# ga_train_ddp.py  ----------------------------------------------
import os, json, torch, torch.distributed as dist, pandas as pd
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import get_train_data_ddp, preprocess_test, LanguageModelingDataset
from peft import LoraConfig, PeftModel
from utils.data_loader_wmdp import get_train_data, preprocess_test
from utils.lora import apply_lora
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
access_token = "your_huggingface_access_token"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# ------------- DDP 主要逻辑 ----------------
# rank: 当前进程编号
# world_size: 总进程数
# args: 训练参数
def main(rank, world_size, args):
    # 1. 初始化 NCCL 通信
    print(f"Process {rank} initializing DDP..., world size={world_size}")
    os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
    os.environ.setdefault("MASTER_PORT", "29501")
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    tokenizer = TransformerModel(args.model_name).get_tokenizer()
    args.save_dir = os.path.join(args.save_dir, args.model_name.split("/")[-1])
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    # 打印当前的数据加载器大小
    print(f"[Rank {rank}] Forget loaders: {[len(loader.dataset) for loader in forget_loaders]}")
    print(f"[Rank {rank}] Retain loaders: {[len(loader.dataset) for loader in retain_loaders]}")

    # 3. 模型
    model = TransformerModel(args.model_name).get_model()
    model = apply_lora(model, args.lora_r, model_name=args.model_name)
    model = PeftModel.from_pretrained(model, f"/data/wwh/llmUN2/base/{args.model_name.split('/')[-1]}/0")
    # print(f"Rank {rank} loaded model from /data/wwh/llmUN/res/ga_gd_7/{args.model_name.split('/')[-1]}/hf_ckpt_10")
    model = model.to(device)
    model.train()
    for name, param in model.named_parameters():
        if "lora" in name.lower():
            param.requires_grad = True
        else:
            param.requires_grad = False
    model = DDP(model, device_ids=[rank], output_device=rank,
                find_unused_parameters=False)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # 训练循环
    ga_loss_hist, gd_loss_hist, total_loss_hist = [], [], []
    ga_step_hist, gd_step_hist, total_step_hist = [], [], []
    if rank == 0: 
        os.makedirs(args.save_dir, exist_ok=True)
    
    ga_loss_hist.append(0.0)  # 初始化损失历史
    gd_loss_hist.append(0.0)
    total_loss_hist.append(0.0)
    
    loss_save_path = os.path.join(args.save_dir, "loss_epoch.csv")
    loss_step_save_path = os.path.join(args.save_dir, "loss_step.csv")

    for epoch in range(args.epochs):
        model.train()
        if rank == 0:
            if epoch % 1 == 0:
                lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                os.makedirs(lora_save_dir, exist_ok=True)
                model.module.save_pretrained(lora_save_dir)
                tokenizer.save_pretrained(lora_save_dir)
                print(f"✅  已保存 LoRA 模型到 {lora_save_dir}")
            with open(loss_save_path, "a") as f:
                # 如果文件不存在，则写入表头
                if f.tell() == 0:
                    f.write("epoch,ga_loss,gd_loss,total_loss\n")
                else:
                    f.write(f"{epoch+1},{ga_loss_hist[-1]},{gd_loss_hist[-1]},{total_loss_hist[-1]}\n")
        for l in forget_loaders + retain_loaders:
            l.sampler.set_epoch(epoch)
        for f_loader, r_loader in zip(forget_loaders, retain_loaders):
            step=0
            for (f_batch, r_batch) in zip(f_loader, r_loader):
                step += 1
                # ---------- GA ------------
                f_ids  = f_batch["input_ids"].to(device)
                f_mask = f_batch["attention_mask"].to(device)
                ga_loss = model(f_ids, labels=f_ids,
                                attention_mask=f_mask).loss    # 原始正向
                # ---------- GD ------------
                r_ids  = r_batch["input_ids"].to(device)
                r_mask = r_batch["attention_mask"].to(device)
                gd_loss = model(r_ids, labels=r_ids,
                                attention_mask=r_mask).loss

                total_loss = gd_loss - args.lambda_ga * ga_loss

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                if rank == 0 and step % 10 == 0:
                    print(f"[E{epoch+1}|S{step}] GA Loss={ga_loss.item():.4f}, GD Loss={gd_loss.item():.4f}, Total Loss={total_loss.item():.4f}")
                    ga_step_hist.append(ga_loss.item())
                    gd_step_hist.append(gd_loss.item())
                    total_step_hist.append(total_loss.item())
            ga_loss_hist.append(ga_loss.item())
            gd_loss_hist.append(gd_loss.item())
            total_loss_hist.append(total_loss.item())
    if rank == 0: 
        print("Training done.")
        with open(loss_step_save_path, "w") as f:
            # 如果文件不存在，则写入表头
            if f.tell() == 0:
                f.write("step,ga_loss,gd_loss,total_loss\n")
            for step in range(len(ga_step_hist)):
                f.write(f"{step+1},{ga_step_hist[step]},{gd_step_hist[step]},{total_step_hist[step]}\n")
    dist.destroy_process_group()

# ------------- 启动 ---------------
if __name__ == "__main__":
    import argparse, torch.multiprocessing as mp
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_name", type=str, required=True)
    ap.add_argument("--dataset", type=str, required=True)  # 或 "wmdp-cyber"
    ap.add_argument("--save_dir", type=str, required=True)
    ap.add_argument("--epochs", type=int, required=True)
    ap.add_argument("--batch_size", type=int, drequired=True)   # 已是“每卡” batch
    ap.add_argument("--lr", type=float, required=True)
    ap.add_argument("--lora_r", type=int, required=True)
    ap.add_argument("--lambda_ga", type=float, required=True)  # GA损失的权重
    args = ap.parse_args()
    world = torch.cuda.device_count()
    mp.spawn(main, args=(world, args), nprocs=world)
