文章

07 迁移学习(Transfer Learning)

07 迁移学习(Transfer Learning)

07 迁移学习(Transfer Learning)

一、核心思想

迁移学习的核心理念是:在一个任务(源任务)上训练好的模型,其学到的知识可以迁移到另一个相关但不同的目标任务上,从而提升目标任务的学习效率或性能,尤其是在目标任务数据量较少的情况下。

关键直觉:神经网络的前几层通常学习的是通用的低层次特征(如边缘、纹理、曲线等),这些特征在多个视觉或听觉任务中具有共通性。


二、典型应用场景

场景1:图像识别 → 医学影像诊断

  • 源任务 A:大规模图像分类(如 ImageNet,100 万张图像,类别:猫、狗、鸟等)
  • 目标任务 B:放射科 X 光片诊断(仅 100 张图像,标签:疾病/正常)

场景2:语音识别 → 唤醒词检测

  • 源任务 A:语音转文本(10,000 小时音频)
  • 目标任务 B:触发词检测(如 “Hey Siri”,仅 1 小时数据)

三、迁移学习的操作流程

步骤 1:预训练(Pre-training)

在源任务 A 上训练一个完整的神经网络:

\[\min_{\theta} \mathcal{L}_A(\theta) = \frac{1}{N_A} \sum_{i=1}^{N_A} \ell(f_\theta(x_i^A), y_i^A)\]

其中:

  • $x_i^A$:源任务输入(如自然图像)
  • $y_i^A$:源任务标签
  • $\theta$:网络全部参数
  • $\ell$:损失函数(如交叉熵)

步骤 2:迁移与微调(Fine-tuning)

  1. 移除原输出层(对应源任务的分类头)
  2. 添加新的输出层,适配目标任务 B 的类别数
  3. 初始化新输出层参数为随机值(如 $\mathbf{W}_{\text{new}} \sim \mathcal{N}(0, \sigma^2)$)
  4. 在目标任务 B 上继续训练:

情况 A:小目标数据集(如 10k 样本)

  • 微调整个网络(所有层参数均可更新):

    \[\min_{\theta} \mathcal{L}_B(\theta) = \frac{1}{N_B} \sum_{i=1}^{N_B} \ell(f_\theta(x_i^B), y_i^B)\]

    其中学习率通常设得较小(如 $10^{-4}$),以避免破坏预训练学到的通用特征。


四、为何有效?——特征复用原理

神经网络具有层次化特征表示能力

  • 浅层(靠近输入) :学习通用低级特征(edges, corners, blobs)

    \[\phi_1(x) \approx \text{Gabor filters, edge detectors}\]
  • 中层:组合成部件(如眼睛、轮子)
  • 深层(靠近输出) :任务特定语义(如“这是猫”)

在目标任务 B 中,即使数据少,也能复用源任务 A 学到的通用特征提取器(即共享的底层网络),只需学习新的高层判别器。


五、迁移学习有效的前提条件

迁移学习有意义当且仅当满足以下条件:

条件说明
输入空间一致源任务与目标任务的输入类型相同(如都是图像或都是音频)
源任务数据 » 目标任务数据$N_A \gg N_B$(如 $10^6 \gg 10^2$)
低层特征可迁移源任务学到的特征对目标任务有帮助(如自然图像中的边缘对 X 光片仍有意义)
反向迁移通常无效若 $N_B > N_A$,则源任务数据价值低,迁移收益有限

⚠️ 反例:用 100 张猫狗图预训练,去提升 1000 张 X 光片的诊断——不推荐,因为每张 X 光片的信息价值远高于普通图像。


六、术语澄清

术语含义
Pre-training(预训练)在大规模源任务上训练模型,获得初始参数 $\theta_0$
Fine-tuning(微调)在目标任务上继续训练(部分或全部参数),从 $\theta_0$ 开始优化
Feature extraction(特征提取)冻结预训练网络,仅训练新分类头——适用于极小目标数据集

七、实践建议

目标数据规模推荐策略
极小( 10k)微调整个网络,使用较小学习率

八、与多任务学习的区别(预告)

  • 迁移学习串行学习 —— 先学任务 A,再迁移到任务 B
  • 多任务学习并行学习 —— 同时优化多个任务的联合损失:

    \[\min_{\theta} \sum_{k=1}^K \lambda_k \mathcal{L}_k(\theta)\]

    (将在下一讲介绍)


九、总结口诀

“同输入、大数据、小目标、低层通”
—— 当源任务与目标任务输入相同源数据远多于目标数据目标任务数据少、且低层特征通用时,迁移学习最有效!

本文由作者按照 CC BY 4.0 进行授权