import argparse
import random
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from sklearn.decomposition import PCA
import matplotlib as mpl
# === Parse command-line arguments ===
parser = argparse.ArgumentParser(description="Compute FIM and CKA analyses")
parser.add_argument('--lr', type=float, default=1e-6, help="Learning rate, e.g., 1e-6")
parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
parser.add_argument('--unlearning_algorithm', type=str, default="GA", help="Unlearning algorithm identifier")
parser.add_argument('--type', choices=['Text','Math'], default="Text", help="Data type: Text or Math")
parser.add_argument('--phase', type=str, default="N1_Request100", help="Unlearning phase identifier")
parser.add_argument('--batch_size', type=int, default=4, help="Batch size for DataLoader")
parser.add_argument('--num_batches', type=int, default=10, help="Number of batches to estimate FIM and CKA")
parser.add_argument('--relearn_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
parser.add_argument('--test_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
args = parser.parse_args()

# === Hyperparameters & checkpoint paths ===
lr                   = args.lr
model_name           = args.model_name
unlearning_algorithm = args.unlearning_algorithm
data_type            = args.type
phase                = args.phase
batch_size           = args.batch_size
num_batches          = args.num_batches
relearn_data_type    = args.relearn_data_type
test_data_type       = args.test_data_type
base = model_name.split('/')[-1]
checkpoint_before    = model_name
checkpoint_after_un  = f"Model/{data_type}/all_layer/{base}/lr{lr}_{unlearning_algorithm}_{phase}"
if relearn_data_type == "forget_set":
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}/lr{lr}_all_layers_forget_{phase}"
else:
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}_{relearn_data_type}/lr{lr}_all_layers_forget_{phase}"

# === Random seed & device ===
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load Tokenizer & Models ===
tokenizer = AutoTokenizer.from_pretrained(checkpoint_before, use_fast=True)

def load_model(path):
    m = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_flash_attention_2=True,
        device_map="auto"
    )
    return m.to(device).eval()

model_before   = load_model(checkpoint_before)
model_after_un = load_model(checkpoint_after_un)
model_after_re = load_model(checkpoint_after_re)
num_layers = len(model_before.model.layers)
layers_to_analyze = list(range(num_layers))

# === Load Data & Sample ===
if data_type == "Text":
    forget_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    retain_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    forget_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    retain_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")
    forget_list_arxiv = [data['text'] for data in forget_dataset_arxiv]
    retain_list_arxiv = [data['text'] for data in retain_dataset_arxiv]
    forget_list_github = [data['text'] for data in forget_dataset_github]
    retain_list_github = [data['text'] for data in retain_dataset_github]
    retain_list_general = [data['text'] for data in retain_dataset_general]
    if test_data_type == "forget_set":
        test_list = forget_list_arxiv + forget_list_github
    elif test_data_type == "retain_set":
        test_list = retain_list_arxiv + retain_list_github 
    elif test_data_type == "unrelated":
        test_list = retain_list_general
    random.shuffle(test_list)
    texts = test_list[:30]

elif data_type == "Math":
    dataset = load_dataset("AI-MO/NuminaMath-1.5")
    problems  = dataset["train"]["problem"]
    solutions = dataset["train"]["solution"]
    answers   = dataset["train"]["answer"]
    forget_list = [
        f"{p} {s} {a}"
        for p, s, a in zip(problems, solutions, answers)
    ]
    N = 10
    grouped = [
        "  ".join(forget_list[i : i + N])
        for i in range(0, len(forget_list), N)
    ]
    forget_list = grouped
    random.shuffle(forget_list)
    texts = forget_list[:30]

# === Tokenized Dataset ===
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, model, max_length=128):
        enc = tokenizer(texts, return_tensors="pt", truncation=True,
                        padding="max_length", max_length=max_length).to(model.device)
        self.input_ids      = enc["input_ids"]
        self.attention_mask = enc["attention_mask"]
        self.labels         = enc["input_ids"]
    def __len__(self):
        return self.input_ids.size(0)
    def __getitem__(self, idx):
        return {
            "input_ids":      self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels":         self.labels[idx],
        }

dataset = TextDataset(texts, tokenizer, model_before)
loader  = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# === CKA Calculation ===
def center_gram(K):
    n = K.shape[0]; unit = np.ones((n,n))/n
    return K - unit@K - K@unit + unit@K@unit

def linear_cka(X, Y):
    Xc = X - X.mean(0,keepdims=True)
    Yc = Y - Y.mean(0,keepdims=True)
    Kx, Ky = Xc@Xc.T, Yc@Yc.T
    hsic = np.trace(center_gram(Kx)@center_gram(Ky))
    denom = np.sqrt(np.trace(center_gram(Kx)@center_gram(Kx)) *
                    np.trace(center_gram(Ky)@center_gram(Ky))+1e-12)
    return float(hsic/denom)

activations = {"Original":{}, "Unlearned":{}, "Relearned":{}}
for tag, mdl in [("Original",model_before),("Unlearned",model_after_un),("Relearned",model_after_re)]:
    for L in layers_to_analyze:
        buf=[]
        def hook_fn(module, inp, out):
            t = out[0] if isinstance(out,tuple) else out
            buf.append(t[:,0,:].float().detach().cpu().numpy())
        h= dict(mdl.named_modules())[f"model.layers.{L}"].register_forward_hook(hook_fn)
        with torch.no_grad():
            for i,batch in enumerate(loader):
                if i>=num_batches: break
                mdl(batch["input_ids"])
        h.remove()
        activations[tag][L]=np.concatenate(buf,axis=0)

cka_results = {
    L:{
        "Orig–Un": linear_cka(activations["Original"][L], activations["Unlearned"][L]),
        "Orig–Re": linear_cka(activations["Original"][L], activations["Relearned"][L])
    }
    for L in layers_to_analyze
}

# === Plot settings ===
mpl.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['DejaVu Serif'],
    'font.size': 18,
    'axes.titlesize': 22,
    'axes.labelsize': 18,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'lines.linewidth': 2,
    'lines.markersize': 8,
    'axes.linewidth': 1.2,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.linestyle': '--',
    'grid.linewidth': 0.6,
    'grid.alpha': 0.6,
    'legend.frameon': True,
    'legend.fontsize': 16,
    'legend.title_fontsize': 16,
})

# === Scientific color palette ===
plot_colors = {
    "Original": "#000000",
    "Unlearned": "#1b9e77",
    "Relearned": "#d95f02"
}
linestyles = {
    "Original": "-",
    "Unlearned": "--",
    "Relearned": ":"
}

# === Plotting ===
fig, ax = plt.subplots(figsize=(5, 3))
layers = layers_to_analyze
y_un = [cka_results[L]["Orig–Un"] for L in layers]
y_re = [cka_results[L]["Orig–Re"] for L in layers]

marker_freq = 5
marker_indices = [i for i in range(len(layers)) if i % marker_freq == 0]

ax.plot(layers, y_un, linestyle='--', linewidth=2,
        color=plot_colors["Unlearned"], label="Unlearned")
ax.plot(layers, y_re, linestyle=':', linewidth=2,
        color=plot_colors["Relearned"], label="Relearned")

ax.plot([layers[i] for i in marker_indices],
        [y_un[i] for i in marker_indices],
        'o', color=plot_colors["Unlearned"], markersize=6)
ax.plot([layers[i] for i in marker_indices],
        [y_re[i] for i in marker_indices],
        's', color=plot_colors["Relearned"], markersize=6)

ax.set_xticks([L for i, L in enumerate(layers) if i % 5 == 0])
ax.set_xticklabels([str(L) for i, L in enumerate(layers) if i % 5 == 0])
ax.set_xlabel("Layer index")
ax.set_ylabel("Linear CKA")
ax.set_title("CKA", pad=12)
plt.ylim(-1, 3)

ax.legend(
    loc="best",
    frameon=False,
    fancybox=True
)

fig.tight_layout()

# === Save to PDF ===
out_dir = f"Analyse_Figure/{data_type}/{base}_{relearn_data_type}/CKA_{test_data_type}"
os.makedirs(out_dir, exist_ok=True)
fig.savefig(f"{out_dir}/lr{lr}_{unlearning_algorithm}_{phase}.pdf", dpi=300, bbox_inches='tight')
plt.show()
