import matplotlib
from matplotlib.ticker import MultipleLocator
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from torch.utils.data import DistributedSampler
from utils.data_loader_wmdp import get_train_data_ddp
plt.switch_backend('agg')
matplotlib.rcParams['pdf.fonttype'] = 42
font = {'size': 18, 'family': 'Arial'}
plt.rc('font', **font)
# 🔧 计算整个 dataloader 上的平均 loss（用于绘图）
def compute_dataset_loss(model, dataloader):
    model.eval()
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for _ in range(2):
            batch = next(iter(dataloader))
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            labels = batch['input_ids'].to(model.device)  # 默认 labels = input_ids
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_loss += outputs.loss.item()
            count += 1
    return total_loss / max(count, 1)

# 🔧 主绘图函数：仅扰动 LoRA 参数，在给定 dataloader 上计算 loss 曲面
def visualize_loss_surface(model: PeftModel, method, dataloader, steps=21, radius=0.1):
    device = model.device

    # 提取 LoRA 参数向量
    lora_params = []
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True  # 确保 LoRA 参数可训练
            lora_params.append(param.view(-1))
    lora_vector = torch.cat(lora_params).detach().to(device)

    # 正交扰动方向
    dim = lora_vector.numel()
    d1 = torch.randn(dim, device=device)
    d2 = torch.randn(dim, device=device)
    d2 -= d1 * torch.dot(d1, d2) / torch.dot(d1, d1)
    d1 = d1 / torch.norm(d1)
    d2 = d2 / torch.norm(d2)

    alphas = betas = np.linspace(-radius, radius, steps)
    loss_surface = np.zeros((steps, steps))

    # LoRA 参数更新函数
    def set_lora_vector(vec):
        pointer = 0
        for name, param in model.named_parameters():
            if 'lora' in name and param.requires_grad:
                numel = param.numel()
                param.data.copy_(vec[pointer:pointer+numel].view_as(param))
                pointer += numel

    # 扫描二维扰动空间
    for i, alpha in enumerate(alphas):
        for j, beta in enumerate(betas):
            perturbed = lora_vector + alpha * d1 + beta * d2
            set_lora_vector(perturbed)
            loss_surface[i, j] = compute_dataset_loss(model, dataloader)
            print(f"({i},{j}) loss = {loss_surface[i,j]:.4f}")

    # 恢复原始参数
    set_lora_vector(lora_vector)

    # 绘图
    X, Y = np.meshgrid(alphas, betas)
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    print(X)
    print(Y)
    print(loss_surface)
    ax.plot_surface(X, Y, loss_surface, cmap='viridis')
    ax.set_xlabel("Perturb α (Dir 1)", labelpad=15)
    ax.set_ylabel("Perturb β (Dir 2)", labelpad=15)
    ax.set_zlabel("Loss", labelpad=15)
    ax.xaxis.set_major_locator(MultipleLocator(0.5))
    ax.yaxis.set_major_locator(MultipleLocator(0.5))
    zmin = np.min(loss_surface)
    zmax = np.max(loss_surface)
    ax.set_zlim(zmin, zmax)
    ax.zaxis.set_major_locator(MultipleLocator(0.2))
    
    plt.figure(figsize=(8, 8))
    levels = np.arange(np.min(loss_surface), np.max(loss_surface) + 0.1, 0.1)
    cp = plt.contourf(X, Y, loss_surface, levels=levels, cmap='viridis')
    plt.colorbar(cp, label="Loss")
    plt.title("LoRA Parameter Loss Contour (Forget Set)")
    plt.xlabel("Perturb α (Dir 1)")
    plt.ylabel("Perturb β (Dir 2)")
    plt.grid(True)
    plt.savefig(f"lora_loss_surface_forget_{method}.pdf")
    plt.show()

# 🔧 主入口：载入模型 + 数据 + 调用绘图函数
if __name__ == "__main__":
    torch.set_grad_enabled(False)
    torch.manual_seed(42)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 模型路径设置
    model_name = "your_base_model_name"  # e.g., "mistralai/Mistral-7B-v0.1"
    method = "ga_7"
    checkpoint_path = "your_lora_checkpoint_path"  # e.g., "/data/wwh/llmUN/relearn/zephyr-7b-beta/ga_7"
    token = "your_huggingface_access_token"

    # 加载模型（LoRA未 merge）
    base = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, token=token).to(device)
    model = PeftModel.from_pretrained(base, checkpoint_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

    # 加载遗忘数据集
    forget_corpora = ["bio_forget"]
    retain_corpora = ["bio-retain-corpus"]
    forget_loaders, _ = get_train_data_ddp(
        forget_corpora, retain_corpora, tokenizer,
        batch_size=4,
        sampler_cls=DistributedSampler,
        world_size=1, rank=0
    )
    forget_loader = forget_loaders[0]

    # 绘图
    visualize_loss_surface(model, method, forget_loader, steps=21, radius=1)