04 通过时间的反向传播(Backpropagation Through Time, BPTT)
04 通过时间的反向传播(Backpropagation Through Time, BPTT)
1. 核心概念
前向传播 (Forward Propagation) :
- 方向:从左到右(时间步 $t$ 从 $1$ 增加到 $T_x$)。
- 过程:利用输入序列 $x^{(1)}, x^{(2)}, \dots, x^{(T_x)}$ 和共享参数 $W_{ax}, W_{aa}, b_a$,依次计算每个时间步的激活值 $a^{(t)}$。
- 输出计算:利用激活值 $a^{(t)}$ 和参数 $W_{ya}, b_y$ 计算预测值 $\hat{y}^{(t)}$。
- 特点:所有时间步共享同一组参数,当前时刻的激活值依赖于上一时刻的激活值。
反向传播 (Backward Propagation) :
- 方向:从右到左(时间步 $t$ 从 $T_x$ 递减到 $1$),即“时间倒流”。
- 目的:计算损失函数相对于各参数的梯度,以便使用梯度下降法更新参数。
- 命名由来:因为计算方向与时间流逝方向相反,仿佛穿越时光,故称为“通过时间的反向传播” (BPTT)。
2. 损失函数定义
为了进行反向传播,首先定义损失函数:
单时间步损失 (Element-wise Loss) :
\[\mathcal{L}^{(t)}(\hat{y}^{(t)}, y^{(t)}) = -y^{(t)} \log(\hat{y}^{(t)}) - (1-y^{(t)}) \log(1-\hat{y}^{(t)})\]
对于序列中第 $t$ 个位置的预测,使用交叉熵损失函数(逻辑回归损失):(注:此处以二分类为例,若是多分类则对应标准的交叉熵形式)
整个序列的总损失 (Total Loss) :
\[\mathcal{L}(\hat{y}, y) = \sum_{t=1}^{T_x} \mathcal{L}^{(t)}(\hat{y}^{(t)}, y^{(t)})\]
将所有时间步的损失相加:
3. BPTT 算法流程
- 构建计算图:将 RNN 在时间上展开,形成一个深层的前馈神经网络结构。
- 前向计算:从左至右计算所有 $a^{(t)}$ 和 $\hat{y}^{(t)}$,并最终计算出总损失 $\mathcal{L}$。
反向传递:
- 从最后一个时间步 $T_x$ 开始,计算 $\frac{\partial \mathcal{L}}{\partial \hat{y}^{(T_x)}}$。
- 沿着红色箭头方向(从右向左),利用链式法则依次计算梯度。
- 关键步骤是将梯度从 $t+1$ 时刻传递回 $t$ 时刻,因为 $a^{(t+1)}$ 依赖于 $a^{(t)}$。
参数更新:
- 由于参数 $W_{ax}, W_{aa}, W_{ya}, b_a, b_y$ 在所有时间步被共享,最终的梯度是所有时间步梯度的累加和。
- 例如:$\frac{\partial \mathcal{L}}{\partial W_{aa}} = \sum_{t=1}^{T_x} \frac{\partial \mathcal{L}^{(t)}}{\partial W_{aa}}$。
- 利用计算出的总梯度更新参数。
4. 总结与展望
- 实现细节:在实际编程框架(如 TensorFlow, PyTorch)中,BPTT 通常由框架自动处理,但理解其“时间展开”和“反向累加梯度”的机制对于调试和理解 RNN 行为(如梯度消失/爆炸问题)至关重要。
- 当前局限:本节讨论的是输入序列长度等于输出序列长度($T_x = T_y$)的情况。
- 后续内容:接下来的课程将介绍更多样的 RNN 架构,以处理输入输出长度不一致的更广泛应用场景(如机器翻译、语音识别等)。
本文由作者按照 CC BY 4.0 进行授权