何时需要全参数微调
大多数场景 LoRA/QLoRA 已经足够,但以下情况考虑全参数微调:
- 需要从根本上改变模型的"世界观"(语言、领域)
- 持续预训练(Continual Pretraining)在新语料上
- LoRA 效果达到瓶颈,评估指标不再提升
- 数据量 > 10 万条,充分利用大数据集
- 追求最高精度,不关心显存成本
DeepSpeed ZeRO:分割显存压力
ZeRO Stage 1 — 分割优化器状态
Adam 优化器状态(momentum、variance)分散到各 GPU,每张 GPU 显存减少约 4×(8 GPU)。梯度和参数仍然每卡一份。
ZeRO Stage 2 — 分割梯度 + 优化器状态
在 Stage 1 基础上,梯度也分散存储。总节省约 8×(8 GPU)。训练速度基本不受影响,是多数场景的推荐选择。
ZeRO Stage 3 — 分割参数 + 梯度 + 优化器
模型参数本身也分散存储。总节省约 64×(8 GPU)。但 forward 需要 all-gather 参数,通信开销较大。适合极大模型(70B+)。
DeepSpeed 配置
{
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"gather_16bit_weights_on_model_save": true
},
"bf16": {"enabled": true},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 4
}
启动多 GPU 训练
# 单机 8 GPU 训练
deepspeed --num_gpus=8 train.py \
--deepspeed ds_config.json \
--model_name_or_path meta-llama/Meta-Llama-3-8B \
--dataset_path data/train.jsonl \
--output_dir ./output \
--num_train_epochs 3
# 或使用 torchrun(FSDP)
torchrun --nproc_per_node=8 \
--master_addr="localhost" \
--master_port=29500 \
train_fsdp.py
FSDP vs DeepSpeed 选择
| 特性 | FSDP | DeepSpeed ZeRO |
|---|---|---|
| 来源 | PyTorch 原生 | 微软独立库 |
| 集成难度 | 较简单 | 需要配置 JSON |
| HuggingFace 集成 | Trainer 原生支持 | 需要 deepspeed 库 |
| 多机支持 | 好 | 好 |
| CPU Offload | 有限 | 强大(可卸载到 CPU/NVMe) |
| 推荐场景 | 单机多卡,7B-30B | 多机多卡,30B+ |
全参数微调训练技巧
使用更低的学习率
全参数微调的学习率通常比 LoRA 低 10 倍:1e-5 到 5e-5。过高学习率会"灾难性遗忘"预训练知识。
Flash Attention 2
安装 flash-attn 可将 Attention 计算速度提升 2-4 倍,显存降低 50%。这是全参数微调的必选项。
混合精度 BF16
使用 BF16 而不是 FP16。BF16 数值范围更大,不容易出现 loss scaling 相关的训练不稳定问题。
全参数微调与 LoRA 的显存对比
单张 GPU 的显存占用分析
以 7B 参数模型为例,BF16 精度下参数占用约 14 GB;Adam 优化器状态(momentum + variance,FP32)额外占用约 56 GB;梯度(BF16)约 14 GB;激活值视 batch size 和序列长度而定,约 5-20 GB。总计单卡训练至少需要 80-100+ GB 显存,远超单张 A100 80GB 的容量,这是为什么需要 ZeRO 或 FSDP 分布式训练的根本原因。
ZeRO 的通信代价
ZeRO 通过在 GPU 间分散存储来减少每张 GPU 的显存占用,但代价是需要额外的通信:Stage 3 在每次 forward/backward 时需要 all-gather 参数,NVLink 节点内通信快(~300 GB/s),跨节点 InfiniBand 通信慢(~100 Gb/s)。因此 Stage 3 在多机环境下速度显著低于单机,需要权衡。对于能放进单机的模型,优先 Stage 2。
Gradient Checkpointing 的原理
标准训练需要保存所有中间激活值(forward 时计算,backward 时使用)。Gradient Checkpointing 只保存关键的 checkpoint 激活值(通常每隔 N 层保存一次),backward 时需要重新计算没有保存的激活值。代价是约 30% 的训练速度降低,换取约 10 倍的激活值显存节省。在全参数微调中几乎是必开的选项。
灾难性遗忘的防范
全参数微调最大的风险是模型在学习新数据的同时"遗忘"了预训练的能力(如通用语言理解、指令跟随)。防范措施:使用极小的学习率(1e-5 量级);混合部分通用数据和领域数据训练(比例约 1:5 到 1:10);使用 Warm-up + Cosine 学习率调度;在验证集上监控通用基准指标,一旦发现下降停止训练。
Flash Attention 2 的工作原理
# 安装 Flash Attention 2(需要 CUDA 11.6+,Ampere/Ada 架构 GPU)
pip install flash-attn --no-build-isolation
# 在 Transformers 中启用
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
attn_implementation="flash_attention_2", # 关键参数
torch_dtype=torch.bfloat16,
)
为什么 Flash Attention 快?
标准 Attention 计算 QK^T 矩阵时,中间结果(S×S 注意力矩阵)需要写入 HBM(高带宽内存),再读回来做 softmax,产生大量的 HBM 读写。Flash Attention 2 将整个 Attention 计算分块放进 SRAM(片上缓存),大幅减少 HBM 访问次数。由于 Attention 计算是显存带宽受限而非算力受限,Flash Attention 2 可以提速 2-4x,且无数值误差(数学等价)。
多机训练的环境配置
# 多机 DeepSpeed 训练(主节点 + 工作节点)
# 需要所有节点 SSH 免密登录,共享存储(NFS/NAS)
# 主节点上创建 hostfile(列出所有节点)
cat hostfile
worker-0 slots=8 # 8张A100
worker-1 slots=8
# 在主节点启动(DeepSpeed 自动 SSH 到工作节点)
deepspeed --hostfile hostfile \
--num_gpus 8 \
--num_nodes 2 \
train.py \
--deepspeed ds_z3_config.json \
--model_name_or_path meta-llama/Meta-Llama-3-70B
# torchrun 多机版本
torchrun --nnodes=2 \
--nproc_per_node=8 \
--rdzv_id=100 \
--rdzv_backend=c10d \
--rdzv_endpoint="worker-0:29400" \
train_fsdp.py
分布式训练的常见问题与调试
NCCL 通信超时
多机训练中 NCCL 集合通信(all-reduce、all-gather)超时是最常见的问题。原因:网络带宽不足、防火墙阻断 NCCL 端口(默认随机端口)、节点间时钟不同步。排查方法:设置 NCCL_DEBUG=INFO 打印通信日志;使用 NCCL_SOCKET_IFNAME 指定正确的网络接口;在启动前用 iperf 测试节点间带宽。
OOM(显存溢出)的系统化排查
显存溢出时,不要只是调小 batch_size,应该先用 nvidia-smi 分析哪个阶段 OOM:如果在 forward 开始就 OOM,说明模型本身太大(需要升 ZeRO Stage 或开启 CPU Offload);如果在 backward OOM,说明激活值太多(需要开启 Gradient Checkpointing);如果在 optimizer step OOM,说明优化器状态太大(需要升 ZeRO Stage 2+)。
训练挂起(Hanging)的定位
分布式训练挂起(所有进程停止响应但不报错)通常是某个进程在等待其他进程的集合通信。定位方法:设置 TORCH_DISTRIBUTED_DEBUG=DETAIL;用 py-spy 或 gdb 查看所有进程的调用栈;注意数据加载不均匀(某个 GPU 等待其他 GPU 完成数据加载)——使用 DataLoader 时 num_workers 要设为 0 或确保各节点数据加载速度一致。
保存和恢复分布式训练状态
全参数微调时间长(几十小时),必须做定期 checkpoint(每 500-1000 步)。ZeRO Stage 3 下每个 GPU 只有部分参数,save_pretrained 需要先聚合:使用 zero_to_fp32.py 脚本将分片 checkpoint 合并为完整模型。FSDP 使用 FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 在 rank 0 上聚合保存,避免 OOM。
全参数微调的云成本控制
A100 集群训练成本极高,务必做好成本控制:使用 Spot/Preemptible 实例(价格约 70% 折扣,但随时可能被中断——必须有 checkpoint 恢复机制);设置训练时间上限(max_steps + 预期时间估算);在正式训练前用 1% 数据跑 5 步验证配置正确;使用 W&B 或 TensorBoard 实时监控 loss 曲线,loss 不下降立即停止排查。
Gradient Checkpointing 配置与效果
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
# 开启 Gradient Checkpointing(大幅节省激活值显存)
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False} # 推荐 False,更稳定
)
# 注意:Gradient Checkpointing 与以下特性不兼容
# - 某些自定义注意力实现
# - model.generate() 时需要先 model.gradient_checkpointing_disable()
# - 速度下降约 20-30%,但激活值显存减少约 10 倍
训练稳定性:梯度裁剪与 BF16 精度
# TrainingArguments 中的稳定性关键参数
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./output",
bf16=True, # 优先 BF16(数值范围比 FP16 更大)
max_grad_norm=1.0, # 梯度裁剪阈值,防止梯度爆炸
learning_rate=2e-5, # 全参数微调用较小学习率
warmup_ratio=0.03, # 前 3% 步 warmup,防止初始训练不稳定
lr_scheduler_type="cosine", # Cosine 衰减,优于线性
weight_decay=0.01, # L2 正则化,防止过拟合
gradient_checkpointing=True,
gradient_accumulation_steps=4,
per_device_train_batch_size=2,
save_strategy="steps",
save_steps=500, # 每 500 步保存 checkpoint
save_total_limit=3, # 只保留最近 3 个 checkpoint
)
本章小结
本章核心要点
- 何时选择全参数微调:数据量 > 10万条、需要深度改变模型语言/知识体系、LoRA 效果达到瓶颈、从事持续预训练。大多数业务场景 LoRA/QLoRA 已经足够,全参数微调仅在确有必要时使用。
- ZeRO 三个 Stage 的选择:Stage 1(分散优化器状态,4× 节省);Stage 2(再分散梯度,8× 节省,推荐起点);Stage 3(再分散参数,64× 节省,适合 70B+,但通信开销大)。从 Stage 2 开始,不够再升 Stage 3。
- FSDP vs DeepSpeed:FSDP 是 PyTorch 原生,与 Trainer 集成好,适合单机多卡 7B-30B;DeepSpeed 功能更强大(CPU Offload),适合多机多卡极大模型。新手从 FSDP 开始,需要 CPU 卸载再换 DeepSpeed。
- Flash Attention 2 是必选项:提速 2-4×,显存降 50%,数学等价。安装后只需加 attn_implementation="flash_attention_2" 参数,改动极小,收益极大。
- 防范灾难性遗忘:学习率 ≤ 5e-5;混合通用数据和领域数据;监控通用基准指标;使用 Cosine 学习率衰减而非常量。这些是全参数微调稳定训练的关键保证。