import copy
import os, json, random
import torch
import torch.distributed as dist
import torch.nn.functional as F
import argparse, torch.multiprocessing as mp
import pandas as pd
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoModelForCausalLM
from models.transformer_model import TransformerModel
from utils.data_loader_wmdp import get_train_data_ddp, preprocess_test, LanguageModelingDataset
from utils.lora import apply_lora
from peft import PeftModel
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

def compute_npo_loss(
        current_model,        # 不含 DDP 外壳
        initial_model,        # 冻结的参考模型
        input_ids,
        attention_mask,
        beta=0.1):
    """
    NPO loss 计算（sequence-level）。返回标量。
    """
    device = input_ids.device
    bsz = input_ids.size(0)

    # --- 1. reference 模型前向 ---
    with torch.no_grad():
        ref_logits = initial_model(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   return_dict=True).logits
        # log p_ref  (bsz, seq_len-1)
        ref_logprob = F.log_softmax(ref_logits[:, :-1, :], dim=-1)
    # --- 2. current 模型前向 ---
    cur_logits = current_model(input_ids=input_ids,
                               attention_mask=attention_mask,
                               return_dict=True).logits
    cur_logprob = F.log_softmax(cur_logits[:, :-1, :], dim=-1)

    # --- 3. 取 ground-truth token 对数概率 ---
    tgt = input_ids[:, 1:]                        # (bsz, seq_len-1)
    cur_selected = cur_logprob.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)
    ref_selected = ref_logprob.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)

    # --- 4. mask 掉 padding ---
    if attention_mask is not None:
        mask = attention_mask[:, 1:].bool()       # same length
        cur_selected = cur_selected.masked_fill(~mask, 0.0)
        ref_selected = ref_selected.masked_fill(~mask, 0.0)

    # --- 5. 序列级对数概率差 ---
    diff = (cur_selected - ref_selected).sum(dim=-1)   # shape (bsz,)

    # --- 6. NPO loss ---
    #   LNPO = -(2/β) * log σ(-β * diff)
    loss = -(2.0 / beta) * F.logsigmoid(-beta * diff)
    return loss.mean()

def compute_dpo_loss(
        current_model,        # 不含 DDP 外壳
        initial_model,        # 冻结的参考模型
        input_ids,
        attention_mask,
        beta=0.1):
    """
    DPO loss 计算（sequence-level）。返回标量。
    """
    device = input_ids.device
    bsz = input_ids.size(0)
    # --- 1. reference 模型前向 ---
    with torch.no_grad():
        ref_logits = initial_model(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   return_dict=True).logits
        # log p_ref  (bsz, seq_len-1)
        ref_logprob = F.log_softmax(ref_logits[:, :-1, :], dim=-1)
    # --- 2. current 模型前向 ---
    cur_logits = current_model(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 return_dict=True).logits
    cur_logprob = F.log_softmax(cur_logits[:, :-1, :], dim=-1)
    # --- 3. 取 ground-truth token 对数概率 ---
    tgt = input_ids[:, 1:]                        # (bsz, seq_len-1)
    cur_selected = cur_logprob.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)
    ref_selected = ref_logprob.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)
    # --- 4. mask 掉 padding ---
    if attention_mask is not None:
        mask = attention_mask[:, 1:].bool()       # same length
        cur_selected = cur_selected.masked_fill(~mask, 0.0)
        ref_selected = ref_selected.masked_fill(~mask, 0.0)
    # --- 5. 序列级对数概率差 ---
    diff = (cur_selected - ref_selected).sum(dim=-1)   # shape (bsz,)
    # --- 6. DPO loss ---
    #   LDPO = -(1/β) * log σ(β * diff)
    loss = -(1.0 / beta) * F.logsigmoid(beta * diff)
    return loss.mean()
    
# -------- 对抗扰动生成 --------
def apply_perturbation(params, grads, method, rho, mu=1e-2, wa_state=None, N=5):
    """
    params: list of torch.Tensor current parameters
    grads: list of torch.Tensor gradients of forget loss
    method: 'SAM','RS','GP','CR','WA'
    rho: perturbation magnitude
    backup_params: original params for CR and WA
    mu, gamma: hyperparams
    wa_state: dict with {'theta_wa': list, 'count': int}
    Returns perturbed parameters list
    """
    perturbed = []
    if method == 'SAM':
        # delta = rho * grad / ||grad||
        norm = torch.sqrt(sum((g**2).sum() for g in grads)) + 1e-12
        for p, g in zip(params, grads):
            perturbed.append(p + rho * g / norm)
    elif method == 'RS':
        # random gaussian perturbation
        for p in params:
            perturbed.append(p + torch.randn_like(p) * rho)
    elif method == 'GP':
        # equivalent to SAM first-order approx: p + rho * grad / ||grad||
        norm = torch.sqrt(sum((g**2).sum() for g in grads)) + 1e-12
        for p, g in zip(params, grads):
            perturbed.append(p + rho * g / norm)
    elif method == 'CR':
        # curvature regularization: p + mu * grad, then penalize difference
        # here we simulate delta = mu * grad
        for p, g in zip(params, grads):
            perturbed.append(p + mu * g)
    elif method == 'WA':
        # weight averaging: theta_wa = (count * theta_wa + theta) / (count+1)
        if wa_state['theta_wa'] is None:
            wa_state['theta_wa'] = [p.clone() for p in params]
            wa_state['count'] = 1
        else:
            n = wa_state['count']
            for idx, p in enumerate(params):
                wa_state['theta_wa'][idx] = (n * wa_state['theta_wa'][idx] + p) / (n + 1)
            if wa_state['count'] < N:
                wa_state['count'] += 1
            else:
                wa_state['count'] = N
        perturbed = [w.clone() for w in wa_state['theta_wa']]
    else:
        perturbed = [p.clone() for p in params]
    for i, p in enumerate(perturbed):
        if p.norm() > 1e-3:
            perturbed[i] = p / p.norm() * 1e-3  # clip to max norm
    return perturbed

# -------- 元学习循环 --------
def main(rank, world_size, args):
    # 1. 初始化 NCCL 通信
    print(f"Process {rank} initializing DDP..., world size={world_size}")
    os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
    os.environ.setdefault("MASTER_PORT", "29005")
    dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f'cuda:{rank}')
    tokenizer = TransformerModel(args.model_name).get_tokenizer()
    scaler = torch.amp.GradScaler()
    args.save_dir = os.path.join(args.save_dir, args.model_name.split("/")[-1])
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    # 打印当前的数据加载器大小
    print(f"[Rank {rank}] Forget loaders: {[len(loader.dataset) for loader in forget_loaders]}")
    print(f"[Rank {rank}] Retain loaders: {[len(loader.dataset) for loader in retain_loaders]}")

    # 3. 模型
    model = TransformerModel(args.model_name).get_model()
    model = apply_lora(model, args.lora_r, model_name=args.model_name)
    model.to(device)
    model = DDP(model, device_ids=[rank], output_device=rank)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],  # 仅 LoRA
        lr=args.lr, betas=(0.9,0.95), eps=1e-6
    )

    # 评估集初始化...
    methods = ['SAM','RS','GP','CR','WA']
    wa_state = {'theta_wa': None, 'count': 0}
    NPO_loss_hist = []
    forget_loss_hist = []
    retain_loss_hist = []

    npo_step_hist = []
    forget_step_hist = []
    retain_step_hist = []

    NPO_loss_hist.append(0.0)  # 初始化损失历史
    forget_loss_hist.append(0.0)
    retain_loss_hist.append(0.0)

    loss_save_path = os.path.join(args.save_dir, "loss_epoch.csv")
    loss_step_save_path = os.path.join(args.save_dir, "loss_step.csv")

    with torch.inference_mode():
        initial_model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float16,
            device_map={"": rank}
        )      # 每张卡各一份
    initial_model.requires_grad_(False)
    initial_model.eval()

    for epoch in range(args.epochs):
        model.train()
        # Remembering Feedback
        # meta-tune on forget
        if rank == 0:
            if epoch%1 == 0:
                lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                os.makedirs(lora_save_dir, exist_ok=True)
                model.module.save_pretrained(lora_save_dir)
                tokenizer.save_pretrained(lora_save_dir)
                print(f"Saving LoRA checkpoint to {lora_save_dir}...")
            with open(loss_save_path, "a") as f:
                if f.tell() == 0:
                    f.write("Epoch,NPO Loss,forget Loss, retain_loss\n")
                else:
                    f.write(f"{epoch},{NPO_loss_hist[-1]},{forget_loss_hist[-1]},{retain_loss_hist[-1]}\n")
        for l in forget_loaders + retain_loaders:
            l.sampler.set_epoch(epoch)
        for f_loader, r_loader in zip(forget_loaders, retain_loaders):
            step=0
            for (f_batch, r_batch) in zip(f_loader, r_loader):
                optimizer.zero_grad()
                step += 1
                # meta_tune step (Forgetting Feedback)
                # ---------- NPO ------------
                f_ids  = f_batch["input_ids"].to(device)
                f_mask = f_batch["attention_mask"].to(device)

                r_ids  = r_batch["input_ids"].to(device)
                r_mask = r_batch["attention_mask"].to(device)
                total_loss = compute_npo_loss(model.module, initial_model, f_ids, f_mask, beta=args.beta)
                NPO_loss_hist.append(total_loss.item())
                trainable_names = [name for name,p in model.named_parameters() if p.requires_grad]
                orig_state = {name: p for name,p in model.named_parameters()}
                trainable_params = [p for p in model.parameters() if p.requires_grad]
                perturbed_state = dict(orig_state)  # shallow copy, 保存原始参数

                grads_f = torch.autograd.grad(total_loss, trainable_params, retain_graph=True)
                loss_adv = 0.0
                # apply meta-tune step
                meta_params = [p - args.lr * g for p, g in zip(trainable_params, grads_f)]
                # meta step (Forgetting Feedback)
                for method in random.sample(methods, 2):
                    # calculate grads on forget
                    # temporarily assign perturbed_params, 映射回原始参数
                    # generate perturbation
                    with torch.no_grad():
                        deltas = apply_perturbation(meta_params, grads_f, method, args.rho, mu=args.mu, wa_state=wa_state, N=args.N)
                        # 替换deltas的参数到model中
                        for name, p_new in zip(trainable_names, deltas):
                            perturbed_state[name] = p_new
                        # assign perturbed parameters to model
                        for name, p in model.named_parameters():
                            if p.requires_grad:
                                # p.data = perturbed_state[name]
                                p.data.copy_(perturbed_state[name])
                        # compute adversarial loss
                        # loss_adv += -model(f_ids, attention_mask=f_mask, labels=f_ids).loss
                        loas_adv += compute_npo_loss(model.module, initial_model, f_ids, f_mask, beta=args.beta)
                        print(f"Method: {method}, Forget Loss: {loss_adv.item()}")
                grads_forget = torch.autograd.grad(loss_adv * args.gamma + total_loss, trainable_params, retain_graph=True)
                forget_loss_hist.append(loss_adv.item())
                # meta-step (Retaining Feedback)
                loss_retain = 0.0
                # 替换meta_params的参数到model中
                for name, p_meta in zip(trainable_names, meta_params):
                        perturbed_state[name] = p_meta
                # assign perturbed parameters to model
                for name, p in model.named_parameters():
                    if p.requires_grad:
                        p.data.copy_(perturbed_state[name])
                # calculate grads on retain
                # loss_retain = model(r_ids, attention_mask=r_mask, labels=r_ids).loss
                loss_retain = compute_dpo_loss(model.module, initial_model, r_ids, r_mask, beta=args.beta)
                # print(f"Retain Loss: {loss_retain.item()}")
                retain_loss_hist.append(loss_retain.item())
                grads_retain = torch.autograd.grad(loss_retain + total_loss, trainable_params)
                new_grads = []
                # Gradient Harmonization
                for g_r, g_f in zip(grads_retain, grads_forget):
                # 举例：如果夹角<0 就把 g_f 在 g_r 上投影出去
                    flat_r = g_r.view(-1)
                    flat_f = g_f.view(-1)
                    cos = F.cosine_similarity(g_r.view(-1), g_f.view(-1), dim=0)
                    if cos < 0:
                        proj = (flat_f.dot(flat_r) / flat_r.norm().pow(2)) * flat_r
                        flat_f = flat_f - proj
                    # 合并
                    new_grads.append((flat_r + flat_f).view_as(g_r))
                i = 0
                for name, g_meta in zip(trainable_names, new_grads):
                        perturbed_state[name] = g_meta
                # assign perturbed parameters to model
                for name, p in model.named_parameters():
                    if p.requires_grad:
                        i += 1
                        p.grad = perturbed_state[name]         
                optimizer.step()
                if rank == 0 and step % 10 == 0:
                    print(f"[E{epoch+1}|S{step}] NPO Loss={total_loss.item():.4f}, "
                          f"Forget Loss={forget_loss_hist[-1]:.4f}, "
                          f"Retain Loss={retain_loss_hist[-1]:.4f}, ")
                    npo_step_hist.append(total_loss.item())
                    forget_step_hist.append(forget_loss_hist[-1])
                    retain_step_hist.append(retain_loss_hist[-1])
                if step% 10000 == 0:
                    # 保存模型
                    lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}_{step}"
                    os.makedirs(lora_save_dir, exist_ok=True)
                    model.module.save_pretrained(lora_save_dir)
                    tokenizer.save_pretrained(lora_save_dir)
                    print(f"Saving LoRA checkpoint to {lora_save_dir}...")
                for p in trainable_params:
                    p.grad = None
    if rank == 0:
        print("Training done.")
        with open(loss_step_save_path, "w") as f:
            # 如果文件不存在，则写入表头
            if f.tell() == 0:
                f.write("step,npo_loss,forget_loss,retain_loss\n")
            for step in range(len(npo_step_hist)):
                f.write(f"{step+1},{npo_step_hist[step]},"
                        f"{forget_step_hist[step]},{retain_step_hist[step]}\n")
    dist.destroy_process_group()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', required=True)
    parser.add_argument("--dataset", type=str, required=True)  # 或 "wmdp-cyber"
    parser.add_argument('--epochs', type=int, required=True)
    parser.add_argument('--batch_size', type=int, required=True)
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--lora_r', type=int, required=True)
    parser.add_argument('--rho', type=float, required=True)
    parser.add_argument('--mu', type=float, required=True)
    parser.add_argument('--N', type=int, required=True)
    parser.add_argument("--lambda_ga", type=float, required=True)  # GA损失的权重
    parser.add_argument('--query_size', type=int, required=True)
    parser.add_argument('--gamma', type=float, required=True)
    parser.add_argument("--save_dir", type=str, required=True)
    parser.add_argument("--beta", type=float, required=True, help="NPO (DPO) 中的 β 超参数")
    args = parser.parse_args()
    world = torch.cuda.device_count()
    mp.spawn(main, args=(world, args), nprocs=world)
