import argparse
import os
import random
import torch
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.decomposition import PCA
from datasets import load_dataset

# === Parse command-line arguments ===
parser = argparse.ArgumentParser(description="PCA similarity analysis after unlearning/relearning")
parser.add_argument('--lr', type=float, default=3e-5, help="Learning rate, e.g. 3e-5")
parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
parser.add_argument('--unlearning_algorithm', type=str, default="GA", help="Unlearning algorithm name")
parser.add_argument('--type', choices=['Text', 'Math'], default="Text", help="Data type: Text or Math")
parser.add_argument('--phase', type=str, default="N1_Request100", help="Unlearning phase identifier")
parser.add_argument('--relearn_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
parser.add_argument('--test_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
args = parser.parse_args()

lr                   = args.lr
model_name           = args.model_name
unlearning_algorithm = args.unlearning_algorithm
data_type            = args.type
phase                = args.phase
relearn_data_type    = args.relearn_data_type
test_data_type       = args.test_data_type

# === Build checkpoint paths ===
checkpoint_before   = model_name
base = model_name.split('/')[-1]
checkpoint_after_un = f"Model/{data_type}/all_layer/{base}/lr{lr}_{unlearning_algorithm}_{phase}"
if relearn_data_type == "forget_set":
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}/lr{lr}_all_layers_forget_{phase}"
else:
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}_{relearn_data_type}/lr{lr}_all_layers_forget_{phase}"

# === Random seed & device ===
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load data ===
if data_type == "Text":
    forget_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    retain_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    forget_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    retain_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")
    forget_list_arxiv = [data['text'] for data in forget_dataset_arxiv]
    retain_list_arxiv = [data['text'] for data in retain_dataset_arxiv]
    forget_list_github = [data['text'] for data in forget_dataset_github]
    retain_list_github = [data['text'] for data in retain_dataset_github]
    retain_list_general = [data['text'] for data in retain_dataset_general]

    if test_data_type == "forget_set":
        test_list = forget_list_arxiv + forget_list_github
    elif test_data_type == "retain_set":
        test_list = retain_list_arxiv + retain_list_github
    elif test_data_type == "unrelated":
        test_list = retain_list_general

    random.shuffle(test_list)
    # Take the first 30 examples for analysis
    texts = test_list[:30]

elif data_type == "Math":
    # === Math scenario: NuminaMath ===
    dataset = load_dataset("AI-MO/NuminaMath-1.5")
    problems = dataset["train"]["problem"]
    solutions = dataset["train"]["solution"]
    answers = dataset["train"]["answer"]

    # Method A: Combine each problem into a triple "problem solution answer"
    forget_list = [
        f"{p} {s} {a}"
        for p, s, a in zip(problems, solutions, answers)
    ]

    # Optional: Combine multiple items into a long sequence (e.g., every N items)
    N = 10
    grouped = [
        "  ".join(forget_list[i : i + N])
        for i in range(0, len(forget_list), N)
    ]
    forget_list = grouped

    random.shuffle(forget_list)
    # Take the first 30 grouped sequences
    texts = forget_list[:30]

# === Feature extraction function ===
def extract_features(model, tokenizer, texts, layer_idx):
    inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    ).to(model.device)
    outputs = model(**inputs, output_hidden_states=True)
    hs = outputs.hidden_states[layer_idx].to(torch.float32)
    return hs.mean(dim=1).detach().cpu().numpy()

# === Load tokenizer and model ===
tokenizer = AutoTokenizer.from_pretrained(checkpoint_before, trust_remote_code=True)

def load_and_move(path):
    m = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_flash_attention_2=True,
        device_map="auto"
    )
    return m.to(device)

model_before   = load_and_move(checkpoint_before)
model_after_un = load_and_move(checkpoint_after_un)
model_after_re = load_and_move(checkpoint_after_re)

# === Compute cosine similarity of PCA first principal component ===
num_layers = model_before.config.num_hidden_layers + 1
pc_components = {}

for state, model in [
    ("before",   model_before),
    ("after_un", model_after_un),
    ("after_re", model_after_re)
]:
    for layer in range(num_layers):
        feats = extract_features(model, tokenizer, texts, layer)
        comp0 = PCA(n_components=2).fit(feats).components_[0]
        pc_components[(state, layer)] = comp0

rows = []
for layer in range(num_layers):
    v0 = pc_components[("before",   layer)]
    v1 = pc_components[("after_un", layer)]
    v2 = pc_components[("after_re", layer)]
    cos_un = np.dot(v0, v1) / (np.linalg.norm(v0) * np.linalg.norm(v1))
    cos_re = np.dot(v0, v2) / (np.linalg.norm(v0) * np.linalg.norm(v2))
    rows.append({"layer": layer, "before_after_un": cos_un, "before_after_re": cos_re})

df = pd.DataFrame(rows)

# === Set scientific style + DejaVu Serif font (alternative to Calibri) ===
mpl.rcParams.update({
    'font.family':        'serif',
    'font.serif':         ['DejaVu Serif'],
    'font.size':          18,
    'axes.titlesize':     20,
    'axes.labelsize':     18,
    'xtick.labelsize':    16,
    'ytick.labelsize':    16,
    'lines.linewidth':    2.0,
    'lines.markersize':   8,
    'axes.linewidth':     1.2,
    'axes.spines.top':    False,
    'axes.spines.right':  False,
    'axes.grid':          True,
    'grid.linestyle':     '--',
    'grid.linewidth':     0.5,
    'grid.alpha':         0.7,
    'legend.frameon':     False,
    'legend.fontsize':    16,
    'axes.prop_cycle':    mpl.cycler('color', ['#0072B2', '#D55E00'])
})

# === Plotting ===
plt.figure(figsize=(5, 3))
layers = df["layer"]

# === Main lines ===
plt.plot(layers, df["before_after_un"], linestyle='-', color='#0072B2', label="Unlearned")
plt.plot(layers, df["before_after_re"], linestyle='--', color='#D55E00', label="Relearned")

# === Add marker every N points ===
marker_freq = 5
marker_idx = [i for i in range(len(layers)) if i % marker_freq == 0]

plt.plot(layers.iloc[marker_idx], df["before_after_un"].iloc[marker_idx],
         'o', color='#0072B2', markersize=6)
plt.plot(layers.iloc[marker_idx], df["before_after_re"].iloc[marker_idx],
         '^', color='#D55E00', markersize=6)

# === Axis labels and title ===
plt.xlabel("Layer")
plt.ylabel("Cosine Similarity")
plt.title("PCA Similarity", pad=12)
plt.ylim(-1, 3)
plt.grid(True)

# === Legend ===
plt.legend(
    loc="best",
    frameon=False,
    fancybox=True
)

plt.tight_layout()

# === Save as high-quality PDF ===
save_path = f"Analyse_Figure/{data_type}/{base}_{relearn_data_type}/PCA_{test_data_type}"
os.makedirs(save_path, exist_ok=True)
plt.savefig(f"{save_path}/Sim_lr{lr}_{unlearning_algorithm}_{phase}.pdf", dpi=300, bbox_inches='tight')
plt.show()
