文章

08 多任务学习(Multi-task Learning)

08 多任务学习(Multi-task Learning)

08 多任务学习(Multi-task Learning)

一、多任务学习 vs 迁移学习

特性迁移学习(Transfer Learning)多任务学习(Multi-task Learning)
学习方式串行:先在源任务上训练,再迁移到目标任务并行:同时训练多个任务
目标利用大数据任务提升小数据任务性能通过共享表示,让多个任务互相促进
典型场景目标任务数据少,源任务数据多多个相关任务需同时完成,且数据量相近

关键区别:迁移学习是“先学A,再用于B”;多任务学习是“同时学A、B、C、D”。


二、多任务学习的核心思想

构建一个共享的神经网络,同时输出多个任务的预测结果。
例如:自动驾驶中,一张图像需同时判断是否存在:

  • 行人(pedestrian)
  • 车辆(car)
  • 停车标志(stop sign)
  • 交通灯(traffic light)

每个样本的标签是一个多维二值向量(而非单标签):

\[y = \begin{bmatrix} y_1 \\ y_2 \\ y_3 \\ y_4 \end{bmatrix} \in \{0,1\}^4\]

其中 $y_i = 1$ 表示存在第 $i$ 类物体,$0$ 表示不存在。

📌 注意:一张图可含多个物体(如同时有车和停车标志),因此是多标签分类(multi-label classification) ,而非互斥的 softmax 分类。


三、损失函数设计

对每个任务使用 logistic 损失(即二元交叉熵) ,总损失为各任务损失之和的平均:

设训练集有 $m$ 个样本,第 $i$ 个样本的真实标签为 $y^{(i)} = [y_1^{(i)}, y_2^{(i)}, y_3^{(i)}, y_4^{(i)}]^\top$,模型预测为 $\hat{y}^{(i)} = [\hat{y}_1^{(i)}, \dots, \hat{y}_4^{(i)}]^\top$。

则整体损失函数为:

\[\mathcal{L} = \frac{1}{m} \sum_{i=1}^{m} \sum_{j=1}^{4} \mathcal{L}_{\text{logistic}}(y_j^{(i)}, \hat{y}_j^{(i)})\]

其中:

\[\mathcal{L}_{\text{logistic}}(y, \hat{y}) = - \left[ y \log \hat{y} + (1 - y) \log (1 - \hat{y}) \right]\]

✅ 此处每个输出节点独立使用 sigmoid 激活 + 二元交叉熵,不是 softmax(因标签非互斥)。


四、处理部分标注数据(Missing Labels)

实际中,某些样本可能只标注了部分任务(如只标了“有行人”,未标“是否有交通灯”)。

解决方案:在计算损失时,仅对已标注的任务求和,忽略缺失项(用问号 ? 表示)。

形式化地,定义掩码 $M_j^{(i)} = 1$ 若 $y_j^{(i)}$ 已知,否则为 0,则损失变为:

\[\mathcal{L} = \frac{1}{m} \sum_{i=1}^{m} \frac{1}{\sum_{j=1}^{4} M_j^{(i)}} \sum_{j=1}^{4} M_j^{(i)} \cdot \mathcal{L}_{\text{logistic}}(y_j^{(i)}, \hat{y}_j^{(i)})\]

💡 实践中常简化为:只对非缺失标签计算损失,不归一化也可接受。


五、多任务学习有效的三大条件

1. 任务可共享低层特征(Shared Representations)

  • 例如:行人、车辆、交通标志都出现在道路场景中,共享边缘、纹理、形状等底层视觉特征。
  • 神经网络前几层可提取通用特征,后几层做任务特异性预测。

2. 各任务数据量相近,或总量远大于单任务数据

  • 若单独训练任务 $T_k$ 只有 1000 个样本,但其他 99 个任务共提供 99,000 个样本,
  • 则多任务学习可通过隐式数据增强提升 $T_k$ 性能。
  • ⚠️ 若某任务数据极少而其他任务数据极多,迁移学习可能更合适

3. 神经网络足够大,能同时拟合所有任务

  • Rich Caruana 的研究指出:只有当网络容量不足时,多任务学习才可能比单任务差
  • 若网络足够大,多任务学习不会降低性能,通常还能提升泛化能力。

✅ 建议:使用足够深/宽的网络,并加入任务特定的输出头(task-specific heads)。


六、多任务学习 vs 单任务多个模型

方法优点缺点
多任务学习(一个网络)共享特征,节省计算资源;任务间正则化,提升泛化设计复杂;任务冲突可能降低性能
多个独立网络任务完全解耦,无干扰无法共享特征;训练/部署成本高

✅ 在计算机视觉(如目标检测)中,多任务学习广泛成功(如 YOLO、Faster R-CNN 同时预测类别+边界框)。


七、实际应用与使用频率

  • 多任务学习使用频率 📌 吴恩达观点: 迁移学习更常用 ,尤其在小数据场景; 多任务学习是特定场景下的强力工具**。

八、总结(Key Takeaways)

  1. 多任务学习 = 一个网络 + 多个输出 + 共享表示
  2. 标签是多维二值向量,使用 sigmoid + 二元交叉熵,非 softmax。
  3. 可处理部分标注数据,只需在损失中忽略缺失项。
  4. 有效前提:

    • 任务共享底层特征;
    • 数据总量充足;
    • 网络容量足够大。
  5. 虽不如迁移学习普及,但在计算机视觉目标检测等领域效果显著。
本文由作者按照 CC BY 4.0 进行授权