# npo_train_ddp.py  ----------------------------------------------
import os
import json
import torch
import torch.distributed as dist
import torch.nn.functional as F
import pandas as pd
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29510")
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import get_train_data_ddp, preprocess_test, LanguageModelingDataset

from peft import LoraConfig, PeftModel
from utils.lora import apply_lora

from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer

# 如果你希望限制单卡可见，请设置下面这行。例如只用 GPU 0：
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 你的 HuggingFace 访问 token（如果加载私有模型或 trust_remote_code 需要授权）
access_token = "your_huggingface_access_token"


def compute_npo_loss(
        current_model,        # 不含 DDP 外壳
        initial_model,        # 冻结的参考模型
        input_ids,
        attention_mask,
        beta=0.1):
    """
    NPO loss 计算（sequence-level）。返回标量。
    """
    device = input_ids.device
    bsz = input_ids.size(0)

    # --- 1. reference 模型前向 ---
    with torch.no_grad():
        ref_logits = initial_model(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   return_dict=True).logits
        # log p_ref  (bsz, seq_len-1)
        ref_logprob = F.log_softmax(ref_logits[:, :-1, :], dim=-1)
    # --- 2. current 模型前向 ---
    cur_logits = current_model(input_ids=input_ids,
                               attention_mask=attention_mask,
                               return_dict=True).logits
    cur_logprob = F.log_softmax(cur_logits[:, :-1, :], dim=-1)

    # --- 3. 取 ground-truth token 对数概率 ---
    tgt = input_ids[:, 1:]                        # (bsz, seq_len-1)
    cur_selected = cur_logprob.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)
    ref_selected = ref_logprob.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)

    # --- 4. mask 掉 padding ---
    if attention_mask is not None:
        mask = attention_mask[:, 1:].bool()       # same length
        cur_selected = cur_selected.masked_fill(~mask, 0.0)
        ref_selected = ref_selected.masked_fill(~mask, 0.0)

    # --- 5. 序列级对数概率差 ---
    diff = (cur_selected - ref_selected).sum(dim=-1)   # shape (bsz,)

    # --- 6. NPO loss ---
    #   LNPO = -(2/β) * log σ(-β * diff)
    loss = -(2.0 / beta) * F.logsigmoid(-beta * diff)
    return loss.mean()

def main(rank, world_size, args):
    """
    rank: 当前进程编号 (0 .. world_size-1)
    world_size: 总进程数
    args: 命令行传入参数
    """
    # 1. 初始化 NCCL 通信
    print(f"[Rank {rank}] Initializing DDP (world_size={world_size}) ...")
    dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    # 2. 准备 Tokenizer
    tokenizer = TransformerModel(args.model_name).get_tokenizer()

    args.save_dir = os.path.join(args.save_dir, args.model_name.split("/")[-1])
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    print(f"[Rank {rank}] Forget loaders: {[len(loader.dataset) for loader in forget_loaders]}")
    print(f"[Rank {rank}] Retain loaders: {[len(loader.dataset) for loader in retain_loaders]}")
    # 4. 实例化可训练模型并应用 LoRA
    #    4.1 先加载基础架构（HuggingFace 模型），再加 LoRA
    base_model = TransformerModel(args.model_name).get_model()
    # 4.2 应用 LoRA（得到一个 PeftModel），比如 r=args.lora_r
    # lora_model = apply_lora(base_model, args.lora_r, model_name=args.model_name)
    # 加载模型
    lora_model = PeftModel.from_pretrained(base_model, f"/data/wwh/llmUN2/base/{args.model_name.split('/')[-1]}/0")
    print(f"[Rank {rank}] Loaded LoRA model with {args.lora_r} rank.")
    # 4.3 把 LoRA 模型移动到对应设备
    lora_model.to(device)
    lora_model.train()
    for name, param in lora_model.named_parameters():
        if "lora" in name.lower():
            param.requires_grad = True
        else:
            param.requires_grad = False
    # 4.4 用 DDP 包裹
    model = DDP(lora_model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    with torch.inference_mode():
        initial_model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float16,
            device_map={"": rank})      # 每张卡各一份
    initial_model.requires_grad_(False)
    initial_model.eval()


    # 6. 优化器：只优化 LoRA 参数即可。假设 apply_lora 已经将所有 LoRA 参数设为 requires_grad=True，其余冻
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],  # 仅 LoRA
        lr=args.lr, betas=(0.9,0.95), eps=1e-6
    )

    # 8. 训练循环：用 NPO loss 代替原先的 GAN/GA Loss
    loss_hist = []
    loss_step_hist = []

    # rank 0 负责创建输出文件夹
    if rank == 0:
        os.makedirs(args.save_dir, exist_ok=True)
    loss_hist.append(0.0)
    loss_save_path = os.path.join(args.save_dir, "loss_epoch.csv")
    loss_step_save_path = os.path.join(args.save_dir, "loss_step.csv")
    for epoch in range(args.epochs):
        model.train()
        # rank 0 定期保存 LoRA checkpoint
        if rank == 0:
            if epoch % 1 == 0:
                save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                os.makedirs(save_dir, exist_ok=True)
                # 保存 LoRA adapter
                model.module.save_pretrained(save_dir)
                tokenizer.save_pretrained(save_dir)
                print(f"[Rank {rank}] Saving LoRA checkpoint to '{save_dir}'")

            # 将 Loss 历史写入 CSV
            with open(loss_save_path, "a") as f:
                if f.tell() == 0:
                    f.write("epoch,loss\n")
                else:
                    f.write(f"{epoch},{loss_hist[-1]:.6f}\n")

        # 遍历所有 forget_loader
        for loader in forget_loaders:
            loader.sampler.set_epoch(epoch)  # 保证 shuffle 同步
            for step, batch in enumerate(loader):
                inp_ids = batch["input_ids"].to(device)           # shape (bsz, seq_len)
                attention_mask = batch["attention_mask"].to(device)  # shape (bsz, seq_len)
                # 计算 NPO loss
                # 由于 model 是 DDP 包裹，实际要把裸模型（module）传进去
                loss = compute_npo_loss(model.module, initial_model, inp_ids, attention_mask, beta=args.beta)
                # 反向传播与更新
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # 记录并打印
                if rank == 0 and step % 10 == 0:
                    print(f"[Epoch {epoch+1} | Step {step}] NPO Loss = {loss.item():.6f}")
                    loss_step_hist.append(loss.item())
                    with open(loss_step_save_path, "w") as f:
                        if f.tell() == 0:
                            f.write("step,loss\n")
                        for step, loss in enumerate(loss_step_hist):
                            f.write(f"{step},{loss:.6f}\n")
            # 将最后一个 batch 的 loss 追加到历史
            loss_hist.append(loss.item())
    if rank == 0:
        print("🔔 Training completed.")
    dist.destroy_process_group()

if __name__ == "__main__":
    import argparse
    import torch.multiprocessing as mp

    parser = argparse.ArgumentParser(description="DDP Training with NPO (DPO) Loss")
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)  # 或 "wmdp-cyber"
    parser.add_argument("--save_dir", type=str, required=True)
    parser.add_argument("--epochs", type=int, required=True, help="训练总轮数")
    parser.add_argument("--batch_size", type=int, required=True, help="每张 GPU 上的 batch size")
    parser.add_argument("--lr", type=float, required=True, help="学习率")
    parser.add_argument("--lora_r", type=int, required=True, help="LoRA rank (r)")
    parser.add_argument("--beta", type=float, required=True, help="NPO (DPO) 中的 β 超参数")

    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size, args), nprocs=world_size)
