文章

04 通过时间的反向传播(Backpropagation Through Time, BPTT)

04 通过时间的反向传播(Backpropagation Through Time, BPTT)

1. 核心概念

  • 前向传播 (Forward Propagation)

    • 方向:从左到右(时间步 $t$ 从 $1$ 增加到 $T_x$)。
    • 过程:利用输入序列 $x^{(1)}, x^{(2)}, \dots, x^{(T_x)}$ 和共享参数 $W_{ax}, W_{aa}, b_a$,依次计算每个时间步的激活值 $a^{(t)}$。
    • 输出计算:利用激活值 $a^{(t)}$ 和参数 $W_{ya}, b_y$ 计算预测值 $\hat{y}^{(t)}$。
    • 特点:所有时间步共享同一组参数,当前时刻的激活值依赖于上一时刻的激活值。
  • 反向传播 (Backward Propagation)

    • 方向:从右到左(时间步 $t$ 从 $T_x$ 递减到 $1$),即“时间倒流”。
    • 目的:计算损失函数相对于各参数的梯度,以便使用梯度下降法更新参数。
    • 命名由来:因为计算方向与时间流逝方向相反,仿佛穿越时光,故称为“通过时间的反向传播” (BPTT)。

2. 损失函数定义

  为了进行反向传播,首先定义损失函数:

  • 单时间步损失 (Element-wise Loss)
    对于序列中第 $t$ 个位置的预测,使用交叉熵损失函数(逻辑回归损失):

    \[\mathcal{L}^{(t)}(\hat{y}^{(t)}, y^{(t)}) = -y^{(t)} \log(\hat{y}^{(t)}) - (1-y^{(t)}) \log(1-\hat{y}^{(t)})\]

    (注:此处以二分类为例,若是多分类则对应标准的交叉熵形式)

  • 整个序列的总损失 (Total Loss)
    将所有时间步的损失相加:

    \[\mathcal{L}(\hat{y}, y) = \sum_{t=1}^{T_x} \mathcal{L}^{(t)}(\hat{y}^{(t)}, y^{(t)})\]

3. BPTT 算法流程

  1. 构建计算图:将 RNN 在时间上展开,形成一个深层的前馈神经网络结构。
  2. 前向计算:从左至右计算所有 $a^{(t)}$ 和 $\hat{y}^{(t)}$,并最终计算出总损失 $\mathcal{L}$。
  3. 反向传递

    • 从最后一个时间步 $T_x$ 开始,计算 $\frac{\partial \mathcal{L}}{\partial \hat{y}^{(T_x)}}$。
    • 沿着红色箭头方向(从右向左),利用链式法则依次计算梯度。
    • 关键步骤是将梯度从 $t+1$ 时刻传递回 $t$ 时刻,因为 $a^{(t+1)}$ 依赖于 $a^{(t)}$。
  4. 参数更新

    • 由于参数 $W_{ax}, W_{aa}, W_{ya}, b_a, b_y$ 在所有时间步被共享,最终的梯度是所有时间步梯度的累加和。
    • 例如:$\frac{\partial \mathcal{L}}{\partial W_{aa}} = \sum_{t=1}^{T_x} \frac{\partial \mathcal{L}^{(t)}}{\partial W_{aa}}$。
    • 利用计算出的总梯度更新参数。

4. 总结与展望

  • 实现细节:在实际编程框架(如 TensorFlow, PyTorch)中,BPTT 通常由框架自动处理,但理解其“时间展开”和“反向累加梯度”的机制对于调试和理解 RNN 行为(如梯度消失/爆炸问题)至关重要。
  • 当前局限:本节讨论的是输入序列长度等于输出序列长度($T_x = T_y$)的情况。
  • 后续内容:接下来的课程将介绍更多样的 RNN 架构,以处理输入输出长度不一致的更广泛应用场景(如机器翻译、语音识别等)。
本文由作者按照 CC BY 4.0 进行授权