Skip to content

这是一个非常关键的误解!你正好卡在了整个推导中最容易把人绕晕的 “下标” 上。

你之所以觉得这里应该是 0,是因为你把 yjyt 搞混了。让我们放慢动作,把这层窗户纸捅破。

1. 谁是 0,谁不是 0?

回想一下我们的交叉熵损失函数:=log(yt)。(t 是正确类别的索引)。

如果网络预测了 3 个类别的概率:y0,y1,y2,并且真实的正确类别是 t=1。 那么损失仅仅是:=log(y1)

现在,我们拿这个 去对所有的 y 求偏导:

  • y0 求导:公式里没有 y0,所以 y0=0
  • y1(也就是 yt)求导:根据对数求导法则,y1=1y1 (注意!这个绝对不是 0)
  • y2 求导:公式里没有 y2,所以 y2=0

结论: 只有当对错误类别的概率求导时,结果才是 0。对正确类别 t 的概率求导时,结果是 1yt

2. 求和公式是如何坍缩的?

根据链式法则,我们要求 对某个错误类别分数 as 的导数:

as=jyjyjas

展开来看(假设有 3 个类别,正确是 t=1,我们正在对错误的 a0 也就是 s=0 求导):

=(y0y0a0)+(y1y1a0)+(y2y2a0)

现在,把第一步求出的“0”和“非0”代入进去:

=(0)+(1y1y1a0)+(0)

看到了吗?正是因为那些 0 的存在,把周围没用的项全部“杀”掉了! 整个长长的求和公式,最后只活下来了中间那一项(也就是 j=t 的那一项)。

3. 活下来的那一项是什么?

活下来的这一项是:

ytytas
  • 左半边 yt 它是 1yt不是 0
  • 右半边 ytas 这是在问:“错误类别的打分 as 变化时,正确类别的概率 yt 会怎么变?”因为 ts(一个是正确类别,一个是错误类别,下标不同),所以应用我们在上一节讲的“非对角线(挤压)”法则,结果是 ytys

把它们乘在一起:

(1yt)[ytys]

负负得正,分母的 yt 和分子的 yt 约掉,最后就得出了完美的结果:ys


总结一下你的困惑点: 0 乘以任何数确实是 0,但在链式法则的求和公式中,乘以 0 的那些项(jt)已经被彻底消灭丢弃了。我们公式里剩下的 (1yt),是茫茫 0 海中唯一一个不是 0 的幸存者

基于 Kaggle MNIST 数据集,使用纯 numpy 从零实现