08 多任务学习(Multi-task Learning)
08 多任务学习(Multi-task Learning)
一、多任务学习 vs 迁移学习
| 特性 | 迁移学习(Transfer Learning) | 多任务学习(Multi-task Learning) |
|---|---|---|
| 学习方式 | 串行:先在源任务上训练,再迁移到目标任务 | 并行:同时训练多个任务 |
| 目标 | 利用大数据任务提升小数据任务性能 | 通过共享表示,让多个任务互相促进 |
| 典型场景 | 目标任务数据少,源任务数据多 | 多个相关任务需同时完成,且数据量相近 |
✅ 关键区别:迁移学习是“先学A,再用于B”;多任务学习是“同时学A、B、C、D”。
二、多任务学习的核心思想
构建一个共享的神经网络,同时输出多个任务的预测结果。
例如:自动驾驶中,一张图像需同时判断是否存在:
- 行人(pedestrian)
- 车辆(car)
- 停车标志(stop sign)
- 交通灯(traffic light)
每个样本的标签是一个多维二值向量(而非单标签):
\[y = \begin{bmatrix} y_1 \\ y_2 \\ y_3 \\ y_4 \end{bmatrix} \in \{0,1\}^4\]其中 $y_i = 1$ 表示存在第 $i$ 类物体,$0$ 表示不存在。
📌 注意:一张图可含多个物体(如同时有车和停车标志),因此是多标签分类(multi-label classification) ,而非互斥的 softmax 分类。
三、损失函数设计
对每个任务使用 logistic 损失(即二元交叉熵) ,总损失为各任务损失之和的平均:
设训练集有 $m$ 个样本,第 $i$ 个样本的真实标签为 $y^{(i)} = [y_1^{(i)}, y_2^{(i)}, y_3^{(i)}, y_4^{(i)}]^\top$,模型预测为 $\hat{y}^{(i)} = [\hat{y}_1^{(i)}, \dots, \hat{y}_4^{(i)}]^\top$。
则整体损失函数为:
\[\mathcal{L} = \frac{1}{m} \sum_{i=1}^{m} \sum_{j=1}^{4} \mathcal{L}_{\text{logistic}}(y_j^{(i)}, \hat{y}_j^{(i)})\]其中:
\[\mathcal{L}_{\text{logistic}}(y, \hat{y}) = - \left[ y \log \hat{y} + (1 - y) \log (1 - \hat{y}) \right]\]✅ 此处每个输出节点独立使用 sigmoid 激活 + 二元交叉熵,不是 softmax(因标签非互斥)。
四、处理部分标注数据(Missing Labels)
实际中,某些样本可能只标注了部分任务(如只标了“有行人”,未标“是否有交通灯”)。
解决方案:在计算损失时,仅对已标注的任务求和,忽略缺失项(用问号 ? 表示)。
形式化地,定义掩码 $M_j^{(i)} = 1$ 若 $y_j^{(i)}$ 已知,否则为 0,则损失变为:
\[\mathcal{L} = \frac{1}{m} \sum_{i=1}^{m} \frac{1}{\sum_{j=1}^{4} M_j^{(i)}} \sum_{j=1}^{4} M_j^{(i)} \cdot \mathcal{L}_{\text{logistic}}(y_j^{(i)}, \hat{y}_j^{(i)})\]💡 实践中常简化为:只对非缺失标签计算损失,不归一化也可接受。
五、多任务学习有效的三大条件
1. 任务可共享低层特征(Shared Representations)
- 例如:行人、车辆、交通标志都出现在道路场景中,共享边缘、纹理、形状等底层视觉特征。
- 神经网络前几层可提取通用特征,后几层做任务特异性预测。
2. 各任务数据量相近,或总量远大于单任务数据
- 若单独训练任务 $T_k$ 只有 1000 个样本,但其他 99 个任务共提供 99,000 个样本,
- 则多任务学习可通过隐式数据增强提升 $T_k$ 性能。
- ⚠️ 若某任务数据极少而其他任务数据极多,迁移学习可能更合适。
3. 神经网络足够大,能同时拟合所有任务
- Rich Caruana 的研究指出:只有当网络容量不足时,多任务学习才可能比单任务差。
- 若网络足够大,多任务学习不会降低性能,通常还能提升泛化能力。
✅ 建议:使用足够深/宽的网络,并加入任务特定的输出头(task-specific heads)。
六、多任务学习 vs 单任务多个模型
| 方法 | 优点 | 缺点 |
|---|---|---|
| 多任务学习(一个网络) | 共享特征,节省计算资源;任务间正则化,提升泛化 | 设计复杂;任务冲突可能降低性能 |
| 多个独立网络 | 任务完全解耦,无干扰 | 无法共享特征;训练/部署成本高 |
✅ 在计算机视觉(如目标检测)中,多任务学习广泛成功(如 YOLO、Faster R-CNN 同时预测类别+边界框)。
七、实际应用与使用频率
- 多任务学习使用频率 📌 吴恩达观点: 迁移学习更常用 ,尤其在小数据场景; 多任务学习是特定场景下的强力工具**。
八、总结(Key Takeaways)
- 多任务学习 = 一个网络 + 多个输出 + 共享表示。
- 标签是多维二值向量,使用 sigmoid + 二元交叉熵,非 softmax。
- 可处理部分标注数据,只需在损失中忽略缺失项。
有效前提:
- 任务共享底层特征;
- 数据总量充足;
- 网络容量足够大。
- 虽不如迁移学习普及,但在计算机视觉目标检测等领域效果显著。