11 使用计算图求导(Derivatives with a Computation Graph)
🧠 课程核心思想
在深度学习中,反向传播(Backpropagation) 是通过计算图(Computation Graph) 高效计算损失函数对各参数偏导数的核心机制。
本节通过一个具体例子,展示了如何利用链式法则(Chain Rule) 沿着计算图从右向左(backward pass) 逐层计算梯度。
📐 示例函数与计算图结构
考虑如下函数:
$$
J = 3v, \quad \text{其中} \quad v = a + u, \quad u = b \cdot c
$$
给定初始值:
- $a = 5$
- $b = 3$
- $c = 2$
- $u = b \cdot c = 6$
- $v = a + u = 11$
- $J = 3v = 33$
计算图结构(从左到右为前向传播):
b ──┐ |
🔁 反向传播:从右向左计算导数
目标:计算 $J$ 对各个中间变量和输入变量的偏导数:
- $\frac{\partial J}{\partial v}$
- $\frac{\partial J}{\partial a}$
- $\frac{\partial J}{\partial u}$
- $\frac{\partial J}{\partial b}$
- $\frac{\partial J}{\partial c}$
1. 计算 $\frac{\partial J}{\partial v}$
$$
J = 3v \quad \Rightarrow \quad \frac{\partial J}{\partial v} = 3
$$
记作代码中的变量:dv = 3
2. 计算 $\frac{\partial J}{\partial a}$
由于 $v = a + u$,所以:
$$
\frac{\partial v}{\partial a} = 1
$$
应用链式法则:
$$
\frac{\partial J}{\partial a} = \frac{\partial J}{\partial v} \cdot \frac{\partial v}{\partial a} = 3 \cdot 1 = 3
$$
记作:da = 3
3. 计算 $\frac{\partial J}{\partial u}$
同理,$v = a + u \Rightarrow \frac{\partial v}{\partial u} = 1$
$$
\frac{\partial J}{\partial u} = \frac{\partial J}{\partial v} \cdot \frac{\partial v}{\partial u} = 3 \cdot 1 = 3
$$
记作:du = 3
4. 计算 $\frac{\partial J}{\partial b}$
由于 $u = b \cdot c$,所以:
$$
\frac{\partial u}{\partial b} = c = 2
$$
再用链式法则:
$$
\frac{\partial J}{\partial b} = \frac{\partial J}{\partial u} \cdot \frac{\partial u}{\partial b} = 3 \cdot 2 = 6
$$
记作:db = 6
5. 计算 $\frac{\partial J}{\partial c}$
同理:
$$
\frac{\partial u}{\partial c} = b = 3
$$
$$
\frac{\partial J}{\partial c} = \frac{\partial J}{\partial u} \cdot \frac{\partial u}{\partial c} = 3 \cdot 3 = 9
$$
记作:dc = 9
💻 编程中的简化记号
在实际代码(如 Python)中,为了简洁,通常省略对最终输出 $J$ 的显式书写,而直接用变量名表示其对某中间量的偏导:
| 数学符号 | 代码变量名 | 含义 |
|---|---|---|
| $\frac{\partial J}{\partial v}$ | dv |
$J$ 对 $v$ 的偏导 |
| $\frac{\partial J}{\partial a}$ | da |
$J$ 对 $a$ 的偏导 |
| $\frac{\partial J}{\partial u}$ | du |
$J$ 对 $u$ 的偏导 |
| $\frac{\partial J}{\partial b}$ | db |
$J$ 对 $b$ 的偏导 |
| $\frac{\partial J}{\partial c}$ | dc |
$J$ 对 $c$ 的偏导 |
✅ 这种记法假设所有导数都是相对于同一个最终输出 $J$(即损失函数),因此无需重复写
dJ_d...。
🔗 链式法则(Chain Rule)回顾
若变量依赖关系为:
$$
J \leftarrow v \leftarrow u \leftarrow b
$$
则:
$$
\frac{\partial J}{\partial b} = \frac{\partial J}{\partial v} \cdot \frac{\partial v}{\partial u} \cdot \frac{\partial u}{\partial b}
$$
但在本例中,因 $v$ 直接依赖于 $u$,而 $J$ 直接依赖于 $v$,可分步计算:
$$
\frac{\partial J}{\partial b} = \underbrace{\frac{\partial J}{\partial u}}{=3} \cdot \underbrace{\frac{\partial u}{\partial b}}{=2} = 6
$$
🔄 前向 vs 反向传播
| 步骤 | 方向 | 目的 | 计算内容 |
|---|---|---|---|
| 前向传播(Forward Pass) | 左 → 右 | 计算输出 $J$ | $a, b, c \rightarrow u \rightarrow v \rightarrow J$ |
| 反向传播(Backward Pass) | 右 ← 左 | 计算梯度 | $\frac{\partial J}{\partial v} \rightarrow \frac{\partial J}{\partial a}, \frac{\partial J}{\partial u} \rightarrow \frac{\partial J}{\partial b}, \frac{\partial J}{\partial c}$ |
⚡ 效率关键:反向传播避免重复计算,利用已算出的上游梯度(如
dv)高效推导下游梯度(如da,du)。
✅ 关键结论
- 计算图是理解和实现自动微分的基础工具。
- 反向传播的本质是链式法则的系统化应用。
- 编程中采用简写记号(如
dv , da )极大提升代码可读性与简洁性。 - 所有梯度最终服务于优化目标(如最小化损失函数 **$J$** ) 。
🔜 下一步预告
正如视频末尾所提,下一节将把这一机制应用到逻辑回归(Logistic Regression) 中,展示如何在真实模型中计算梯度并更新参数。