# rmu_train_ddp.py  ----------------------------------------------
import os, json, torch, torch.distributed as dist, pandas as pd
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "2942")
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import get_train_data_ddp, preprocess_test, LanguageModelingDataset
from peft import LoraConfig, PeftModel
from utils.lora import apply_lora
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
from utils.data_loader_wmdp import get_train_data, preprocess_test
from huggingface_hub import login
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
access_token = "your_huggingface_access_token"
login(token=access_token)

def compute_activation_loss(model, input_ids, attention_mask=None, reference_activations=None, c=10.0, alpha=1.0, u=None, layer_idx=6):
    """计算激活损失 - 使用指定层的表征"""
    model_outputs = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
    hidden_states = model_outputs.hidden_states[layer_idx]  # 第layer_idx层的激活: [B, L, H]

    if reference_activations is None:
        # Forget loss: drive activations in random direction
        if u is None:
            u = torch.rand(hidden_states.size(-1), device=hidden_states.device)  # [H]
            u = u / u.norm()  # Normalize
        target = c * u  # Scale random unit vector
        loss = ((hidden_states - target) ** 2).mean()
    else:
        # Retain loss: preserve original activations
        loss = ((hidden_states - reference_activations) ** 2).mean()

    return loss

# ------------- DDP 主要逻辑 ----------------
# rank: 当前进程编号
# world_size: 总进程数
# args: 训练参数
def main(rank, world_size, args):
    # 1. 初始化 NCCL 通信
    print(f"Process {rank} initializing DDP..., world size={world_size}")
    dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    tokenizer = TransformerModel(args.model_name).get_tokenizer()
    tokenizer.pad_token = tokenizer.eos_token
    args.save_dir = os.path.join(args.save_dir, args.model_name.split("/")[-1])
    
    # 2. 数据集 & 分布式采样器
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    
    # 打印当前的数据加载器大小
    print(f"[Rank {rank}] Forget loaders: {[len(loader.dataset) for loader in forget_loaders]}")
    print(f"[Rank {rank}] Retain loaders: {[len(loader.dataset) for loader in retain_loaders]}")

    # 3. 模型
    model = TransformerModel(args.model_name).get_model()
    model = apply_lora(model, args.lora_r, model_name=args.model_name)
    model.to(device)
    model = DDP(model, device_ids=[rank], output_device=rank,
                find_unused_parameters=True)
    
    # 4. 冻结模型用于retain引导
    frozen_model = TransformerModel(args.model_name).get_model()
    frozen_model = apply_lora(frozen_model, args.lora_r, model_name=args.model_name)
    frozen_model.load_state_dict(model.module.state_dict())
    frozen_model.to(device)
    frozen_model.eval()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # 训练循环
    forget_loss_hist = []
    retain_loss_hist = []
    forget_loss_step_hist = []
    retain_loss_step_hist = []
    
    if rank == 0: 
        os.makedirs(args.save_dir, exist_ok=True)
    
    forget_loss_hist.append(0.0)  # 初始化损失历史
    retain_loss_hist.append(0.0)
    forget_loss_save_path = os.path.join(args.save_dir, "forget_loss_epoch.csv")
    retain_loss_save_path = os.path.join(args.save_dir, "retain_loss_epoch.csv")
    forget_loss_step_save_path = os.path.join(args.save_dir, "forget_loss_step.csv")
    retain_loss_step_save_path = os.path.join(args.save_dir, "retain_loss_step.csv")

    for epoch in range(args.epochs):
        model.train()
        if rank == 0:
            if epoch % 1 == 0:
                lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                os.makedirs(lora_save_dir, exist_ok=True)
                model.module.save_pretrained(lora_save_dir)
                tokenizer.save_pretrained(lora_save_dir)
                print(f"✅  Saving LoRA checkpoint to {lora_save_dir}...")
            # 保存epoch级别的损失
            with open(forget_loss_save_path, "a") as f:
                if f.tell() == 0:
                    f.write("epoch,forget_loss\n")
                else:
                    f.write(f"{epoch},{forget_loss_hist[-1]}\n")
            with open(retain_loss_save_path, "a") as f:
                if f.tell() == 0:
                    f.write("epoch,retain_loss\n")
                else:
                    f.write(f"{epoch},{retain_loss_hist[-1]}\n")
        
        # RMU训练逻辑：交替使用forget和retain数据
        forget_iter = iter(forget_loaders[0])
        retain_iter = iter(retain_loaders[0])
        
        # 设置分布式采样器的epoch
        forget_loaders[0].sampler.set_epoch(epoch)
        retain_loaders[0].sampler.set_epoch(epoch)
        
        max_steps = max(len(forget_loaders[0]), len(retain_loaders[0]))
        
        for step in range(max_steps):
            # 获取forget batch
            try:
                forget_batch = next(forget_iter)
            except StopIteration:
                forget_iter = iter(forget_loaders[0])
                forget_batch = next(forget_iter)
            
            # 获取retain batch
            try:
                retain_batch = next(retain_iter)
            except StopIteration:
                retain_iter = iter(retain_loaders[0])
                retain_batch = next(retain_iter)
            
            input_ids_forget = forget_batch["input_ids"].to(device)
            attention_mask_forget = forget_batch["attention_mask"].to(device)
            input_ids_retain = retain_batch["input_ids"].to(device)
            attention_mask_retain = retain_batch["attention_mask"].to(device)
            
            # 获取retain的参考激活（指定层）
            with torch.no_grad():
                retain_ref_activations = frozen_model(input_ids_retain, attention_mask=attention_mask_retain, output_hidden_states=True).hidden_states[args.layer_idx]
            
            # 计算forget loss（指定层）
            u = torch.rand(retain_ref_activations.size(-1), device=device)
            u = u / u.norm()
            forget_loss = compute_activation_loss(model, input_ids_forget, attention_mask_forget, reference_activations=None, c=args.c, u=u, layer_idx=args.layer_idx)
            
            # 计算retain loss（指定层）
            retain_loss = compute_activation_loss(model, input_ids_retain, attention_mask_retain, reference_activations=retain_ref_activations, layer_idx=args.layer_idx)
            
            # 总损失
            total_loss = forget_loss + args.alpha * retain_loss
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            if rank == 0 and step % 500 == 0:
                print(f"[E{epoch+1}|S{step}] RMU Forget Loss={forget_loss.item():.4f}, Retain Loss={retain_loss.item():.4f}")
                forget_loss_step_hist.append(forget_loss.item())
                retain_loss_step_hist.append(retain_loss.item())
                lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}_{step}"
                os.makedirs(lora_save_dir, exist_ok=True)
                model.module.save_pretrained(lora_save_dir)
                tokenizer.save_pretrained(lora_save_dir)
                print(f"✅  Saving LoRA checkpoint to {lora_save_dir}...")
        
        # 更新epoch级别的损失历史
        if len(forget_loss_step_hist) > 0:
            forget_loss_hist.append(forget_loss_step_hist[-1])
            retain_loss_hist.append(retain_loss_step_hist[-1])
        
        # 保存step级别的损失
        if rank == 0:
            with open(forget_loss_step_save_path, "w") as f:
                if f.tell() == 0:
                    f.write("step,forget_loss\n")
                for step_idx, loss in enumerate(forget_loss_step_hist):
                    f.write(f"{step_idx+1},{loss}\n")
            
            with open(retain_loss_step_save_path, "w") as f:
                if f.tell() == 0:
                    f.write("step,retain_loss\n")
                for step_idx, loss in enumerate(retain_loss_step_hist):
                    f.write(f"{step_idx+1},{loss}\n")

    if rank == 0: 
        print("RMU Training done.")
        dist.destroy_process_group()

# ------------- 启动 ---------------
if __name__ == "__main__":
    import argparse, torch.multiprocessing as mp
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_name", type=str, required=True)
    ap.add_argument("--dataset", type=str, required=True)  # 或 "wmdp-bio"
    ap.add_argument("--save_dir", type=str, required=True)
    ap.add_argument("--epochs", type=int, required=True)
    ap.add_argument("--batch_size", type=int, required=True)   # 已是"每卡" batch
    ap.add_argument("--lr", type=float, required=True)
    ap.add_argument("--lora_r", type=int, required=True)
    ap.add_argument("--layer_idx", type=int, required=True)    # 计算表征的层索引
    ap.add_argument("--alpha", type=float, required=True)    # retain loss权重
    ap.add_argument("--c", type=float, required=True)       # forget loss缩放因子
    args = ap.parse_args()

    world = torch.cuda.device_count()
    mp.spawn(main, args=(world, args), nprocs=world)