稳定的数值计算

本文将介绍在实现softmax和交叉熵损失计算时需要注意的一些数学问题,并且回顾相关函数的定义与性质;代码来源于极市平台 - 编写高效的Pytorch代码技巧

部分技巧

计算机中的浮点数运算通常存在两个问题:精度损失和算术溢出,本质上都是浮点计算单元的位数有限导致的。在训练模型过程中令人讨厌的问题之一就是梯度/loss出现Nan或者Inf,当算法本身在数学上没有错误时,就要考虑是不是出现了不稳定的数值计算,尤其警惕那些特别小或者特别大的数值;下面记录两个避免数值问题的技巧:

softmax

不稳定的写法如下,\(e^x\)很容易发生上溢:

1
2
3
def unstable_softmax(logits):
exp = torch.exp(logits)
return exp / torch.sum(exp)

可以将分子分母同时除以一个常数,相当于对输入数据减去一个常量,经验上,这个常量取输入的最大值即可:

1
2
3
def softmax(logits):
exp = torch.exp(logits - torch.max(logits))
return exp / torch.sum(exp)

cross entropy

计算交叉熵损失通常在对预测得分完成softmax之后,但这并不意味着它的实现与softmax解耦了,事实上,下面的改进正是利用了这种耦合性

不稳定的写法如下,\(\log x\)很容易发生下溢:

1
2
3
def unstable_softmax_cross_entropy(labels, logits):
logits = torch.log(softmax(logits))
return -torch.sum(labels * logits)

可以这样重写: \[ \log p_i = \log(\frac{e^{x_i}}{\sum_je^{x_j}})=\log(\frac{e^{x_i-m}}{\sum_je^{x_j-m}})=(x_i-m) - \text{LSE}(\mathbf{x}-\mathbf{1}m) \] 其中\(m=\max_jx_j\)\(\text{LSE}\)是LogSoftMax;

1
2
3
4
def softmax_cross_entropy(labels, logits, dim=-1):
scaled_logits = logits - torch.max(logits)
normalized_logits = scaled_logits - torch.logsumexp(scaled_logits, dim)
return -torch.sum(labels * normalized_logits)

补充的数学知识

softmax

通常使用这个函数将输出转化为一组在\((0,1)\)范围内的,和为1的输出,有着概率意义;

softmax是这样的函数:\(\sigma:\mathbb{R}^K\rightarrow (0,1)^K\),其中\(K>1\) \[ \sigma(\mathbf{z})_i=\frac{e^{z_i}}{\sum_{j=1}^Ke^{z_j}} \] 除了以\(e\)为底,任何一个\(b>0\)可以作为底,如果\(0<b<1\),则原来更小的值将获得更大的输出;如果\(b>1\),则原来更大的值将获得更小的输出;令\(b=e^\beta\),其中\(\beta\)是实数,则还有: \[ \sigma(\mathbf{z})_i=\frac{e^{\beta z_i}}{\sum_{j=1}^Ke^{\beta z_j}} \] softmax是平滑版的argmax,当\(\beta\rightarrow \infty\)时,softmax收敛到argmax;一般地,平滑即连续可导,这样的性质在神经网络中非常重要;如果\(z_1,\dots,z_n\)两两不相等,有: \[ \arg\max(z_1,\dots,z_n)=(y_1,\dots,y_n)=(0,\dots,0,1,0,\dots,0) \] 其中\(\mathbf{z}\)的一个微小的改变都可能让argmax的输出发生跳变;

argmax不可导,但是max(x1,x2)在x1!=x2时是可导的,后文也会介绍max的平滑近似——LSE

类似的,当\(\beta\rightarrow -\infty\)时,argmax收敛到argmin;

cross entropy的梯度推导

softmax+交叉熵耦合在一起的另一个原因是,其导数的形式非常简洁:

假设真值为第\(i\)个类别,\(\mathbf{q}\)是真实分布,\(\mathbf{p}=\text{softmax}(\mathbf{s})\)是估计分布,则损失函数\(L=-\sum_jq_j\log p_j=-\log p_i\),其中\(p_i = \frac{\exp(s_i)}{\sum_j\exp(s_j)}\);要求出\(\frac{\partial L}{\partial \mathbf{s}}\),其中\(\mathbf{s}=(1,\dots,j,\dots,d)\),分为两种情况:

下面用\(\sum\)代替\(s\),注意\(\frac{\partial L}{\partial p_i}=-\frac{1}{p_i}=-\frac{\sum{}}{\exp(s_i)}\)

  1. \(j = i\),此时\(p_i\)的分子也是\(s_j\)的函数: \[ \frac{\partial {p_i}}{\partial{s_j}}=\frac{\exp(s_i)\sum{}-\exp(2s_i)}{\sum{}^2} \]\[ \begin{aligned} \frac{\partial L}{\partial s_j} &= -\frac{\sum{}}{\exp(s_i)}\frac{\exp(s_i)\sum{}-\exp(2s_i)}{\sum{}^2}\\ &=p_i-1\\ &=p_j-q_j \end{aligned} \]

  2. \(j\ne i\),此时\(p_i\)的分子与\(s_j\)无关: \[ \frac{\partial {p_i}}{\partial{s_j}}=\frac{-\exp(s_i+s_j)}{\sum{}^2} \]\[ \begin{aligned} \frac{\partial L}{\partial s_j} &= -\frac{\sum{}}{\exp(s_i)}\frac{-\exp(s_i+s_j)}{\sum{}^2}\\ &=p_j\\ &=p_j-q_j \end{aligned} \]

综上所述,\(\frac{\partial L}{\partial \mathbf{s}}=\mathbf{p}-\mathbf{q}\)

LogSoftExp

函数的形式为: \[ \text{LSE}(\mathbf{x})=\log(\sum_je^{x_j}) \] 该函数是max函数的平滑版本,是凸函数,假设\(\mathbf{x}\in\mathbb{R}^n\)\(m=\max_j x_j\),则 \[ e^m\le \sum_j^ne^{x_j}\le ne^m \] 两边同时取对数,则有: \[ m\le \text{LSE}(\mathbf{x}) \le \log(n) + m \] 在使用计算机计算LSE时同样也有数值问题,而解决方法同样是,对于等式右边,减去再加上\(m = \log e^m\)