import copy
from itertools import islice
import os, json, random
import torch
import torch.distributed as dist
import torch.nn.functional as F
import argparse, torch.multiprocessing as mp
import pandas as pd
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoModelForCausalLM
from models.transformer_model import TransformerModel
from utils.data_loader_wiki import get_wikitext_dataloader_ddp
from utils.data_loader_wmdp import get_train_data_ddp, preprocess_test, LanguageModelingDataset
from utils.lora import apply_lora
from peft import PeftModel
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# -------- 对抗扰动生成 --------
def apply_perturbation(params, grads, method, rho, mu=1e-2, wa_state=None, N=5):
    """
    params: list of torch.Tensor current parameters
    grads: list of torch.Tensor gradients of forget loss
    method: 'SAM','RS','GP','CR','WA'
    rho: perturbation magnitude
    backup_params: original params for CR and WA
    mu, gamma: hyperparams
    wa_state: dict with {'theta_wa': list, 'count': int}
    Returns perturbed parameters list
    """
    perturbed = []
    if method == 'SAM':
        # delta = rho * grad / ||grad||
        norm = torch.sqrt(sum((g**2).sum() for g in grads)) + 1e-12
        for p, g in zip(params, grads):
            perturbed.append(p + rho * g / norm)
    elif method == 'RS':
        # random gaussian perturbation
        for p in params:
            perturbed.append(p + torch.randn_like(p) * rho)
    elif method == 'GP':
        # equivalent to SAM first-order approx: p + rho * grad / ||grad||
        norm = torch.sqrt(sum((g**2).sum() for g in grads)) + 1e-12
        for p, g in zip(params, grads):
            perturbed.append(p + rho * g / norm)
    elif method == 'CR':
        # curvature regularization: p + mu * grad, then penalize difference
        # here we simulate delta = mu * grad
        for p, g in zip(params, grads):
            perturbed.append(p + mu * g)
    elif method == 'WA':
        # weight averaging: theta_wa = (count * theta_wa + theta) / (count+1)
        if wa_state['theta_wa'] is None:
            wa_state['theta_wa'] = [p.clone() for p in params]
            wa_state['count'] = 1
        else:
            n = wa_state['count']
            for idx, p in enumerate(params):
                wa_state['theta_wa'][idx] = (n * wa_state['theta_wa'][idx] + p) / (n + 1)
            if wa_state['count'] < N:
                wa_state['count'] += 1
            else:
                wa_state['count'] = N
        perturbed = [w.clone() for w in wa_state['theta_wa']]
    else:
        perturbed = [p.clone() for p in params]
    for i, p in enumerate(perturbed):
       if p.norm() > 1e-3:
           perturbed[i] = p / p.norm() * 1e-3  # clip to max norm
    return perturbed

def clip_perturbations(deltas, max_norm=1e-3):
    clipped = []
    for delta in deltas:
        norm = delta.norm()
        if norm > max_norm:
            delta = delta * (max_norm / norm)
        clipped.append(delta)
    return clipped

# -------- 元学习循环 --------
def main(rank, world_size, args):
    # 1. 初始化 NCCL 通信
    print(f"Process {rank} initializing DDP..., world size={world_size}")
    os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
    os.environ.setdefault("MASTER_PORT", "29505")
    dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f'cuda:{rank}')
    tokenizer = TransformerModel(args.model_name).get_tokenizer()
    scaler = torch.amp.GradScaler()
    args.save_dir = os.path.join(args.save_dir, args.model_name.split("/")[-1])
    if args.dataset == "wmdp-cyber":
        forget_corpora, retain_corpora = ["cyber-forget-corpus"], ["cyber-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    elif args.dataset == "wmdp-bio":
        forget_corpora, retain_corpora = ["bio_forget"], ["bio-retain-corpus"]
        forget_loaders, retain_loaders = get_train_data_ddp(
            forget_corpora, retain_corpora, tokenizer,
            batch_size=args.batch_size,
            sampler_cls=DistributedSampler,  # 让每个进程切片
            world_size=world_size, rank=rank
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    wikidataloader = get_wikitext_dataloader_ddp(tokenizer=tokenizer, batch_size=args.batch_size, sampler_cls=DistributedSampler, world_size=world_size, rank=rank)
    # 打印当前的数据加载器大小
    print(f"[Rank {rank}] Forget loaders: {[len(loader.dataset) for loader in forget_loaders]}")
    print(f"[Rank {rank}] Retain loaders: {[len(loader.dataset) for loader in retain_loaders]}")
    print(f"[Rank {rank}] Wiki loader: {len(wikidataloader[0])}")
    # 3. 模型
    model = TransformerModel(args.model_name).get_model()
    model = PeftModel.from_pretrained(model, f"/data/wwh/llmUN2/base/{args.model_name.split('/')[-1]}/0")
    model = apply_lora(model, args.lora_r, model_name=args.model_name)
    model.to(device)
    model = DDP(model, device_ids=[rank], output_device=rank)
    orig_model = copy.deepcopy(model.module)   # DDP 之后拿底层 module
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],  # 仅 LoRA
        lr=args.lr, betas=(0.9,0.95), eps=1e-6
    )

    # 评估集初始化...
    methods = ['SAM','RS','GP','CR','WA']
    wa_state = {'theta_wa': None, 'count': 0}
    GA_loss_hist = []
    GD_loss_hist = []
    forget_loss_hist = []
    retain_loss_hist = []

    ga_step_hist = []
    gd_step_hist = []
    forget_step_hist = []
    retain_step_hist = []

    GA_loss_hist.append(0.0)  # 初始化损失历史
    GD_loss_hist.append(0.0)
    forget_loss_hist.append(0.0)
    retain_loss_hist.append(0.0)

    loss_save_path = os.path.join(args.save_dir, "loss_epoch.csv")
    loss_step_save_path = os.path.join(args.save_dir, "loss_step.csv")
    for epoch in range(args.epochs):
        model.train()
        # Remembering Feedback
        # meta-tune on forget
        if rank == 0:
            if epoch%1 == 0:
                lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                os.makedirs(lora_save_dir, exist_ok=True)
                model.module.save_pretrained(lora_save_dir)
                tokenizer.save_pretrained(lora_save_dir)
                print(f"Saving LoRA checkpoint to {lora_save_dir}...")
            with open(loss_save_path, "a") as f:
                if f.tell() == 0:
                    f.write("Epoch,GA Loss,GD Loss,forget Loss, retain_loss\n")
                else:
                    f.write(f"{epoch},{GA_loss_hist[-1]},{GD_loss_hist[-1]},{forget_loss_hist[-1]},{retain_loss_hist[-1]}\n")
        for l in forget_loaders + retain_loaders + wikidataloader:
            l.sampler.set_epoch(epoch)
        for f_loader, r_loader in zip(forget_loaders, retain_loaders):
            step=0
            for (f_batch, r_batch) in zip(f_loader, r_loader):
                optimizer.zero_grad()
                step += 1
                # meta_tune step (Forgetting Feedback)
                # ---------- GA ------------
                f_ids  = f_batch["input_ids"].to(device)
                f_mask = f_batch["attention_mask"].to(device)
                ga_loss = model(f_ids, labels=f_ids,
                                attention_mask=f_mask).loss    # 原始正向
                # ---------- KL ------------
                r_ids  = r_batch["input_ids"].to(device)
                r_mask = r_batch["attention_mask"].to(device)
                with torch.no_grad():
                    orig_logits = orig_model(r_ids, attention_mask=r_mask).logits
                P = F.softmax(orig_logits, dim=-1)
                cur_logits = model(r_ids, attention_mask=r_mask).logits
                kl_loss = F.kl_div(F.log_softmax(cur_logits, dim=-1), P, reduction='none')
                kl_loss = (kl_loss.sum(-1) * r_mask).sum()   # token-mean
                GA_loss_hist.append(ga_loss.item())
                GD_loss_hist.append(kl_loss.item())
                trainable_names = [name for name,p in model.named_parameters() if p.requires_grad]
                orig_state = {name: p for name,p in model.named_parameters()}
                trainable_params = [p for p in model.parameters() if p.requires_grad]
                perturbed_state = dict(orig_state)  # shallow copy, 保存原始参数
                total_loss = kl_loss - args.lambda_ga * ga_loss
                grads_f = torch.autograd.grad(total_loss, trainable_params, retain_graph=True)
                loss_adv = 0.0
                # apply meta-tune step
                meta_params = [p - args.lr * g for p, g in zip(trainable_params, grads_f)]
                # meta step (Forgetting Feedback)
                for method in random.sample(methods, 2):
                    # calculate grads on forget
                    # temporarily assign perturbed_params, 映射回原始参数
                    # generate perturbation
                    with torch.no_grad():
                        deltas = apply_perturbation(meta_params, grads_f, method, args.rho, mu=args.mu, wa_state=wa_state, N=args.N)
                        # 替换deltas的参数到model中
                        for name, p_new in zip(trainable_names, deltas):
                            perturbed_state[name] = p_new
                        # assign perturbed parameters to model
                        for name, p in model.named_parameters():
                            if p.requires_grad:
                                p.data.copy_(perturbed_state[name])
                        # compute adversarial loss
                        loss_adv += -model(f_ids, attention_mask=f_mask, labels=f_ids).loss
                        print(f"step {step} Method: {method}, Forget Loss: {loss_adv.item()}")
                grads_forget = torch.autograd.grad(loss_adv * args.gamma + total_loss, trainable_params, retain_graph=True)
                forget_loss_hist.append(loss_adv.item())
                # meta-step (Retaining Feedback)
                loss_retain = 0.0
                # 替换meta_params的参数到model中
                for name, p_meta in zip(trainable_names, meta_params):
                        perturbed_state[name] = p_meta
                # assign perturbed parameters to model
                for name, p in model.named_parameters():
                    if p.requires_grad:
                        p.data.copy_(perturbed_state[name])
                # calculate grads on retain
                # loss_retain = F.kl_div(F.log_softmax(model(r_ids, attention_mask=r_mask,).logits, dim=-1), P, reduction='none')
                # loss_retain = (loss_retain.sum(-1) * r_mask).sum()   # token-mean
                # 从wikidataloader中随机获取一个batch
                idx = random.randint(0, len(wikidataloader[0]) - 1)
                print(f"[Rank {rank}] Using Wiki batch idx: {idx}")
                w_batch = next(islice(wikidataloader[0], idx, idx + 1))
                w_ids = w_batch["input_ids"].to(device)
                w_mask = w_batch["attention_mask"].to(device)
                with torch.no_grad():
                    orig_logits0 = orig_model(w_ids, attention_mask=w_mask).logits
                PP = F.softmax(orig_logits0, dim=-1)
                loss_retain = F.kl_div(F.log_softmax(model(w_ids, attention_mask=w_mask,).logits, dim=-1), PP, reduction='none')
                loss_retain = (loss_retain.sum(-1) * r_mask).sum()   # token-mean
                print(f"Retain Loss: {loss_retain.item()}")
                retain_loss_hist.append(loss_retain.item())
                grads_retain = torch.autograd.grad(loss_retain + total_loss, trainable_params)
                new_grads = []
                # Gradient Harmonization
                for g_r, g_f in zip(grads_retain, grads_forget):
                # 举例：如果夹角<0 就把 g_f 在 g_r 上投影出去
                    flat_r = g_r.view(-1)
                    flat_f = g_f.view(-1)
                    cos = F.cosine_similarity(g_r.view(-1), g_f.view(-1), dim=0)
                    if cos < 0:
                        proj = (flat_f.dot(flat_r) / flat_r.norm().pow(2)) * flat_r
                        flat_f = flat_f - proj
                    # 合并
                    new_grads.append((flat_r + flat_f).view_as(g_r))
                i = 0
                for name, g_meta in zip(trainable_names, new_grads):
                    perturbed_state[name] = g_meta
                # assign perturbed parameters to model
                for name, p in model.named_parameters():
                    if p.requires_grad:
                        i += 1
                        p.grad = perturbed_state[name]         
                optimizer.step()
                if rank == 0 and step % 10 == 0:
                    print(f"[E{epoch+1}|S{step}] GA Loss={ga_loss.item():.4f}, "
                          f"KL Loss={kl_loss.item():.4f}, "
                          f"Forget Loss={forget_loss_hist[-1]:.4f}, "
                          f"Retain Loss={retain_loss_hist[-1]:.4f}, ")
                    ga_step_hist.append(ga_loss.item())
                    gd_step_hist.append(kl_loss.item())
                    forget_step_hist.append(forget_loss_hist[-1])
                    retain_step_hist.append(retain_loss_hist[-1])
                if rank == 0 and (step% 100 == 0):
                    # 保存模型
                    lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}"
                    os.makedirs(lora_save_dir, exist_ok=True)
                    model.module.save_pretrained(lora_save_dir)
                    tokenizer.save_pretrained(lora_save_dir)
                    print(f"Saving LoRA checkpoint to {lora_save_dir}...")
                if rank == 0 and (step% 1000 == 0):
                    # 保存模型
                    lora_save_dir = args.save_dir + f"/hf_ckpt_{epoch}_{step}"
                    os.makedirs(lora_save_dir, exist_ok=True)
                    model.module.save_pretrained(lora_save_dir)
                    tokenizer.save_pretrained(lora_save_dir)
                    print(f"Saving LoRA checkpoint to {lora_save_dir}...")
                if rank == 0:
                    with open(loss_step_save_path, "w") as f:
                        # 如果文件不存在，则写入表头
                        if f.tell() == 0:
                            f.write("step,ga_loss,gd_loss,forget_loss,retain_loss\n")
                        for steps in range(len(ga_step_hist)):
                            f.write(f"{steps+1},{ga_step_hist[steps]},{gd_step_hist[steps]},"
                                    f"{forget_step_hist[steps]},{retain_step_hist[steps]}\n")
                # 删除梯度，节约显存
                for p in trainable_params:
                    p.grad = None
                print(f"[Rank {rank}] Epoch {epoch+1}, Step {step} done.")
    if rank == 0:
        print("Training done.")
    dist.destroy_process_group()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument('--epochs', type=int, required=True)
    parser.add_argument('--batch_size', type=int, required=True)
    parser.add_argument('--lr', type=float, required=True)
    parser.add_argument('--lora_r', type=int, required=True)
    parser.add_argument('--rho', type=float, required=True)
    parser.add_argument('--mu', type=float, required=True)
    parser.add_argument('--N', type=int, required=True)
    parser.add_argument("--lambda_ga", type=float, required=True)  # GA损失的权重
    parser.add_argument('--gamma', type=float, required=True)
    parser.add_argument("--save_dir", type=str, required=True)
    args = parser.parse_args()
    world = torch.cuda.device_count()
    mp.spawn(main, args=(world, args), nprocs=world)
