当人工智能遇上视频直播——基于 Agora Web SDK 实现目标识别

我的上一篇文章介绍了如何自定义直播画面。既然可以自定义直播画面了,我们就能实现更多功能。在这篇文章里,我们可以将可以看到,集合Agora Web SDK,把现在流行的AI功能集成到我们的直播中是一件多么简单的事情。

1 目标识别

我们可以将目标识别定义为根据物体的颜色、形状、纹理或者其他特征对目标进行特征化,根据这个特征我们可以确认一幅图像包含什么样的物体,以及这些物体的具体位置。当图像变为一个序列,形成一个视频,目标识别则变为目标追踪,目标物体的位置和大小会随着图像序列的变化而变化。传统的目标识别依赖于特征的人工选取,通常效率并不是很高,现代深度学习技术很好的解决了这个问题,其中最常用的深度学习框架是Google的TensorFlow,而最常用的深度学习模型则是卷积神经网络(CNN)。

一个比较常见的目标检测工作管线可以描述成这样:首先分析输入进来的图像,然后进行区域划分,基于图像算法给出几个可能是物体的位置,大多数以矩形区域的形式标注,然后使用CNN进行特征提取,再之后进行训练和分类。分类是深度学习中最经典并且最常见的一种任务,简单理解就是给出一个输入的数据,并且预先定义好一批标签,机器学习的任务就是给这个输入的数据选择正确的标签,并作为结果返回。因此,我们对图像中几个可能是物体的区域进行识别后,给出的汇总结果就是最终这幅图像目标检测的结果。
这个工作管线的核心是把目标识别的工作划分为了2个阶段——区域划分和区域标注,因此我们又称之为2阶段检测 (Two-Stage Methods)。与之相对的还存在一种更快的检测方法,我们称之为1阶段检测 (One-Stage Methods),顾名思义,1阶段检测同时进行区域划分和区域标注,因此大大增强了目标检测的速度,我们熟知的YOLO就是1阶段检测中的典型代表。当然,这里我们不做过多展开,在后面的内容中你将会看到,即使对深度学习原理不太了解,也不妨碍我们使用现成的工具开发AI应用。

2 TensorFlow.js

TensorFlow通过Python暴露出各种深度学习编程接口,给算法工程师带来了无限的可能,迄今为止已经诞生出不计其数的AI应用。但Python始终让TensorFlow最多只能以后台程序的身份离线运行在用户看不到的地方,这些后台程序最终将得到一套AI模型,这些AI模型最终通过iOS或者Android的本地应用程序部署在用户的设备中。

TensorFlow.js的出现改变了这个局面,Google使用WebGLWebAssemblyJavaScript进行了扩展,开发出了能部署在浏览器里的深度学习库,这使得在浏览器中直接运行深度学习应用成为了可能。

3 用于目标识别的预训练模型

在我们的例子中,我使用TensorFlow.js的预训练模型,这些预训练模型是TensorFlow.js团队预先包装好的深度学习模型,使用这些模型我们可以跳过训练的步骤,毕竟,训练模型并不是我们这次讨论的重点。

我们使用COCO-SSD模型,这个模型基于COCO数据集被训练,是一个专用于目标检测的深度学习模型,可以在一张图片内识别并标注出多个物体。因为是预先训练好的模型,所以我们只能用它标注训练时指定的物体类别,超出这些类别范围的物体将会被跳过。如果你希望在这个例子中检测到更多的物体,就需要用更多被标定好的包含新物体的数据去训练这个模型。

4 实现

我们基于上一篇文章的仓库作为起始代码进行改造。

首先,我们需要引入tensorflow.jscoco-ssd模型文件

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js" type="text/javascript"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/coco-ssd"></script>

犹豫初始化tensorflow.js以及加载模型需要不少时间(在我的电脑上花费了5秒左右),因此,我们需要让canvas的绘制在模型加载完成之后。因此,我们可以在原来的imageLoadCompleted函数里加入我们的tensorflow.js代码。

// 预备好存储模型的变量,可以在canvas绘制的时候使用
var model = undefined;

function imageLoadCompleted() {
  imagesLoaded += 1;
  // 这里只有一张图需要加载,因此已加载图片数量大于等于1即可
  if(imagesLoaded>=1) {
    if (navigator.mediaDevices === undefined) {
      navigator.mediaDevices = {};
    }
    var constraints = { audio: true, video: { width: 1280, height: 720 } };
    navigator.mediaDevices.getUserMedia(constraints)
    .then(function(mediaStream) {
      // 在获取到mediaStream后再加载tensorflow.js
      // 因为需要通过补货视频画面来使用模型
      cocoSsd.load().then(function (loadedModel) {
        model = loadedModel;
        var video = document.querySelector('video');
        video.srcObject = mediaStream;
        video.onloadedmetadata = function(e) {
          video.addEventListener('loadeddata', function() {
            updateVideoToCanvas();
          });
          video.play();
        };
      });
    })
    .catch(function(err) { console.log(err.name + ": " + err.message); });
    }
}

捕获到视频后,我们就可以从每一帧中使用模型进行目标检测,此时我们可以修改updateVideoToCanvas来实现视频直播画面中的物体标注功能。

function updateVideoToCanvas() {
  context.clearRect(0, 0, mycanvas.width, mycanvas.height);
  context.drawImage(video, 0, 0, 800, 450);
  context.globalAlpha = 1;
  context.drawImage(imgTitle, 0, 0);

  // 这里开始是本次新增的代码
  model.detect(video).then(function (predictions) {
    for (let n = 0; n < predictions.length; n++) {
      context.fillStyle = 'rgba(225,225,225,0.5)';
      context.fillRect(predictions[n].bbox[0],predictions[n].bbox[1],predictions[n].bbox[3],predictions[n].bbox[2]);
      
      let innerText = predictions[n].class  + ' - 拥有 ' 
      + Math.round(parseFloat(predictions[n].score) * 100) 
      + '% 的置信度.';
      context.font = "20px Verdana";
      context.fillStyle = 'rgba(225,225,225,0.5)';
      context.fillText(innerText, predictions[n].bbox[0]-10, predictions[n].bbox[1]-10);
    }
    
  });

  requestAnimationFrame(updateVideoToCanvas);
}

这个过程其实就是模型推断,深度学习中,模型推断通常就是接受一个输入,然后通过已经训练好的模型,给出对应的结果,最典型的一个例子就是给一张图片,然后告诉用户这个图片里面的物体是猫还是狗。

通常,预训练好的模型,都会提供类似上面的detect函数的方法,在我们的例子中,这个方法会返回一个predictions数组,这就是保存了我们模型推断结果的数组。注意,这里不是一个结果,而是一系列的结果,因为我们视频的每一帧很可能被侦测到多个目标物体。一次调用detect我们可以获得当前帧的数据,这就是为什么我们需要把这个方法写在每一帧的绘制函数里。获得了predictions以后,我们即可遍历这个数组,下面简单的解释一下这个数组里的数据类型。

class : 模型推断出这个物体所属的类别,最常见的是人,会被标注为person,其他常见的还有laptopclock之类的,通常都是生活中常见的物品
score : 置信度,因为我们的深度学习模型几乎永远不可能达到100%的正确率,因此这个数字代表着我们推断出的结果又多少可行性
bbox : 目标物体所在的位置,是我们常见的(left, top, width, height)的四元数组形式,借此我们才得以方便的绘制出包含了目标物体的矩形,并给予适当的标注。

至此,运行我们新的代码,就可以在直播的时候实时的引入目标检测效果,是不是很酷?藉由这篇文章的方法,我们可以实现更多基于WebRTC的人工智能结合视频直播的应用,只要弄清楚了其中的核心原理,借助声网的API,整个流程其实并不复杂。实现了这样的功能,我们就可以最大限度的扩展我们的视频直播产品,例如,我们可以对直播间的物体做敏感物体检测,如果出现了违规现象,平台方可以第一时间得知,从而大大减少了传统通过人工来监控视频的工作量。再往后考虑,我们可以结合现在流行的GAN神经网络对直播画面进行AI滤镜化,抖音中的很多效果都可以通过这个思路来实现,只要展开想象力,就可以完成出色的应用。