import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList
from collections import defaultdict
import re
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.lines import Line2D
import pickle

# Helper functions for saliency analysis

# --- Custom logits processor: allow only certain words ---
class OnlyAllowCertainLogitsProcessor(LogitsProcessor):
    def __init__(self, allowed_ids):
        self.allowed_ids = allowed_ids

    # leave input_ids there as it's required by the interface, even though it's not used
    def __call__(self,  input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # scores shape: [batch_size, vocab_size]
        mask = torch.full_like(scores, -float("inf"))
        mask[:, self.allowed_ids] = scores[:, self.allowed_ids]
        return mask

def print_top_saliencies_old(saliency_dict, title, top_k=50):
    lines = [f"\n--- Top {top_k} parameters for {title} ---"]
    non_norm = {k:v for k,v in saliency_dict.items() 
            if "layernorm" not in k.lower()}
    sorted_non_norm = sorted(non_norm.items(), key=lambda x: x[1], reverse=True)
    for name, score in sorted_non_norm[:top_k]:
        lines += [f"{name}: {score:.4e}"]

    print("\n".join(lines))
    return lines

def print_top_saliencies(saliency_dict, title, top_k=50):
    """
    Prints—and returns as a list of lines—the top_k parameters by saliency.
    Accepts values that are either floats or torch.Tensors.
    """
    # 1) Build a scalar map
    scalar_sal = {}
    for name, val in saliency_dict.items():
        if "layernorm" in name.lower():
            continue
        if isinstance(val, torch.Tensor):
            # val is abs-gradient per-element → reduce to a single number
            scalar_sal[name] = val.mean().item()
        else:
            scalar_sal[name] = float(val)

    # 2) Sort descending
    sorted_params = sorted(
        scalar_sal.items(),
        key=lambda x: x[1],
        reverse=True
    )

    # 3) Format lines
    lines = [f"\n--- Top {top_k} parameters for {title} ---"]
    for name, score in sorted_params[:top_k]:
        lines.append(f"{name}: {score:.4e}")

    # 4) Print & return
    print("\n".join(lines))
    return lines

# def save_saliency_npz(saliency_dict, npz_path):
#     np.savez_compressed(npz_path, **{
#         f"{drug.replace(' ', '_')}": {
#             k: v.cpu().numpy() if isinstance(v, torch.Tensor) else np.array(v)
#             for k, v in sal.items()
#         }
#         for drug, sal in saliency_dict.items()
#     })

def merge_saliency_dicts(saliency_dicts, model=None):
    """
    Takes a dict of dicts: {prompt -> {param_name -> tensor}}, and averages across names.
    Returns:
        - avg_sal: {param_name -> scalar mean of tensor}
        - avg_sal_per_prompt: {prompt -> {param_name -> scalar mean of tensor}}
        - avg_sal_per_head: {(layer_idx, head_idx) -> scalar mean saliency}
        - avg_sal_per_mlp: {layer_idx -> scalar mean saliency for MLP matrices}
    """

    merged = {}
    all_keys = set(k for d in saliency_dicts.values() for k in d)

    for key in all_keys:
        vals = []
        for inner in saliency_dicts.values():
            val = inner.get(key, torch.tensor(0.0))
            val = val.float() if isinstance(val, torch.Tensor) else torch.tensor(val, dtype=torch.float32)
            vals.append(val)
        merged[key] = sum(vals) / len(vals)

    avg_sal = {
        name: tensor.mean().item() if isinstance(tensor, torch.Tensor) else float(tensor)
        for name, tensor in merged.items()
    }

    avg_sal_per_prompt = {
        prompt: {
            name: tensor.mean().item() if isinstance(tensor, torch.Tensor) else float(tensor)
            for name, tensor in inner.items()
        }
        for prompt, inner in saliency_dicts.items()
    }

    # Compute per-attention-head and per-MLP saliency if model is provided
    avg_sal_per_head = {}
    avg_sal_per_mlp = {}

    if model is not None:
        # IMPORTANT: average per-prompt head/MLP scores, not the merged tensor averages
        per_head_accum = defaultdict(list)
        per_mlp_accum  = defaultdict(list)

        for prompt, inner in saliency_dicts.items():
            h_scores, m_scores, _ = aggregate_saliency_to_heads_and_mlps(inner, model)
            for key, val in h_scores.items():
                per_head_accum[key].append(float(val))
            for key, val in m_scores.items():
                per_mlp_accum[key].append(float(val))

        # Mean across prompts
        if per_head_accum:
            avg_sal_per_head = { key: float(np.mean(vals)) for key, vals in per_head_accum.items() if len(vals) > 0 }
        if per_mlp_accum:
            avg_sal_per_mlp  = { key: float(np.mean(vals)) for key, vals in per_mlp_accum.items() if len(vals) > 0 }

    return avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp

# def save_saliency_npz(npz_path, saliency_dicts, avg_sal):
#     """
#     Save saliency results to a compressed .npz file.

#     Each saliency dict is reduced to param → scalar (mean value).
#     This greatly reduces file size and allows posthoc aggregation (e.g., std).
#     """
#     data_to_save = {}

#     for name, sal in saliency_dicts.items():
#         safe_name = str(name).replace(" ", "_")
#         scalar_dict = {
#             k: v.mean().item() if isinstance(v, torch.Tensor) else float(np.mean(v))
#             for k, v in sal.items()
#         }
#         data_to_save[safe_name] = scalar_dict

#     # if avg_sal is not None:
#     data_to_save["avg_sal"] = avg_sal  # already scalars

#     # Flatten for .npz format
#     flattened = {}
#     for outer_key, inner_dict in data_to_save.items():
#         for param_name, scalar in inner_dict.items():
#             flat_key = f"{outer_key}__{param_name}"
#             flattened[flat_key] = scalar

#     np.savez(npz_path, **flattened)
#     print(f"Saved scalar saliency data to {npz_path}")

# save it using pickle instead of npz
def save_saliency(file_path, avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp):
    data_to_save = {"avg_sal": avg_sal, "avg_sal_per_prompt": avg_sal_per_prompt, "avg_sal_per_head": avg_sal_per_head, "avg_sal_per_mlp": avg_sal_per_mlp}
        
    with open(file_path, "wb") as f:
        pickle.dump(data_to_save, f)


# load_saliency_npz(npz_path)
# def load_saliency_npz(npz_path):
#     data = np.load(npz_path)
#     print(data.keys())
#     return data, data["avg_sal"]

def load_saliency(file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data["saliency_dicts"], data["avg_sal"]

def aggregate_saliency_to_heads_and_mlps(
    grad_tensors: dict[str, float],
    model: any
) -> tuple[dict[tuple[int,int], float], dict[int, float]]:
    """
    Args:
        grad_tensors: mapping from parameter names (e.g.
            "model.layers.3.self_attn.v_proj.weight") to either
            a scalar saliency (float) or a per-element torch.Tensor
            of the same shape as that parameter.
        model:        a HuggingFace LlamaForCausalLM instance.

    Returns:
        head_scores: {(layer_idx, head_idx) -> normalized L1 saliency}
        mlp_scores:  {layer_idx -> normalized L1 saliency}
    """

    # 1) grab shapes and head count
    param_shapes = {n: p.shape for n, p in model.named_parameters()}
    cfg = getattr(model, "config", getattr(model, "model").config)
    # Gemma-3 stores text heads in text_config
    if hasattr(cfg, "text_config"):
        num_heads = getattr(cfg.text_config, "num_attention_heads", None) or getattr(cfg.text_config, "num_heads", None)
    else:
        num_heads = getattr(cfg, "num_attention_heads", None) or getattr(cfg, "num_heads", None)

    head_scores = defaultdict(float)
    mlp_scores  = defaultdict(float)

    head_re = re.compile(r"model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight")
    mlp_re  = re.compile(r"model\.layers\.(\d+)\.mlp\.(gate_proj|up_proj|down_proj)\.weight")

    for name, sal in grad_tensors.items():
        # skip unknown params
        if name not in param_shapes:
            continue

        # build a full tensor if needed
        shape = param_shapes[name]
        if isinstance(sal, torch.Tensor):
            grad = sal
        else:
            grad = torch.full(shape, float(sal), dtype=torch.float32)

        # — Attention heads —
        m = head_re.match(name)
        # print(f"m: {m}")
        if m:
            layer = int(m.group(1))
            proj  = m.group(2)
            out_dim, in_dim = grad.shape

            if proj in ("q", "o"):
                # these are concatenated across heads
                head_dim = out_dim // num_heads
                chunks = grad.view(num_heads, head_dim, in_dim)
                # normalized L1 per head
                per_head = chunks.abs().sum(dim=(1,2)) / (head_dim * in_dim)
                for h, score in enumerate(per_head.tolist()):
                    head_scores[(layer, h)] += score
            else:
                # k/v each project once: distribute equally
                avg_score = grad.abs().sum().item() / (out_dim * in_dim)
                for h in range(num_heads):
                    head_scores[(layer, h)] += avg_score

            continue

        # — MLP blocks —
        m2 = mlp_re.match(name)
        if m2:
            layer = int(m2.group(1))
            # normalized L1
            mlp_scores[layer] += grad.abs().sum().item() / grad.numel()
            continue

        # else: ignore layernorms, embed_tokens, lm_head, etc.

    # print scores for each layer
    # after computing head_scores and mlp_scores:
    lines = ['Aggregating information per attention head and MLP ...']
    # Gemma-3 language layers count lives under text_config
    if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "num_hidden_layers"):
        num_layers = model.config.text_config.num_hidden_layers
    else:
        num_layers = model.config.num_hidden_layers
    # recompute num_heads consistently
    cfg = getattr(model, "config", getattr(model, "model").config)
    if hasattr(cfg, "text_config"):
        num_heads = getattr(cfg.text_config, "num_attention_heads", None) or getattr(cfg.text_config, "num_heads", None)
    else:
        num_heads = getattr(cfg, "num_attention_heads", None) or getattr(cfg, "num_heads", None)

    for layer in range(num_layers):
        lines.append(f"Layer {layer}:")
        # collect every head’s score for this layer
        layer_head_scores = [
            (h, head_scores.get((layer, h), 0.0))
            for h in range(num_heads)
        ]
        # filter out zeroes
        non_zero_heads = [(h, s) for h, s in layer_head_scores if s != 0.0]

        if non_zero_heads:
            for h, score in non_zero_heads:
                lines.append(f"  Head {h}: {score:.6f}")
        else:
            lines.append("  (no non-zero head saliencies)")

        # always show the MLP
        mlp_score = mlp_scores.get(layer, 0.0)
        lines.append(f"  MLP score: {mlp_score:.6f}")

    print("\n".join(lines))
    return dict(head_scores), dict(mlp_scores), lines
