稳定的数值计算
本文将介绍在实现softmax和交叉熵损失计算时需要注意的一些数学问题,并且回顾相关函数的定义与性质;代码来源于极市平台 - 编写高效的Pytorch代码技巧;
部分技巧
计算机中的浮点数运算通常存在两个问题:精度损失和算术溢出,本质上都是浮点计算单元的位数有限导致的。在训练模型过程中令人讨厌的问题之一就是梯度/loss出现Nan
或者Inf
,当算法本身在数学上没有错误时,就要考虑是不是出现了不稳定的数值计算,尤其警惕那些特别小或者特别大的数值;下面记录两个避免数值问题的技巧:
softmax
不稳定的写法如下,\(e^x\)很容易发生上溢:
1 | def unstable_softmax(logits): |
可以将分子分母同时除以一个常数,相当于对输入数据减去一个常量,经验上,这个常量取输入的最大值即可:
1 | def softmax(logits): |
cross entropy
计算交叉熵损失通常在对预测得分完成softmax之后,但这并不意味着它的实现与softmax解耦了,事实上,下面的改进正是利用了这种耦合性
不稳定的写法如下,\(\log x\)很容易发生下溢:
1 | def unstable_softmax_cross_entropy(labels, logits): |
可以这样重写: \[ \log p_i = \log(\frac{e^{x_i}}{\sum_je^{x_j}})=\log(\frac{e^{x_i-m}}{\sum_je^{x_j-m}})=(x_i-m) - \text{LSE}(\mathbf{x}-\mathbf{1}m) \] 其中\(m=\max_jx_j\);\(\text{LSE}\)是LogSoftMax;
1 | def softmax_cross_entropy(labels, logits, dim=-1): |
补充的数学知识
softmax
通常使用这个函数将输出转化为一组在\((0,1)\)范围内的,和为1的输出,有着概率意义;
softmax是这样的函数:\(\sigma:\mathbb{R}^K\rightarrow (0,1)^K\),其中\(K>1\) \[ \sigma(\mathbf{z})_i=\frac{e^{z_i}}{\sum_{j=1}^Ke^{z_j}} \] 除了以\(e\)为底,任何一个\(b>0\)可以作为底,如果\(0<b<1\),则原来更小的值将获得更大的输出;如果\(b>1\),则原来更大的值将获得更小的输出;令\(b=e^\beta\),其中\(\beta\)是实数,则还有: \[ \sigma(\mathbf{z})_i=\frac{e^{\beta z_i}}{\sum_{j=1}^Ke^{\beta z_j}} \] softmax是平滑版的argmax,当\(\beta\rightarrow \infty\)时,softmax收敛到argmax;一般地,平滑即连续可导,这样的性质在神经网络中非常重要;如果\(z_1,\dots,z_n\)两两不相等,有: \[ \arg\max(z_1,\dots,z_n)=(y_1,\dots,y_n)=(0,\dots,0,1,0,\dots,0) \] 其中\(\mathbf{z}\)的一个微小的改变都可能让argmax的输出发生跳变;
argmax不可导,但是max(x1,x2)在x1!=x2时是可导的,后文也会介绍max的平滑近似——LSE
类似的,当\(\beta\rightarrow -\infty\)时,argmax收敛到argmin;
cross entropy的梯度推导
softmax+交叉熵耦合在一起的另一个原因是,其导数的形式非常简洁:
假设真值为第\(i\)个类别,\(\mathbf{q}\)是真实分布,\(\mathbf{p}=\text{softmax}(\mathbf{s})\)是估计分布,则损失函数\(L=-\sum_jq_j\log p_j=-\log p_i\),其中\(p_i = \frac{\exp(s_i)}{\sum_j\exp(s_j)}\);要求出\(\frac{\partial L}{\partial \mathbf{s}}\),其中\(\mathbf{s}=(1,\dots,j,\dots,d)\),分为两种情况:
下面用\(\sum\)代替\(s\),注意\(\frac{\partial L}{\partial p_i}=-\frac{1}{p_i}=-\frac{\sum{}}{\exp(s_i)}\)
\(j = i\),此时\(p_i\)的分子也是\(s_j\)的函数: \[ \frac{\partial {p_i}}{\partial{s_j}}=\frac{\exp(s_i)\sum{}-\exp(2s_i)}{\sum{}^2} \] 则 \[ \begin{aligned} \frac{\partial L}{\partial s_j} &= -\frac{\sum{}}{\exp(s_i)}\frac{\exp(s_i)\sum{}-\exp(2s_i)}{\sum{}^2}\\ &=p_i-1\\ &=p_j-q_j \end{aligned} \]
\(j\ne i\),此时\(p_i\)的分子与\(s_j\)无关: \[ \frac{\partial {p_i}}{\partial{s_j}}=\frac{-\exp(s_i+s_j)}{\sum{}^2} \] 则 \[ \begin{aligned} \frac{\partial L}{\partial s_j} &= -\frac{\sum{}}{\exp(s_i)}\frac{-\exp(s_i+s_j)}{\sum{}^2}\\ &=p_j\\ &=p_j-q_j \end{aligned} \]
综上所述,\(\frac{\partial L}{\partial \mathbf{s}}=\mathbf{p}-\mathbf{q}\)
LogSoftExp
函数的形式为: \[ \text{LSE}(\mathbf{x})=\log(\sum_je^{x_j}) \] 该函数是max函数的平滑版本,是凸函数,假设\(\mathbf{x}\in\mathbb{R}^n\),\(m=\max_j x_j\),则 \[ e^m\le \sum_j^ne^{x_j}\le ne^m \] 两边同时取对数,则有: \[ m\le \text{LSE}(\mathbf{x}) \le \log(n) + m \] 在使用计算机计算LSE时同样也有数值问题,而解决方法同样是,对于等式右边,减去再加上\(m = \log e^m\)