论文情况概括

https://arxiv.org/abs/2407.10490

摘要 (Abstract):

论文研究了大型语言模型 (LLM) 在不同类型微调过程中的学习动力学,即特定训练样本的学习如何影响模型对其他样本的预测。作者通过分析影响在不同潜在响应之间逐步累积的分解,提出了一个统一的框架来解释指令微调和偏好微调中流行算法的许多有趣观察。特别地,论文对微调后特定类型幻觉(例如,模型使用问题B的回答中的短语或事实来回答问题A,或者模型在生成响应时不断重复相似的简单短语)为何会增强提出了一个假设性解释。该框架还被扩展用于解释离策略直接偏好优化 (off-policy DPO) 中先前观察到的现象——“挤压效应”(squeezing effect),即DPO运行时间过长会导致期望输出的可能性降低。此框架还为在策略DPO (on-policy DPO) 和其他变体的优势来源提供了见解。该分析不仅为理解LLM微调提供了新的视角,还启发了一种简单有效的方法来提高对齐性能。

  • 核心贡献/主要发现 (Main Contribution/Findings - What):

    1. 统一的学习动力学框架: 提出了一个形式化的LLM微调学习动力学框架,通过将模型预测的变化分解为三个不同作用的项,该框架可统一应用于各种微调算法,如监督微调 (SFT)、直接偏好优化 (DPO) 及其变体,甚至基于强化学习的方法。
    2. 解释关键现象: 该框架有助于解释微调过程中的几个有趣且反直觉的观察结果,包括:

      • 特定类型幻觉的增强(例如,模型混淆不同问题的答案内容)。
      • 偏好微调后模型倾向于重复简单短语(“重复者”现象)。
      • Off-policy DPO过程中,所有响应(包括期望的响应)的置信度均下降的现象。
    3. 提出并解释“挤压效应”(squeezing effect): 发现并解释了在DPO等基于梯度上升的算法中,由于Softmax层后的交叉熵损失,存在一种“挤压效应”。即对于每个词元的预测,负梯度会压低几乎所有可能输出标签的概率,并将这些概率质量转移到最可能的标签上。这种效应在负梯度施加于本身就不太可能的标签时尤为严重,从而解释了off-policy DPO中几乎所有响应置信度下降的原因。
    4. 解释On-policy DPO的优势: 基于“挤压效应”的分析,为on-policy DPO及其变体相比off-policy DPO表现更好的原因提供了新的视角(即on-policy方法可以缓解“挤压效应”的负面影响)。
    5. 提出改进对齐性能的新方法: 受此动力学分析的启发,提出了一种简单但反直觉却有效的方法来进一步提高模型的对齐性能,即在SFT阶段同时对期望响应 ($y_u^+$) 和不期望响应 ($y_u^-$) 进行训练,以缓解后续DPO阶段的“挤压效应”。

image.png


数据集与模型

数据集:markyfsun/chinese-enthusiastic-dpo
https://huggingface.co/datasets/markyfsun/chinese-enthusiastic-dpo

模型:Qwen3-4B-Base

uv需要安装的包

File: pyproject.toml

[project]
name = "tryllm"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "accelerate>=1.8.1",
    "datasets>=3.6.0",
    "lighteval>=0.10.0",
    "matplotlib>=3.10.3",
    "more-itertools>=10.7.0",
    "peft>=0.15.2",
    "unsloth>=2025.6.8",
    "vllm>=0.9.1",
    "wandb>=0.21.0",
]

实验

纯DPO

# DPO.py
import unsloth
import torch
import torch.nn.functional as F  # 添加F用于log_softmax
import gc
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from datasets import load_dataset
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from unsloth import FastLanguageModel, PatchDPOTrainer
# 补丁DPOTrainer以使用unsloth的优化
PatchDPOTrainer()
from trl import DPOTrainer, DPOConfig
import warnings

# 忽略一些不必要的警告
warnings.filterwarnings("ignore")

# 1. 模型和分词器初始化 (使用你提供的代码)
# =================================================================
max_seq_length = 2048
lora_rank = 32

# 从预训练模型加载,支持4bit量化和vLLM快速推理
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen3-4B-Base",
    max_seq_length=max_seq_length,
    load_in_4bit=False,  # 为了计算log_prob,我们使用16位LoRA
    fast_inference=True,
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.98,
    # token="hf_...", # 如果需要,请填入你的HuggingFace Token
)

# 添加LoRA适配器
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank * 2,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)


# 为Qwen3模型设置chat_template
# 基于unsloth文档中的Qwen3 chat template
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ 'You are a helpful assistant.' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"

tokenizer.chat_template = chat_template

# 2. 数据集加载与预处理
# =================================================================
def format_chat_template(example):
    """
    为Qwen3模型格式化数据集。
    'prompt' 列将包含 system 和 user 的对话。
    'chosen' 和 'rejected' 列只包含 assistant 的回答文本。
    """
    # 构建prompt部分,这是DPO trainer需要的格式
    # Qwen3的模板要求一个system prompt,即使是空的
    prompt_messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": example['prompt']}
    ]

    # tokenizer.apply_chat_template会为我们处理特殊的token
    # add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n
    formatted_prompt = tokenizer.apply_chat_template(
        prompt_messages,
        tokenize=False
    )

    return {
        "prompt": formatted_prompt,
        "chosen": example['chosen'],
        "rejected": example['rejected']
    }

# 加载数据集
ds = load_dataset("markyfsun/chinese-enthusiastic-dpo")

# 3. 核心复现逻辑:自定义Callback和对数似然计算
# =================================================================
# 训练配置参数
NUM_EPOCHS = 100  # 训练轮数
PROBE_DATASET_SIZE = 200  # 选择的样本个数 (-1 表示使用所有样本)
EVAL_INTERVAL = 4  # 评测间隔(每隔多少个epoch评测一次)

# 格式化数据集
formatted_ds = ds.map(format_chat_template)
# 根据PROBE_DATASET_SIZE选择样本数量
if PROBE_DATASET_SIZE == -1:
    train_dataset = formatted_ds['train']
    actual_dataset_size = len(train_dataset)
    print(f"使用所有样本进行训练和评估,共 {actual_dataset_size} 个样本")
else:
    train_dataset = formatted_ds['train'].select(range(PROBE_DATASET_SIZE))
    actual_dataset_size = PROBE_DATASET_SIZE
    print(f"使用 {actual_dataset_size} 个样本进行训练和评估")

print("数据格式化示例:")
print("Prompt:\n", repr(train_dataset[0]['prompt']))
print("\nChosen:\n", repr(train_dataset[0]['chosen']))

# 设置用于评估的数据集
if PROBE_DATASET_SIZE == -1:
    probe_dataset = train_dataset
else:
    probe_dataset = train_dataset.select(range(PROBE_DATASET_SIZE))

# 生成结果文件名
dataset_size_str = "all" if PROBE_DATASET_SIZE == -1 else str(PROBE_DATASET_SIZE)
results_filename = f"dpo_results_samples_{dataset_size_str}_interval_{EVAL_INTERVAL}_epochs_{NUM_EPOCHS}.json"
plot_filename = f"squeezing_effect_samples_{dataset_size_str}_interval_{EVAL_INTERVAL}_epochs_{NUM_EPOCHS}.png"

print(f"结果将保存到: {results_filename}")
print(f"图表将保存到: {plot_filename}")

# 用于存储每个epoch结果的全局变量
log_probs_history = {
    "epoch": [],
    "chosen_log_probs": [],
    "rejected_log_probs": [],
    "argmax_log_probs": [],  # 修改变量名以反映新的计算方法
}

def save_results_to_json():
    """将当前结果保存到JSON文件"""
    results = {
        "config": {
            "num_epochs": NUM_EPOCHS,
            "probe_dataset_size": PROBE_DATASET_SIZE,
            "actual_dataset_size": actual_dataset_size,
            "eval_interval": EVAL_INTERVAL
        },
        "results": log_probs_history
    }

    with open(results_filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"结果已保存到 {results_filename}")

def calculate_log_prob(model, tokenizer, prompt, response):
    """
    计算给定prompt下,模型生成特定response的对数似然。
    """
    # 将prompt和response拼接,并进行分词
    prompt_tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    response_tokens = tokenizer(response, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)

    # 拼接起来形成完整的输入序列
    full_tokens = torch.cat([prompt_tokens, response_tokens], dim=1)

    # 创建与response对应的labels,prompt部分用-100忽略
    labels = torch.full_like(full_tokens, -100)
    labels[:, prompt_tokens.shape[1]:] = response_tokens

    with torch.no_grad():
        outputs = model(full_tokens, labels=labels)
        # outputs.loss 是每个token loss的均值,乘以token数量得到总loss
        # 总loss就是负的对数似然
        log_prob = -outputs.loss.item() * (response_tokens.shape[1])

    return log_prob if not np.isnan(log_prob) else -np.inf


def calculate_argmax_log_prob(model, tokenizer, prompt, chosen_response):
    """
    计算在 chosen_response 的路径上,每个token位置的最大对数概率之和。
    这对应论文中的 "argmax confidence"。
    """
    prompt_tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    response_tokens = tokenizer(chosen_response, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    full_tokens = torch.cat([prompt_tokens, response_tokens], dim=1)

    with torch.no_grad():
        outputs = model(full_tokens)
        logits = outputs.logits

    # 我们只关心 response 部分的 logits
    # logits 的形状是 (batch, seq_len, vocab_size)
    # 我们需要预测 response_tokens[i] 的 logits 在 full_tokens[i-1] 的位置
    prompt_len = prompt_tokens.shape[1]
    response_len = response_tokens.shape[1]

    # 提取与 response 对应的 logits
    # The logits for the first token of the response are at index `prompt_len - 1`
    response_logits = logits[:, prompt_len - 1 : -1, :]

    # 计算每个位置的对数概率
    log_probs = F.log_softmax(response_logits, dim=-1)

    # 找到每个位置的最大对数概率
    # log_probs 形状: (1, response_len, vocab_size)
    max_log_probs, _ = torch.max(log_probs, dim=2)

    # 将所有位置的最大对数概率相加
    total_argmax_log_prob = torch.sum(max_log_probs).item()

    return total_argmax_log_prob if not np.isnan(total_argmax_log_prob) else -np.inf


def plot_squeezing_effect():
    """
    绘制对数似然变化图
    """
    epochs = log_probs_history["epoch"]
    chosen_lp = log_probs_history["chosen_log_probs"]
    rejected_lp = log_probs_history["rejected_log_probs"]
    argmax_lp = log_probs_history["argmax_log_probs"]  # 修改变量名

    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7))

    ax.plot(epochs, chosen_lp, 'o-', label=r"Avg Log-Prob of $y_{chosen}$", color='green', lw=2)
    ax.plot(epochs, rejected_lp, 'o-', label=r"Avg Log-Prob of $y_{rejected}$", color='red', lw=2)
    # 修改标签以匹配论文 (Figure 4)
    ax.plot(epochs, argmax_lp, 'o-', label=r"Avg Log-Prob of Argmax ($y^*$)", color='blue', lw=2)

    ax.set_title("Replication of 'Squeezing Effect' in DPO", fontsize=16)
    ax.set_xlabel("Training Epoch", fontsize=12)
    ax.set_ylabel("Average Log-Probability", fontsize=12)
    ax.legend(fontsize=12)
    ax.grid(True)

    # 确保x轴刻度与评估间隔对齐
    max_epoch = int(max(epochs))
    ax.set_xticks(np.arange(0, max_epoch + 1, step=EVAL_INTERVAL))

    fig.tight_layout()

    # 保存图像
    plt.savefig(plot_filename, dpi=300)
    print(f"\nResults plot saved to '{plot_filename}'")
    plt.show()


class SqueezingEffectCallback(TrainerCallback):
    """
    按指定间隔计算并记录chosen, rejected, greedy响应的对数似然。
    """
    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        current_epoch = int(state.epoch)
        total_epochs = int(args.num_train_epochs) if args else NUM_EPOCHS  # 使用配置的epoch数

        # 只在指定间隔的epoch或最后一个epoch进行评测
        if current_epoch % EVAL_INTERVAL != 0 and current_epoch != total_epochs:
            print(f"\nEpoch {current_epoch} ended. Skipping evaluation (will evaluate every {EVAL_INTERVAL} epochs).")
            return

        print(f"\nEpoch {current_epoch} ended. Running squeezing effect analysis on {actual_dataset_size} samples...")

        # 获取当前模型,确保在评估模式
        current_model = kwargs['model'].eval()

        chosen_log_probs, rejected_log_probs, argmax_log_probs = [], [], []

        for i, example in enumerate(probe_dataset):
            try:
                prompt_text = example['prompt']
                chosen_text = example['chosen']
                rejected_text = example['rejected']

                # 1. 计算 chosen 和 rejected 的对数似然 (不变)
                chosen_log_probs.append(calculate_log_prob(current_model, tokenizer, prompt_text, chosen_text))
                rejected_log_probs.append(calculate_log_prob(current_model, tokenizer, prompt_text, rejected_text))

                # 2. 修改:计算 Argmax Confidence,而不是生成 greedy response
                argmax_log_probs.append(calculate_argmax_log_prob(current_model, tokenizer, prompt_text, chosen_text))

                print(f"  Processed sample {i+1}/{actual_dataset_size}", end='\r')
            except Exception as e:
                print(f"\n  Error processing sample {i+1}: {e}")
                # 使用默认值避免中断
                chosen_log_probs.append(-1000.0)
                rejected_log_probs.append(-1000.0)
                argmax_log_probs.append(-1000.0)

        # 计算平均值并存储
        log_probs_history["epoch"].append(state.epoch)
        log_probs_history["chosen_log_probs"].append(np.mean(chosen_log_probs))
        log_probs_history["rejected_log_probs"].append(np.mean(rejected_log_probs))
        log_probs_history["argmax_log_probs"].append(np.mean(argmax_log_probs))

        print("\nAnalysis complete for this epoch.")
        print(f"  Avg Chosen LogProb: {np.mean(chosen_log_probs):.4f}")
        print(f"  Avg Rejected LogProb: {np.mean(rejected_log_probs):.4f}")
        print(f"  Avg Argmax Confidence: {np.mean(argmax_log_probs):.4f}")

        # 保存中间结果到JSON文件
        save_results_to_json()

        # 生成并保存当前的图表
        plot_squeezing_effect()

        # 清理内存
        gc.collect()
        torch.cuda.empty_cache()

        # 将模型切换回训练模式
        current_model.train()

# 4. DPO训练器设置与训练
# =================================================================

# 在开始训练前,先在0-shot模型上评估一次作为基线 (Epoch 0)
squeezing_callback = SqueezingEffectCallback()

# 为了初始评测,我们需要创建一个临时的TrainingArguments
temp_args = DPOConfig(num_train_epochs=NUM_EPOCHS)  # 使用配置的epoch数
squeezing_callback.on_epoch_end(temp_args, TrainerState(epoch=0.0), TrainerControl(), model=model)


dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None, # DPO with LoRA and no reference model
    args=DPOConfig(
        per_device_train_batch_size=1,  # 减小batch size用于测试
        gradient_accumulation_steps=1,  # 减小用于测试
        warmup_ratio=0.1,
        num_train_epochs=NUM_EPOCHS, # 使用配置的训练轮数
        learning_rate=5e-6,
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.0,
        lr_scheduler_type="linear",
        seed=42,
        output_dir="outputs",
        report_to="none",
        beta=0.1,  # beta参数应该在DPOConfig中设置
        max_length=max_seq_length,
        max_prompt_length=max_seq_length // 2,
    ),
    train_dataset=train_dataset,
    processing_class=tokenizer,  # 使用processing_class而不是tokenizer
    callbacks=[squeezing_callback], # 添加我们的自定义回调
)

print("\nStarting DPO training...")
dpo_trainer.train()
print("Training finished.")

# 清理分布式资源
import torch.distributed as dist
if dist.is_initialized():
    dist.destroy_process_group()
    print("Distributed process group destroyed.")

# 5. 结果可视化
# =================================================================

# 最终绘制结果(如果训练过程中没有自动绘制)
if len(log_probs_history["epoch"]) > 0:
    print("\n训练完成!最终结果图表已在训练过程中生成。")
    print(f"最终结果文件: {results_filename}")
    print(f"最终图表文件: {plot_filename}")
else:
    print("\n没有评测数据,生成空图表...")
    plot_squeezing_effect()

运行60个epoch的结果截图:

image.png

可以观测到初期$y_win$会升高,但是后面也会降低;模型原始的最高概率输出会增加

SFT+DPO

先将$y_win$与$y_lose$ SFT两轮,再进行DPO

File: SFT_DPO.py

# SFT_DPO.py
import unsloth
import torch
import torch.nn.functional as F
import gc
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from datasets import load_dataset, Dataset
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from unsloth import FastLanguageModel, PatchDPOTrainer
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
import warnings

# 忽略一些不必要的警告
warnings.filterwarnings("ignore")

# 补丁DPOTrainer以使用unsloth的优化
PatchDPOTrainer()

# 1. 训练配置
# =================================================================
# SFT 阶段配置
SFT_EPOCHS = 2  # 论文中建议的SFT轮数 [cite: 356]

# DPO 阶段配置
DPO_EPOCHS = 60 # DPO训练轮数
PROBE_DATASET_SIZE = 200 # 选择的样本个数 (-1 表示使用所有样本)
EVAL_INTERVAL = 4 # DPO评测间隔(每隔多少个epoch评测一次)

# LoRA 和模型配置
max_seq_length = 2048
lora_rank = 32

# 2. 模型和分词器初始化
# =================================================================
# 从预训练模型加载,支持4bit量化和vLLM快速推理
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen3-4B-Base",
    max_seq_length=max_seq_length,
    load_in_4bit=False,  # 为了计算log_prob,我们使用16位LoRA
    fast_inference=True,
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.98,
    # token="hf_...", # 如果需要,请填入你的HuggingFace Token
)

# 添加LoRA适配器
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank * 2,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

# 为Qwen3模型设置chat_template
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ 'You are a helpful assistant.' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"

tokenizer.chat_template = chat_template

# 3. 数据集加载与预处理
# =================================================================
def create_sft_dataset(dpo_dataset):
    """
    根据论文的 "extend" 方法,创建SFT数据集。
    SFT阶段同时在 chosen 和 rejected 响应上进行训练。 [cite: 327]
    """
    sft_data = {"text": []}
    for example in dpo_dataset:
        prompt = example['prompt']
        # 构造 chosen 样本
        chosen_messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": example['chosen']}
        ]
        sft_data["text"].append(tokenizer.apply_chat_template(chosen_messages, tokenize=False))

        # 构造 rejected 样本
        rejected_messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": example['rejected']}
        ]
        sft_data["text"].append(tokenizer.apply_chat_template(rejected_messages, tokenize=False))

    return Dataset.from_dict(sft_data)

def format_dpo_template(example):
    """
    为DPO训练格式化数据集。
    """
    prompt_messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": example['prompt']}
    ]
    formatted_prompt = tokenizer.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)
    return {
        "prompt": formatted_prompt,
        "chosen": example['chosen'],
        "rejected": example['rejected']
    }

# 加载原始数据集
ds = load_dataset("markyfsun/chinese-enthusiastic-dpo")['train']

# 根据PROBE_DATASET_SIZE选择样本
if PROBE_DATASET_SIZE != -1:
    ds = ds.select(range(PROBE_DATASET_SIZE))
actual_dataset_size = len(ds)
print(f"使用 {actual_dataset_size} 个核心样本进行SFT, DPO和评估。")


# 创建SFT和DPO数据集
sft_train_dataset = create_sft_dataset(ds)
dpo_train_dataset = ds.map(format_dpo_template)
probe_dataset = dpo_train_dataset # DPO评估时使用这个格式

print("SFT 数据集示例 (包含chosen和rejected响应):")
print(repr(sft_train_dataset[0]['text']))
print(repr(sft_train_dataset[1]['text']))

# 4. SFT 训练
# =================================================================
print("\n--- 阶段 1: 开始SFT训练 ---")
sft_trainer = SFTTrainer(
    model=model,
    train_dataset=sft_train_dataset,
    args=SFTConfig(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        num_train_epochs=SFT_EPOCHS,
        learning_rate=2e-4,
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=42,
        output_dir="sft_outputs",
        report_to="none",
        max_seq_length=max_seq_length,
    ),
    dataset_text_field="text",
)
sft_trainer.train()
print("--- SFT训练完成 ---")


# 5. DPO 训练与 squeezing effect 复现
# =================================================================
# 生成结果文件名
results_filename = f"sft_dpo_results_samples_{actual_dataset_size}_sfte_{SFT_EPOCHS}_dpoe_{DPO_EPOCHS}.json"
plot_filename = f"sft_dpo_squeezing_effect_samples_{actual_dataset_size}_sfte_{SFT_EPOCHS}_dpoe_{DPO_EPOCHS}.png"

print(f"结果将保存到: {results_filename}")
print(f"图表将保存到: {plot_filename}")

# 用于存储每个epoch结果的全局变量
log_probs_history = {
    "epoch": [],
    "chosen_log_probs": [],
    "rejected_log_probs": [],
    "argmax_log_probs": [],
}

def save_results_to_json():
    """将当前结果保存到JSON文件"""
    results = {
        "config": {
            "sft_epochs": SFT_EPOCHS,
            "dpo_epochs": DPO_EPOCHS,
            "probe_dataset_size": PROBE_DATASET_SIZE,
            "actual_dataset_size": actual_dataset_size,
            "eval_interval": EVAL_INTERVAL
        },
        "results": log_probs_history
    }
    with open(results_filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"结果已保存到 {results_filename}")

# (calculate_log_prob, calculate_argmax_log_prob, plot_squeezing_effect 函数与 DPO.py 相同)
def calculate_log_prob(model, tokenizer, prompt, response):
    prompt_tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    response_tokens = tokenizer(response, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    full_tokens = torch.cat([prompt_tokens, response_tokens], dim=1)
    labels = torch.full_like(full_tokens, -100)
    labels[:, prompt_tokens.shape[1]:] = response_tokens
    with torch.no_grad():
        outputs = model(full_tokens, labels=labels)
        log_prob = -outputs.loss.item() * (response_tokens.shape[1])
    return log_prob if not np.isnan(log_prob) else -np.inf

def calculate_argmax_log_prob(model, tokenizer, prompt, chosen_response):
    prompt_tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    response_tokens = tokenizer(chosen_response, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    full_tokens = torch.cat([prompt_tokens, response_tokens], dim=1)
    with torch.no_grad():
        outputs = model(full_tokens)
        logits = outputs.logits
    prompt_len = prompt_tokens.shape[1]
    response_logits = logits[:, prompt_len - 1 : -1, :]
    log_probs = F.log_softmax(response_logits, dim=-1)
    max_log_probs, _ = torch.max(log_probs, dim=2)
    total_argmax_log_prob = torch.sum(max_log_probs).item()
    return total_argmax_log_prob if not np.isnan(total_argmax_log_prob) else -np.inf

def plot_squeezing_effect():
    epochs = log_probs_history["epoch"]
    chosen_lp = log_probs_history["chosen_log_probs"]
    rejected_lp = log_probs_history["rejected_log_probs"]
    argmax_lp = log_probs_history["argmax_log_probs"]
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7))
    ax.plot(epochs, chosen_lp, 'o-', label=r"Avg Log-Prob of $y_{chosen}$", color='green', lw=2)
    ax.plot(epochs, rejected_lp, 'o-', label=r"Avg Log-Prob of $y_{rejected}$", color='red', lw=2)
    ax.plot(epochs, argmax_lp, 'o-', label=r"Avg Log-Prob of Argmax ($y^*$)", color='blue', lw=2)
    ax.set_title(f"Squeezing Effect (SFT {SFT_EPOCHS} epochs then DPO)", fontsize=16)
    ax.set_xlabel("DPO Training Epoch", fontsize=12)
    ax.set_ylabel("Average Log-Probability", fontsize=12)
    ax.legend(fontsize=12)
    ax.grid(True)
    # 在0点处增加一条垂直虚线表示SFT和DPO的分割点
    ax.axvline(x=0, color='grey', linestyle='--', linewidth=2, label='SFT/DPO Boundary')
    ax.legend()
    fig.tight_layout()
    plt.savefig(plot_filename, dpi=300)
    print(f"\nResults plot saved to '{plot_filename}'")
    plt.show()

class SqueezingEffectCallback(TrainerCallback):
    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        current_epoch = int(state.epoch)
        total_epochs = int(args.num_train_epochs) if args else DPO_EPOCHS
        if current_epoch % EVAL_INTERVAL != 0 and current_epoch != total_epochs:
            print(f"\nDPO Epoch {current_epoch} ended. Skipping evaluation.")
            return

        print(f"\nDPO Epoch {current_epoch} ended. Running squeezing effect analysis...")
        current_model = kwargs['model'].eval()
        chosen_log_probs, rejected_log_probs, argmax_log_probs = [], [], []

        for i, example in enumerate(probe_dataset):
            try:
                prompt_text, chosen_text, rejected_text = example['prompt'], example['chosen'], example['rejected']
                chosen_log_probs.append(calculate_log_prob(current_model, tokenizer, prompt_text, chosen_text))
                rejected_log_probs.append(calculate_log_prob(current_model, tokenizer, prompt_text, rejected_text))
                argmax_log_probs.append(calculate_argmax_log_prob(current_model, tokenizer, prompt_text, chosen_text))
                print(f"  Processed sample {i+1}/{actual_dataset_size}", end='\r')
            except Exception as e:
                print(f"\n  Error processing sample {i+1}: {e}")
                chosen_log_probs.append(-1000.0)
                rejected_log_probs.append(-1000.0)
                argmax_log_probs.append(-1000.0)

        log_probs_history["epoch"].append(state.epoch)
        log_probs_history["chosen_log_probs"].append(np.mean(chosen_log_probs))
        log_probs_history["rejected_log_probs"].append(np.mean(rejected_log_probs))
        log_probs_history["argmax_log_probs"].append(np.mean(argmax_log_probs))

        print("\nAnalysis complete for this DPO epoch.")
        print(f"  Avg Chosen LogProb: {np.mean(chosen_log_probs):.4f}")
        print(f"  Avg Rejected LogProb: {np.mean(rejected_log_probs):.4f}")
        print(f"  Avg Argmax Confidence: {np.mean(argmax_log_probs):.4f}")

        save_results_to_json()
        plot_squeezing_effect()
        gc.collect()
        torch.cuda.empty_cache()
        current_model.train()

# 手动执行 "Post-SFT / Pre-DPO" 评估
print("\n--- 评估 SFT 后的模型 (DPO Epoch 0) ---")
squeezing_callback = SqueezingEffectCallback()
temp_args = DPOConfig(num_train_epochs=DPO_EPOCHS)
squeezing_callback.on_epoch_end(temp_args, TrainerState(epoch=0.0), TrainerControl(), model=model)


print("\n--- 阶段 2: 开始DPO训练 ---")
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None, # DPO with LoRA and no reference model
    args=DPOConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        warmup_ratio=0.1,
        num_train_epochs=DPO_EPOCHS,
        learning_rate=5e-6,
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.0,
        lr_scheduler_type="linear",
        seed=42,
        output_dir="dpo_outputs",
        report_to="none",
        beta=0.1,
        max_length=max_seq_length,
        max_prompt_length=max_seq_length // 2,
    ),
    train_dataset=dpo_train_dataset,
    tokenizer=tokenizer,
    callbacks=[squeezing_callback],
)

dpo_trainer.train()
print("--- DPO训练完成 ---")

可以看到$y_win$的下降没有前一个实验下降的那么低

image.png

Last modification:July 13, 2025
如果觉得我的文章对你有用,请随意赞赏