问题
在机器学习中,设计合理、有效的目标函数是为人所津津乐道的技能(本小学生尚无此功力)。倘若设计出来的目标函数连自己都不会求解(优化)那就很尴尬了。像我这种小学生瞎鼓捣出来的目标函数,想知道究竟能不能work,不求解一下写成代码在数据集上跑一跑又如何知晓?
纵观求解方法,有贪心的,动态规划的,蒙特卡洛的,期望最大的,梯度的等等(小学生无责任乱分)。前途无限的深度学习其求解方法全都依赖于梯度,在可以预计的将来极有可能成为大一统的求解方法。因此,如何求解损失函数(这里的目标函数可以称作损失函数)的梯度成了小学生心中最关键的一环。本小节将把叙述重心放在微分部分,也一并给出python编码。
如何求解梯度建议先看看cs231n lecture 4 的ppt或者vidio(cs231n的其他学习资料在这里)。
一些表达上的技巧可以帮助我们后面更好地计算梯度.
梯度推导过程中的trick
逻辑判断
把一些分段函数化为便于求导的乘积形式,多少有点用吧。
二值选择
min(x2,ey)=1{x2<ey}x2+1{x2≥ey}ey多值选择
Y7=∑i1{i=7}Yi分组讨论:
∂(2s1−3s2+es3)∂si={2,i=1−3,i=2es3,i=3=1{i=1}2−1{i=2}3+1{i=3}es3逻辑运算
与运算
1{x>5∧x<10}=1{x>5}⋅1{x<10}或运算
1{x>5∨x<10}=1{x>5}+1{x<10}
打分向量与打分矩阵
对于分类问题,损失函数的输出是实数,输入是打分向量,该向量的分量为样本属于各类别的打分;而batch的平均损失的输入则是打分矩阵S(scores),如下图所示,这个batch中只有3个样本。
实际中的例子
Multiclass SVM loss
L=K∑k≠ymax(0,sk−sy+1)其中s为某样本属于各类别的打分,若是k分类,s为长度为K的数组,为了表达它是一个行向量,以下标记为sT, sk为该样本属于第j类的打分,y为真实类别,详情请参见该ppt。
求∂L/∂sj
把L中的二值选择用逻辑判断代替
L=K∑m≠ymax(0,sm−sy+1)=K∑m≠y1{sm−sy+1≥0}(sj−sy+1)=K∑m≠yqm,y(sm−sy+1)
其中qm,y=1{sm−sy+1≥0}, 则
∂L∂sj=∂∂sj∑m≠yqm,y(sm−sy+1)=∑m≠yqm,y∂(sm−sy+1)∂sj=∑m≠yqm,y(1{j≠y∧m=j}−1{j=y})=1{j≠y}∑m≠yqm,y1{m=j}−1{j=y}∑m≠yqm,y=1{j≠y}qj,y−1{j=y}∑m≠yqm,y
注意: 上面逆用了多值选择的情况
∑mqm,y1{m=j}=qj,y因为有j≠y的保证,为∑加上m≠y的限制不会有任何影响
在矩阵分析中,学习过,实数函数对向量或矩阵求导就是实数函数对向量或矩阵中的每个分量求偏导,若打分向量s是一个1×K的行向量,则单样本损失对打分向量的梯度有:
∂L∂s=(∂L∂s1,∂L∂s2,...,∂L∂sK)
求平均损失对打分矩阵的梯度
我们所谓的损失指的都是期望损失,也就是平均损失。若第i条样本的损失为L(i),则N条样本的平均损失为
ˉL=1NN∑i=1L(i)
注意前面那张图片,打分矩阵各行为各样本的打分向量,因此打分矩阵为 S=(s(1)s(2)⋮s(N))=(s(1)1⋯s(1)K⋮⋱⋮s(N)1⋯s(N)K)
因此平均损失对打分矩阵S的梯度为 dˉLdS=1NN∑i=1dL(i)dS=1NN∑i=1(dL(i)ds(1)⋮dL(i)ds(i)⋮dL(i)ds(N))=1NN∑i=1(0⋮dL(i)ds(i)⋮0)=1N(dL(1)ds(1)⋮dL(N)ds(N))=1N(∂L(1)∂s(1)1⋯∂L(1)∂s(1)K⋮⋱⋮∂L(N)∂s(N)1⋯∂L(N)∂s(N)K):=dS
定义了一个矩阵dS,其中的元素为
(dS)ij=1N∂L(i)∂s(i)j=1N(1{j≠y(i)}qj,y(i)−1{j=y(i)}K∑m≠y(i)qm,y(i))
编码实现
|
|
Softmax loss
L=−log(esy∑mesm)除了loss外,其他设定与Multiclass SVM loss一样,求解梯度一样需要以下四个步骤:
- 求∂L/∂sj
- 求平均损失对打分矩阵的梯度
- 编码实现
因为只有loss不同,其他都是一样的,因此只有求∂L/∂sj和编码实现不同。
求∂L/∂sj(方法1)
L=−log(esy∑mesm)=log∑mesm−sy其中 pj=esj∑mesm
求∂L/∂sj(方法2)
之所以要再多写个方法二,是为了增加对这种逻辑求导运算的熟练度。
L=−log(esy∑mesm)=−log(py)
则
∂L∂sj=−1py⋅∂py∂sj
其中
∂py∂sj=∂∂sj(esy∑mesm)=1{j=y}esy(∑mesm)−esyesj(∑mesm)2−1{j≠y}esyesj(∑mesm)2
将其带入∂L/∂sj得,
∂L∂sj=−∑mesmesy(1{j=y}esy(∑mesm)−esyesj(∑mesm)2−1{j≠y}esyesj(∑mesm)2)=−(1{j=y}(∑mesm)−esj∑mesm−1{j≠y}esj∑mesm)=1{j≠y}esj∑mesm−1{j=y}(∑mesm)−esj∑mesm=1{j≠y}pj−1{j=y}(1−pj)=1{j≠y}pj−1{j=y}+1{j=y}pj=pj−1{j=y}
注意: 上式最后一步用了一个很简单的或运算
1{j≠y}pj+1{j=y}pj=1{j≠y∨j=y}pj=pj
编码实现
|
|
注意: 避免数值问题
esj−smax∑mesm−smax=esj/esmax∑m(esm/esmax)=esj/esmax(∑mesm)/esmax=esj∑mesm
这么做的好处是避免指数运算出现特别大的数而产生溢出
总结
通过上面两个例子,求损失对打分的梯度,关键在于对打分下标的考虑,将选择取值和分组讨论问题转化为逻辑运算,可以简化求导过程,最终简化编码实现。