import argparse
import os
import random
import torch
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.decomposition import PCA
from datasets import load_dataset

# === Parse command-line arguments ===
parser = argparse.ArgumentParser(description="Compute and plot PCA Shift vs Principle")
parser.add_argument('--lr', type=float, default=3e-5, help="Learning rate, e.g. 3e-5")
parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B", help="Model name or path")
parser.add_argument('--unlearning_algorithm', type=str, default="GA", help="Unlearning algorithm identifier")
parser.add_argument('--type', choices=['Text','Math'], default="Text", help="Data type: Text or Math")
parser.add_argument('--phase', type=str, default="N1_Request100", help="Unlearning phase identifier")
parser.add_argument('--relearn_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
parser.add_argument('--test_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
args = parser.parse_args()

lr                   = args.lr
model_name           = args.model_name
unlearning_algorithm = args.unlearning_algorithm
data_type            = args.type
phase                = args.phase
relearn_data_type    = args.relearn_data_type
test_data_type       = args.test_data_type

# === Build checkpoint path ===
base = model_name.split('/')[-1]
checkpoint_before   = model_name
checkpoint_after_un = f"Model/{data_type}/all_layer/{base}/lr{lr}_{unlearning_algorithm}_{phase}"
if relearn_data_type == "forget_set":
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}/lr{lr}_all_layers_forget_{phase}"
else:
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}_{relearn_data_type}/lr{lr}_all_layers_forget_{phase}"

# === Set random seed and device ===
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load text data ===
if data_type == "Text":
    forget_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    retain_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    forget_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    retain_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")
    forget_list_arxiv = [data['text'] for data in forget_dataset_arxiv]
    retain_list_arxiv = [data['text'] for data in retain_dataset_arxiv]
    forget_list_github = [data['text'] for data in forget_dataset_github]
    retain_list_github = [data['text'] for data in retain_dataset_github]
    retain_list_general = [data['text'] for data in retain_dataset_general]

    if test_data_type == "forget_set":
        test_list = forget_list_arxiv + forget_list_github
    elif test_data_type == "retain_set":
        test_list = retain_list_arxiv + retain_list_github
    elif test_data_type == "unrelated":
        test_list = retain_list_general

    random.shuffle(test_list)
    # Take the first 30 samples for analysis
    texts = test_list[:30]

elif data_type == "Math":
    # === Math task scenario: NuminaMath ===
    dataset   = load_dataset("AI-MO/NuminaMath-1.5")
    problems  = dataset["train"]["problem"]
    solutions = dataset["train"]["solution"]
    answers   = dataset["train"]["answer"]

    # Method A: Each question is combined into a triplet "problem solution answer"
    forget_list = [
        f"{p} {s} {a}"
        for p, s, a in zip(problems, solutions, answers)
    ]

    # Optionally group multiple questions into one long text (e.g., every N questions)
    N = 2000
    grouped = [
        "  ".join(forget_list[i : i + N])
        for i in range(0, len(forget_list), N)
    ]
    forget_list = grouped

    random.shuffle(forget_list)
    # Take the first 30 grouped texts
    texts = forget_list[:30]

# === Feature extraction function ===
def extract_features(model, tokenizer, texts, layer_idx):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128).to(model.device)
    outputs = model(**inputs, output_hidden_states=True)
    hs = outputs.hidden_states[layer_idx].to(torch.float32)
    return hs.mean(dim=1).detach().cpu().numpy()

# === Load Tokenizer and three models ===
tokenizer = AutoTokenizer.from_pretrained(checkpoint_before, trust_remote_code=True)

def load_and_move(path):
    m = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_flash_attention_2=True,
        device_map="auto"
    )
    return m.to(device)

model_before   = load_and_move(checkpoint_before)
model_after_un = load_and_move(checkpoint_after_un)
model_after_re = load_and_move(checkpoint_after_re)

# === Compute Shift vs Principle layer by layer ===
records = []
num_layers = model_before.config.num_hidden_layers + 1

for layer in range(num_layers):
    feats_o = extract_features(model_before, tokenizer, texts, layer)
    feats_u = extract_features(model_after_un, tokenizer, texts, layer)
    feats_r = extract_features(model_after_re, tokenizer, texts, layer)

    pca = PCA(n_components=2).fit(feats_o)
    comp1, comp2 = pca.components_[0], pca.components_[1]

    pc1_o, pc2_o = feats_o.dot(comp1).mean(), feats_o.dot(comp2).mean()
    pc1_u, pc2_u = feats_u.dot(comp1).mean(), feats_u.dot(comp2).mean()
    pc1_r, pc2_r = feats_r.dot(comp1).mean(), feats_r.dot(comp2).mean()

    records.append({"layer": layer, "state": "Original",  "shift": 0.0,         "principle": pc2_o})
    records.append({"layer": layer, "state": "Unlearned", "shift": pc1_u - pc1_o, "principle": pc2_u})
    records.append({"layer": layer, "state": "Relearned", "shift": pc1_r - pc1_o, "principle": pc2_r})

df = pd.DataFrame(records)

# === Set plot style ===
mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['DejaVu Serif'],  # Use 'Calibri' if available
    'font.size': 18,
    'axes.titlesize': 20,
    'axes.labelsize': 18,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'lines.linewidth': 2,
    'lines.markersize': 8,
    'axes.linewidth': 1.2,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.linestyle': '--',
    'grid.linewidth': 0.6,
    'grid.alpha': 0.6,
    'legend.frameon': True,
    'legend.fontsize': 16,
    'legend.title_fontsize': 16,
    'axes.prop_cycle': mpl.cycler('color', ['#0072B2', '#D55E00', '#009E73']),
})

# === Plotting ===
fig, ax = plt.subplots(figsize=(5, 3))

# Gray lines connecting states
for layer in df["layer"].unique():
    sub = df[df["layer"] == layer].sort_values("state")
    ax.plot(sub["shift"], sub["principle"], color="gray", linewidth=1, alpha=0.5, zorder=1)

# Colored markers
markers = {"Original": "o", "Unlearned": "^", "Relearned": "s"}
for state in df["state"].unique():
    sub = df[df["state"] == state]
    ax.scatter(sub["shift"], sub["principle"], marker=markers[state],
               label=state, edgecolors="black", zorder=2)

# # Layer number labels (optional)
# for _, r in df.iterrows():
#     ax.text(r["shift"], r["principle"], str(int(r["layer"])),
#             fontsize=12, ha="center", va="center", color="black")

# Axis and title
ax.set_xlabel("(PC1 Δ)")
ax.set_ylabel("(PC2)")
ax.set_title("PCA Shift", pad=12)

# Auto padding for axis limits
x_pad = 1.5 * (df["shift"].max() - df["shift"].min())
y_pad = 0.1 * (df["principle"].max() - df["principle"].min())
ax.set_xlim(df["shift"].min() - 0.05 * (df["shift"].max() - df["shift"].min()), df["shift"].max() + x_pad)
ax.set_ylim(df["principle"].min() - 0.05 * (df["shift"].max() - df["shift"].min()), df["principle"].max() + y_pad)

# Legend styling and placement
ax.legend(
    loc="upper center",
    frameon=False,
    bbox_to_anchor=(0.7, 1.2),
    fancybox=True
)

plt.tight_layout()

# Save as PDF (high-quality vector graphic)
save_path = f"Analyse_Figure/{data_type}/{base}_{relearn_data_type}/PCA_{test_data_type}"
os.makedirs(save_path, exist_ok=True)
plt.savefig(f"{save_path}/Shift_lr{lr}_{unlearning_algorithm}_{phase}.pdf", dpi=300, bbox_inches='tight')
plt.show()
