# ga_train_ddp.py  ----------------------------------------------
import os, json, torch, torch.distributed as dist, pandas as pd
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "2941")
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import get_train_data_ddp, preprocess_test, LanguageModelingDataset
from peft import LoraConfig, PeftModel
from utils.lora import apply_lora
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
from utils.data_loader_wmdp import get_train_data, preprocess_test
from huggingface_hub import login
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
access_token = "your_huggingface_access_token"
login(token=access_token)
# ------------- DDP 主要逻辑 ----------------
# rank: 当前进程编号
# world_size: 总进程数
# args: 训练参数
def main(rank, world_size, args):
    # 1. 初始化 NCCL 通信
    print(f"Process {rank} initializing DDP..., world size={world_size}")
    dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    tokenizer = TransformerModel(args.model_name).get_tokenizer()
    args.save_dir = os.path.join(args.save_dir, args.model_name.split("/")[-1])
    # 2. 数据集 & 分布式采样器
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        forget_loaders, _ = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        forget_loaders, _ = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    # 打印当前的数据加载器大小
    print(f"[Rank {rank}] Forget loaders: {[len(loader.dataset) for loader in forget_loaders]}")
    print(f"[Rank {rank}] Retain loaders: {[len(loader.dataset) for loader in _]}")

    # 3. 模型
    model = TransformerModel(args.model_name).get_model()
    model = apply_lora(model, args.lora_r, model_name=args.model_name)
    model.to(device)
    model = DDP(model, device_ids=[rank], output_device=rank,
                find_unused_parameters=False)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # 训练循环
    loss_hist = []
    loss_step_hist = []
    if rank == 0: 
        os.makedirs(args.save_dir, exist_ok=True)
    
    loss_hist.append(0.0)  # 初始化损失历史
    loss_save_path = os.path.join(args.save_dir, "loss_epoch.csv")
    loss_step_save_path = os.path.join(args.save_dir, "loss_step.csv")

    for epoch in range(args.epochs):
        model.train()
        if rank == 0:
            if epoch % 1 == 0:
                lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                os.makedirs(lora_save_dir, exist_ok=True)
                model.module.save_pretrained(lora_save_dir)
                tokenizer.save_pretrained(lora_save_dir)
                print(f"✅  Saving LoRA checkpoint to {lora_save_dir}...")
            with open(loss_save_path, "a") as f:
                if f.tell() == 0:
                    f.write("epoch,loss\n")
                else:
                    f.write(f"{epoch},{loss_hist[-1]}\n")
        for loader in forget_loaders:
            loader.sampler.set_epoch(epoch)
            for step, batch in enumerate(loader):
                inp = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                loss = model(inp, labels=inp,
                             attention_mask=attention_mask).loss                
                optimizer.zero_grad(); 
                (-loss).backward(); 
                optimizer.step()
                if rank == 0 and step % 100 == 0:
                    print(f"[E{epoch+1}|S{step}] GA Loss={loss.item():.4f}")
                    loss_step_hist.append(loss.item())
                    # 保存当前 step 的模型
                # if step % 1000 == 0:
                #     step_save_dir = os.path.join(args.save_dir, f"hf_ckpt_{epoch}_step_{step}")
                #     os.makedirs(step_save_dir, exist_ok=True)
                #     model.module.save_pretrained(step_save_dir)
                #     tokenizer.save_pretrained(step_save_dir)
                #     print(f"✅  Saved checkpoint at {step_save_dir}")
            loss_hist.append(loss.item())
            with open(loss_step_save_path, "w") as f:
                # 如果文件不存在，则写入表头
                if f.tell() == 0:
                    f.write("step,loss\n")
                for step, loss in enumerate(loss_step_hist):
                    f.write(f"{step+1},{loss}\n")
    if rank == 0: 
        print("Training done.")
        dist.destroy_process_group()

# ------------- 启动 ---------------
if __name__ == "__main__":
    import argparse, torch.multiprocessing as mp
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_name", type=str, required=True)
    ap.add_argument("--dataset", type=str, required=True)  # 或 "wmdp-cyber"
    ap.add_argument("--save_dir", type=str, required=True)
    ap.add_argument("--epochs", type=int, required=True)
    ap.add_argument("--batch_size", type=int, required=True)   # 已是“每卡” batch
    ap.add_argument("--lr", type=float, required=True)
    ap.add_argument("--lora_r", type=int, required=True)
    args = ap.parse_args()

    world = torch.cuda.device_count()
    mp.spawn(main, args=(world, args), nprocs=world)
