前言
搞定了LSTM理论之后,按照理解搭建了一个简易的模型,但是在一切都看起来没什么问题的情况下,报错了。
报错不太寻常,因为并不是跑一轮就报错,而是正常跑几轮才会报错。
用异常捕获机制强制停止出现异常的轮次,打印发现模型输出包含nan
的张量。
漫长的debug就这样开始了。
神经网络传播过程中nan的处理
如果神经网络设计有缺陷,确实可能出现传播过程中出现nan的情况,而且这种情况在网络上非常常见。
比如:PyTorch训练过程中出现NaN的排查笔记 - 知乎
我觉得就排查的非常好,非常有条理。
神经网络传播过程中可能出现nan的地方很多,这里直接借用这位作者整理的:
- 学习率过大
学习率是控制模型参数更新的重要超参数。如果学习率设置得过大,模型参数更新的幅度可能会过大,从而导致损失值发散。这种情况下,测试损失的值可能会变为NaN。可以尝试减小学习率,以确保模型稳定地收敛。
- 模型设计问题
测试损失变为NaN还可能是模型设计存在问题。在一些情况下,模型的架构可能导致数值不稳定,从而出现NaN值。这可能是由于某些操作(例如relu、softmax等)在特定情况下产生了数值溢出或欠溢出。可以尝试改变模型的架构,例如使用不同的激活函数或正则化操作,来处理这个问题。
- 尺度不平衡的初始化
“尺度不平衡的初始化”是指权重初始化得过大或过小,造成了梯度更新时的不稳定性。使用适合你使用的激活函数的初始化方法(如He或Xavier初始化)可以有效地解决这一问题。
这位作者的排查方式比较专业,但是对于我来说实在是太难了,因为我并不十分清楚lstm的构造,说白了把lstm当个黑盒在使用,有没有更方便的方法?
有的,总算给我找到了:训练过程中出现nan(not a number)的原因及解决方案 - 知乎
可以在 python 文件头部使用如下函数打开 nan 检查:
torch``**.**``autograd``**.**``set_detect_anomaly(True)
加了这行,一旦传播过程中出现了nan,会直接在该处报错停止代码。由于是直接给出了问题发生的位置,这让debug变得非常容易!
另外,如果是反向传播过程中要使用这种检查,需要这样
loss = model(X)
with torch.autograd.detect_anomaly():
loss.backward()
这样就能定位因为梯度爆炸之类的原因产生的nan了。
开启 nan 检查后,直接定位在了数据输入模型这步上,也就是说数据集有空值!
数据集有nan值的处理
一般来说,用以下语句定位数据集确定的有nan的位置
assert not torch.any(torch.isnan(T))
然后,要么变换,要么删掉就好了!
其他可能出现nan的情况
调研的时候看了不少文章,这里汇总一下
1、pytorch混合精度训练,使用半精度等能提升batch,但是有出现nan的风险
2、升级你的numpy…
pytorch中第一轮训练loss就是nan是为什么啊? - 知乎
反正看到过很多奇怪的情况,结果都是升级numpy解决的…
[文章导入自 http://qzq-go.notion.site/12949a7b4e7580fc88fbe94d75bcf251 访问原文获取高清图片]