2018 AI 开发者会议是一场由中俄人工智能技术前辈携手构建的 AI 技术与产业的年度盛事!这里有 15+ 硅谷实力讲联队、80+AI 领军企业技术核心人物、100+ 技术&大众实力媒体、1500+AI 专业开发者——我们只讲技术,拒绝空谈!
参加 2018 AI 开发者会议,请点击 ↑↑↑
作者 | Vincent Mühler
译者 | 刘旭坤
整理 | Jane
出品 | AI科技大本营
【导读】TensorFlow.js 的发布可以说是 JS 社区开发者的福音!但是在浏览器中训练一些模型还是会存在一些问题与不同,如何可以让训练疗效更好?本文的作者,是一位后端工程师,经过自己不断的经验积累,为你们总结了 18 个 Tips,希望可以帮助你们训练出更好的模型。
TensorFlow.js 发布以后我就把之前训练的目标/人脸检查和人脸辨识的模型往 TensorFlow.js 里导,我发觉有些模型在浏览器里运行的疗效还相当不错。感觉 TensorFlow.js 让我们搞后端的也潮了一把。
虽说浏览器也能跑深度学习模型了,这些模型终归不是为在浏览器里运行设计的,所以好多限制和挑战也就随之而来了。就拿目标测量来说,不说实时监测,就是维持一定的分辨率似乎都很困难。更别提动辄上百兆的模型给用户浏览器和带宽(手机端的话)带来的压力了。
不过只要我们遵守一定的原则,用卷积神经网络 CNN 和 TensorFlow.js 在浏览器里训练个像样的深度学习模型并非痴人说梦。从下边图里可以听到,我训练的这几个模型大小都控制在了 2 MB 以下,最小的才 3 KB。
大家可能心里会有个疑惑:你弱智吗?要用浏览器训练模型?对,用自己笔记本、服务器、集群或则云来训练深度学习模型肯定是一条正道,但并非人人都有钱用NVIDIA GTX 1080 Ti 或者Titan X(尤其是主板集体大降价以后)。这时,在浏览器中训练深度学习模型的优势就彰显下来了,有了 WebGL 和 TensorFLow.js 我用笔记本上的 AMD GPU 也能很方便地训练深度学习模型。
对目标辨识问题,为了稳当起见一般还会建议你们用一些现成的构架例如YOLO、SSD、残差网路 ResNet 或 MobileNet ,但我个人觉得假如完全仿效的话,在浏览器上训练疗效肯定是不好的。在浏览器上训练就要求模型要小、要快、要越容易训练越好。下面我们就从模型构架、训练和调试等几个方面来瞧瞧怎么能够做到这三点。
模型构架
▌1. 控制模型大小
控制模型的规模很重要。如果模型构架太大太复杂,训练和运行的速率就会减少,从浏览器载入模型度速率也会变慢。控制模型的规模说起来简单,难的是取得准确率和模型规模之间的平衡。如果准确率达不到要求,模型再小也是废物。
▌2. 使用深度可分离频域操作
与标准频域操作不同,深度可分离频域先对每位通道进行频域操作,之后再进行1X1跨通道频域。这样做的用处是可以大大减少参数个数,所以模型运行速率会有很大提高,资源的消耗和训练速率也会有所提高。深度可分离频域操作的过程如下图所示:
MobileNet 和 Xception 都使用了深度可分离频域,TensorFlow.js 版本的 MobileNet 和 PoseNet 中你也能看到深度可分离频域的身影。虽然深度可分离频域对模型准确率的影响还有争议,但从我个人的经验来看在浏览器里训练模型用它肯定没错。
第一层我推荐用标准的 conv2d 操作来保持提取完特点的通道之间的关系。因为第一层通常参数不多,所以对性能的影响不大。
其他频域层就可以都用深度可分离频域了。比如这儿我们就使用了两个过滤器。
这里 tf.separableConv2d 使用的频域核结构分别是[3,3,32,1]和[1,1,32,64]。
▌3.运用跳跃联接和密集块
随着网路层数的降低,梯度消失问题出现的可能性也会减小。梯度消失会导致损失函数增长太慢训练时间超长或则干脆失败。ResNet 和 DenseNet 中采用的跳跃联接则能防止这一问题。简单说来跳跃联接就是把个别层的输出跳过激活函数直接传给网路深处的隐藏层作为输入,如下图所示:
这样就防止了由于激活函数和链式导数引起的梯度消失问题,我们也能依照需求降低网路的层数了。
显然跳跃联接蕴涵的一个要求就是联接的两层输出和输入的格式必须能对应得上。我们要用方差网路的话,那最好保证两层的过滤器数量和填充都一致但是弯度为1(不过肯定有其它做法来保证格式对应)。
一开始我模仿方差网路的思路隔一层加一个跳跃联接(如下图)。不过我发觉密集块疗效更好,模型收敛的速率比加跳跃联接快得多。
下面我们就来瞧瞧具体的代码,这里的密集块有四个深度可分离频域层,其中第一层我把弯度设为 2 来改变输入的大小。
▌4.激活函数选ReLU
在浏览器里训练深度网路的话激活函数不用看直接选 ReLU 就行了,主要缘由还是梯度消失。不过你们可以试试 ReLU 的不同变种,比如
和 MobileNet 用的 ReLU-6 (y = min(max(x, 0), 6)):
训练过程
▌5.优化器选Adam
这也是我个人的经验只谈。之前用 SGD 经常会卡在局部极小值或则出现梯度爆燃。我推荐你们一开始把学习速度设为 0.001 然后其他参数都用默认:
▌6.动态调整学习速度
一般来说当损失函数不再增长的时侯我们就该停止训练了,因为再训练就过拟合了。不过若果我们发觉损失函数出现上下回落的情况,则可能通过降低学习速度让损失函数显得更小。
下面这个事例中我们可以看见学习速度一开始设的是 0.01,然后从 32 期开始出现回落(黄线)。这里通过将学习速度改为 0.001(蓝线)使损失函数又减少了大约 0.3。
▌7.权重初始化原则
我个人喜欢把偏置量设为 0,权重则用传统的正态分布。我通常用的是 Glorot 正态分布初始化法:
▌8.把数据集次序搅乱
老生常谈了。TensorFlow.js 中我们可以用 tf.utils.shuffle 来实现。
▌9. 保存模型
js 可以通过 FileSaver.js 来实现模型的储存(或者叫下载)。比如下边的代码就可以把模型所有的权重保存上去:
保存成哪些格式是自己定的,但 FileSaver.js 只管存,所以这儿要用JSON.strinfify 把 Blob 转成字符串:
调试
▌10.保证预处理和后处理的正确性
虽然是句屁话但“垃圾数据垃圾结果”实在是至理名言。标记要标对,每层的输入输出也要前后一致。尤其是对图片做过一些预处理和后处理的话更要仔细,有时候那些小问题还比较难发觉。所以其实费些工夫但磨刀不误砍柴工。
▌11.自定义损失函数
TensorFlow.js 提供了好多现成的损失函数给大家用,而且一般说来也够用了,所以我不太建议你们自己写。如果实在要自己写的话,请一定注意先测试测试。
▌12.在数据子集试试过拟合
我建议你们模型定义好以后先挑个十几二十张图试试看损失函数有没有收敛。最好能把结果可视化一下,这样才能很明显地看出这个模型有没有成功的潜质。
这样做我们也能早早地发觉模型和预处理时的一些低级错误。这似乎也就是 11 条里说的测试测试损失函数。
性能
▌13.内存泄漏
不知道你们知不知道 TensorFlow.js 不会手动帮你进行垃圾回收。张量所占的显存必须自己自动调用 tensor.dispose() 来释放。如果忘掉回收的话内存泄漏是早晚的事。
判断有没有内存泄漏很容易。大家把 tf.memory() 每次迭代都输下来瞧瞧张量的个数。如果没有仍然降低那说明没泄露。
▌14.调整画布大小,而不是张量大小
在调用 TF . from pixels 之前,要将画布转换成张量,请调整画布的大小,否则你会很快用尽 GPU 内存。
如果你的训练图象大小都一样,这将不会是一个问题,但是假如你必须明确地调整它们的大小,你可以参考下边的代码。(注意,以下句子仅在 tfjs - core 的当前状态下有效,我当前正在使用 tfjs - core 版本 0.12.14)
▌15.慎选批大小
每一批的样本数选多少,也就是批大小似乎取决于我们用的哪些 GPU 和网路结构,所以你们最好试试不同的批大小瞧瞧如何最快。我通常从 1 开始试,而且有时候我发觉降低批大小对训练的效率也没啥帮助。
▌16.善用IndexedDB
我们训练的数据集由于都是图片所以有时候还是挺大的。如果每次都下载的话肯定效率低,最好是用 IndexedDB 来储存。IndexedDB 其实就是浏览器里嵌入的一个本地数据库,任何数据都能以通配符对的方式进行储存。读取和保存数据也只要几行代码能够搞定。
▌17.异步返回损失函数值
要实时检测损失函数值的话可以用下边的代码这来自己算之后异步返回:
需要注意的是假如每期训练完要把损失函数值存到文件里的话这样的代码就有点问题了。因为现今损失函数的值是异步返回了所以我们得等最后一个 promise 返回能够存。不过我通常都暴力地在一期结束以后直接等个 10 秒再存:
▌18.权重的量化
为了实现又小又快的目标,在模型训练完成以后我们应当对权重进行量化来压缩模型。权重量化不光能减少模型的容积,对提升模型的速率也很有帮助,而且几乎全是用处没益处。这一步就让模型又能小又能快,非常适宜我们在浏览器里训练深度学习模型。
在浏览器里训练深度学习模型的十八招(实际十七招)就总结到这儿,希望你们读了这篇文章才能有所收获。
如果有问题也欢迎在后台给我们留言,大家一起讨论!
原文链接:
【完】
微信改版了,
想快速见到CSDN的热乎文章,
赶快把CSDN公众号设为星标吧,
打开公众号,点击“设为星标”就可以啦!
“
征稿啦”
CSDN公众号奉行着「与千万技术人共成长」理念,不仅以「极客头条」、「畅言」栏目在第一时间以技术人的奇特视角描述技术人关心的行业焦点风波,更有「技术头条」专栏,深度剖析行业内的热门技术与场景应用,让所有的开发者紧随技术时尚,保持警醒的技术味觉,对行业趋势、技术有更为全面的认知。
如果你有优质的文章,或是行业热点风波、技术趋势的真知灼见,或是深度的应用实践、场景方案等的新看法,欢迎联系CSDN投稿,联系方法:微信(guorui_1118,请备注投稿+姓名+公司职位),邮箱(guorui@csdn.net)。