import argparse
import os
import random
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer
)
from datasets import load_dataset

# === Parse command-line arguments ===
parser = argparse.ArgumentParser(
    description="Compute and plot Fisher Information Matrix histograms one layer at a time"
)
parser.add_argument('--lr', type=float, default=1e-6, help="Learning rate, e.g., 1e-6")
parser.add_argument('--model_name', type=str, default="Qwen/Qwen2.5-7B",
                    help="Model name or local path")
parser.add_argument('--unlearning_algorithm', type=str, default="GA",
                    help="Unlearning algorithm identifier")
parser.add_argument('--type', choices=['Text','Math'], default="Text",
                    help="Data type: Text or Math")
parser.add_argument('--phase', type=str, default="N1_Request100",
                    help="Unlearning phase identifier")
parser.add_argument('--batch_size', type=int, default=4,
                    help="Batch size for DataLoader")
parser.add_argument('--num_batches', type=int, default=10,
                    help="Number of batches used to estimate FIM")
parser.add_argument('--relearn_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
parser.add_argument('--test_data_type', type=str, default="forget_set", help="forget set or retain set or unrelated")
args = parser.parse_args()

lr                   = args.lr
model_name           = args.model_name
unlearning_algorithm = args.unlearning_algorithm
data_type            = args.type
phase                = args.phase
batch_size           = args.batch_size
num_batches          = args.num_batches
relearn_data_type    = args.relearn_data_type
test_data_type       = args.test_data_type
base = model_name.split('/')[-1]
checkpoint_before   = model_name
checkpoint_after_un = f"Model/{data_type}/all_layer/{base}/lr{lr}_{unlearning_algorithm}_{phase}"
if relearn_data_type == "forget_set":
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}/lr{lr}_all_layers_forget_{phase}"
else:
    checkpoint_after_re = f"Model/recovery/{data_type}/{base}/{unlearning_algorithm}_{relearn_data_type}/lr{lr}_all_layers_forget_{phase}"

# === Random seed & device ===
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Detect number of layers (from config, fallback to model if needed) ===
config = AutoConfig.from_pretrained(checkpoint_before)
num_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer", None)
if num_layers is None:
    tmp = AutoModelForCausalLM.from_pretrained(
        checkpoint_before, torch_dtype=torch.bfloat16,
        trust_remote_code=True, use_flash_attention_2=True, device_map="auto"
    ).to(device)
    num_layers = len(tmp.model.layers)
    del tmp
    torch.cuda.empty_cache()
layers_to_analyze = list(range(0, num_layers, 3))
print(f"Detected {num_layers} layers; will process each layer sequentially.")

# === Prepare data & DataLoader ===
if data_type == "Text":
    forget_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="forget")
    retain_dataset_arxiv = load_dataset("llmunlearn/unlearn_dataset", name="arxiv", split="retain")
    forget_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="forget")
    retain_dataset_github = load_dataset("llmunlearn/unlearn_dataset", name="github", split="retain")
    retain_dataset_general = load_dataset("llmunlearn/unlearn_dataset", name="general", split="retain")
    forget_list_arxiv = [data['text'] for data in forget_dataset_arxiv]
    retain_list_arxiv = [data['text'] for data in retain_dataset_arxiv]
    forget_list_github = [data['text'] for data in forget_dataset_github]
    retain_list_github = [data['text'] for data in retain_dataset_github]
    retain_list_general = [data['text'] for data in retain_dataset_general]
    if test_data_type == "forget_set":
        test_list = forget_list_arxiv + forget_list_github
    elif test_data_type == "retain_set":
        test_list = retain_list_arxiv + retain_list_github 
    elif test_data_type == "unrelated":
        test_list = retain_list_general
    random.shuffle(test_list)
    texts = test_list[:30]
else:
    math_ds = load_dataset("AI-MO/NuminaMath-1.5")
    texts = [
        f"{p} {s} {a}"
        for p, s, a in zip(math_ds["train"]["problem"],
                           math_ds["train"]["solution"],
                           math_ds["train"]["answer"])
    ]
    random.shuffle(texts)
    texts = texts[:30]

tokenizer = AutoTokenizer.from_pretrained(checkpoint_before, use_fast=True)

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        enc = tokenizer(
            texts,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=max_length
        )
        self.input_ids      = enc["input_ids"]
        self.attention_mask = enc["attention_mask"]
        self.labels         = enc["input_ids"]

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids":      self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "labels":         self.labels[idx],
        }

dataset = TextDataset(texts, tokenizer)
loader  = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# === Diagonal Fisher Information Matrix Estimation ===
def compute_fim_diag(model, dataloader, num_batches, layer_key):
    named = list(model.named_parameters())
    params = [
        p for n, p in named
        if p.requires_grad and layer_key in n
    ]
    fim_acc = [torch.zeros_like(p) for p in params]
    for i, batch in enumerate(dataloader):
        if i >= num_batches:
            break
        b = {k: v.to(device) for k, v in batch.items()}
        model.zero_grad()
        out = model(
            input_ids=b["input_ids"],
            attention_mask=b["attention_mask"],
            labels=b["labels"]
        )
        out.loss.backward()
        for j, p in enumerate(params):
            fim_acc[j] += p.grad.detach() ** 2
    divisor = min(len(dataloader), num_batches)
    return torch.cat([f.view(-1).float()/divisor for f in fim_acc]).cpu().numpy()

# === Output directory ===
out_dir = f"Analyse_Figure/{data_type}/{base}_{relearn_data_type}/Fisher_{test_data_type}"
os.makedirs(out_dir, exist_ok=True)

# === Global plot style (scientific) ===
mpl.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['DejaVu Serif'],
    'font.size': 18,
    'axes.titlesize': 20,
    'axes.labelsize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 16,
    'figure.dpi': 300,
    'axes.grid': True,
    'grid.linestyle': '--',
    'grid.linewidth': 0.6,
    'grid.alpha': 0.6
})

# === Colors and styles ===
plot_colors = {"Original": "#d62728", "Unlearned": "#1f77b4", "Relearned": "#2ca02c"}
linestyles  = {"Original": "-", "Unlearned": "--", "Relearned": ":"}

# === Plot histogram for each layer ===
for L in layers_to_analyze:
    print(f"--- Processing layer {L} ---")
    layer_key = f"model.layers.{L}"
    fim_vals = {}

    for tag, ckpt in [
        ("Original", checkpoint_before),
        ("Unlearned", checkpoint_after_un),
        ("Relearned", checkpoint_after_re)
    ]:
        print(f"  Loading {tag} model...")
        model = AutoModelForCausalLM.from_pretrained(
            ckpt,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_flash_attention_2=True,
            device_map="auto"
        ).to(device).eval()

        fim_vals[tag] = compute_fim_diag(model, loader, num_batches, layer_key)

        del model
        torch.cuda.empty_cache()

    # === Plot histogram ===
    fig, ax = plt.subplots(figsize=(5, 3))
    for tag in ["Original", "Unlearned", "Relearned"]:
        ax.hist(
            fim_vals[tag],
            bins=40,
            histtype='step',
            linewidth=2,
            linestyle=linestyles[tag],
            color=plot_colors[tag],
            label=tag
        )

    ax.set_xscale('log')
    ax.set_xlabel("Fisher Diagonal Value (log scale)")
    ax.set_ylabel("Frequency")
    ax.set_title(f"FIM @ Layer {L+1}", pad=16)

    # Adjust y-axis (avoid legend crowding)
    ymax = 0
    for tag in ["Original", "Unlearned", "Relearned"]:
        counts, _ = np.histogram(fim_vals[tag], bins=40)
        ymax = max(ymax, max(counts))
    ax.set_ylim(0, ymax * 1.2)

    # Adjust x-axis for log padding
    x_all = np.concatenate([fim_vals[tag] for tag in ["Original", "Unlearned", "Relearned"]])
    x_min, x_max = np.min(x_all), np.max(x_all)
    x_pad = 1.5 * (np.log10(x_max+0.01) - np.log10(x_min+0.01))
    x_small_pad = 0.1 * (np.log10(x_max+0.01) - np.log10(x_min+0.01))
    ax.set_xlim(10**(np.log10(x_min) - x_small_pad), 10**(np.log10(x_max) + x_pad))

    # === Legend on top right ===
    ax.legend(
        loc="upper right",
        frameon=False,
        fancybox=True
    )

    fig.tight_layout()
    fig.savefig(f"{out_dir}/layer_{L}_lr{lr}_{unlearning_algorithm}_{phase}.pdf", dpi=300, bbox_inches='tight')
    plt.close(fig)
    del fim_vals
    torch.cuda.empty_cache()

print("All layers processed successfully.")
