import matplotlib.pyplot as plt
import seaborn as sns
import random

import torch

import re
from typing import Dict, Any, Union


def get_gpt_grouping(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    """
    Groups a GPT-style state_dict into:
      - embedding
      - unembedding
      - layernorms (final + per-block)
      - attention in/out
      - MLP in/out
    Also extracts metadata:
      block_size, n_layers, n_embed (d_model), n_head, head_size.

    Example return structure:
    {
      "metadata": {
         "block_size": int or None,
         "n_layers": int,
         "n_embed": int or None,
         "n_head": int or None,
         "head_size": int or None,
      },
      "embedding": ["transformer.wte.weight"],
      "unembedding": ["lm_head.weight"],
      "layernorms": {
         "final": ["transformer.ln_f.weight", "transformer.ln_f.bias"],
         "blocks": {
            0: {
               "norm_1_weight": "...",
               "norm_1_bias": "...",
               "norm_2_weight": "...",
               "norm_2_bias": "...",
            },
            1: { ... },
            ...
         }
      },
      "attn_in": { 0: "...", 1: "...", ... },
      "attn_in_bias": {...},
      "attn_out": {...},
      "attn_out_bias": {...},
      "mlp_in": {...},
      "mlp_in_bias": {...},
      "mlp_out": {...},
      "mlp_out_bias": {...},
    }
    """

    # ----------------------------------------------------------------
    # 1) Prepare the output structure
    # ----------------------------------------------------------------
    grouping = {
        "metadata": {
            "block_size": None,  # context length, if we find a pos embedding
            "n_layers": None,  # number of transformer.h.X blocks
            "n_embed": None,  # e.g. 768, 1024...
            "n_head": None,
            "head_size": None,
        },
        "embedding": [],
        "unembedding": [],
        "layernorms": {"final": [], "blocks": {}},
        "attn_in": {},
        "attn_in_bias": {},
        "attn_out": {},
        "attn_out_bias": {},
        "mlp_in": {},
        "mlp_in_bias": {},
        "mlp_out": {},
        "mlp_out_bias": {},
    }

    # ----------------------------------------------------------------
    # 2) Identify "global" or "top-level" keys
    # ----------------------------------------------------------------
    #  a) Embedding and unembedding
    if "transformer.wte.weight" in state_dict:
        grouping["embedding"].append("transformer.wte.weight")
    if "lm_head.weight" in state_dict:
        grouping["unembedding"].append("lm_head.weight")

    #  b) Final layer norm
    final_ln_weight = "transformer.ln_f.weight"
    final_ln_bias = "transformer.ln_f.bias"
    if final_ln_weight in state_dict:
        grouping["layernorms"]["final"].append(final_ln_weight)
    if final_ln_bias in state_dict:
        grouping["layernorms"]["final"].append(final_ln_bias)

    #  c) Possibly positional embeddings "transformer.wpe.weight"
    #     If found, we can guess block_size from shape[0].
    if "transformer.wpe.weight" in state_dict:
        pos_emb = state_dict["transformer.wpe.weight"]
        grouping["metadata"]["block_size"] = pos_emb.shape[0]
    # TODO: implement this?

    # ----------------------------------------------------------------
    # 3) Detect how many layers from "transformer.h.{i}."
    # ----------------------------------------------------------------
    layer_pattern = re.compile(r"transformer\.h\.(\d+)\.")
    layer_indices = set()
    for k in state_dict.keys():
        match_ = layer_pattern.search(k)
        if match_:
            layer_idx = int(match_.group(1))
            layer_indices.add(layer_idx)

    if not layer_indices:
        raise ValueError("No 'transformer.h.X' layers found - not a typical GPT model?")

    n_layers = max(layer_indices) + 1  # e.g. if h.0..h.7 => 8 layers
    grouping["metadata"]["n_layers"] = n_layers

    # ----------------------------------------------------------------
    # 4) Guess n_embed (model dim) from attn_in shape
    # ----------------------------------------------------------------
    # Typical GPT: "transformer.h.0.attn.attn.weight" has shape [3*d_model, d_model]
    attn_in_key_0 = f"transformer.h.0.attn.attn.weight"
    if attn_in_key_0 in state_dict:
        shape_ = state_dict[attn_in_key_0].shape
        # We'll guess n_embed from the second dimension
        # e.g. [3*d_model, d_model] => d_model = shape_[1]
        n_embed = shape_[1]
        grouping["metadata"]["n_embed"] = n_embed
    else:
        # fallback: maybe no standard key name
        grouping["metadata"]["n_embed"] = None

    # ----------------------------------------------------------------
    # 5) Attempt to guess n_head & head_size
    # ----------------------------------------------------------------
    # TODO: implement this
    grouping["metadata"]["head_size"] = None
    grouping["metadata"]["n_head"] = None

    # ----------------------------------------------------------------
    # 6) For each layer, gather norm, attn, mlp
    # ----------------------------------------------------------------
    # We'll do it by scanning over each layer i in layer_indices
    for i in sorted(layer_indices):
        grouping["layernorms"]["blocks"][i] = {}

        # norm_1 & norm_2
        n1w = f"transformer.h.{i}.norm_1.weight"
        n1b = f"transformer.h.{i}.norm_1.bias"
        n2w = f"transformer.h.{i}.norm_2.weight"
        n2b = f"transformer.h.{i}.norm_2.bias"

        norm_block_dict = grouping["layernorms"]["blocks"][i]
        if n1w in state_dict:
            norm_block_dict["norm_1_weight"] = n1w
        if n1b in state_dict:
            norm_block_dict["norm_1_bias"] = n1b
        if n2w in state_dict:
            norm_block_dict["norm_2_weight"] = n2w
        if n2b in state_dict:
            norm_block_dict["norm_2_bias"] = n2b

        # attn: attn.attn.* => attn_in, attn_in_bias
        #       attn.proj.* => attn_out, attn_out_bias
        aiw = f"transformer.h.{i}.attn.attn.weight"
        aib = f"transformer.h.{i}.attn.attn.bias"
        aow = f"transformer.h.{i}.attn.proj.weight"
        aob = f"transformer.h.{i}.attn.proj.bias"

        if aiw in state_dict:
            grouping["attn_in"][i] = aiw
        if aib in state_dict:
            grouping["attn_in_bias"][i] = aib
        if aow in state_dict:
            grouping["attn_out"][i] = aow
        if aob in state_dict:
            grouping["attn_out_bias"][i] = aob

        # mlp: fc (in), proj (out)
        fcw = f"transformer.h.{i}.mlp.fc.weight"
        fcb = f"transformer.h.{i}.mlp.fc.bias"
        pjw = f"transformer.h.{i}.mlp.proj.weight"
        pjb = f"transformer.h.{i}.mlp.proj.bias"

        if fcw in state_dict:
            grouping["mlp_in"][i] = fcw
        if fcb in state_dict:
            grouping["mlp_in_bias"][i] = fcb
        if pjw in state_dict:
            grouping["mlp_out"][i] = pjw
        if pjb in state_dict:
            grouping["mlp_out_bias"][i] = pjb

    # ----------------------------------------------------------------
    # 7) Collect any leftover keys not in the recognized groups
    # ----------------------------------------------------------------
    recognized_keys = set()
    # gather everything we inserted
    recognized_keys.update(grouping["embedding"])
    recognized_keys.update(grouping["unembedding"])
    recognized_keys.update(grouping["layernorms"]["final"])
    for i in grouping["layernorms"]["blocks"]:
        for v in grouping["layernorms"]["blocks"][i].values():
            recognized_keys.add(v)

    for i in grouping["attn_in"].values():
        recognized_keys.add(i)
    for i in grouping["attn_in_bias"].values():
        recognized_keys.add(i)
    for i in grouping["attn_out"].values():
        recognized_keys.add(i)
    for i in grouping["attn_out_bias"].values():
        recognized_keys.add(i)
    for i in grouping["mlp_in"].values():
        recognized_keys.add(i)
    for i in grouping["mlp_in_bias"].values():
        recognized_keys.add(i)
    for i in grouping["mlp_out"].values():
        recognized_keys.add(i)
    for i in grouping["mlp_out_bias"].values():
        recognized_keys.add(i)

    # also if "transformer.wpe.weight" we used for block_size
    if "transformer.wpe.weight" in state_dict:
        recognized_keys.add("transformer.wpe.weight")

    all_keys = set(state_dict.keys())
    leftover = all_keys - recognized_keys
    assert len(leftover) == 0, f"Keys not grouped: {leftover}"

    return grouping


def select_layer_key(layer: int, key: str, ckpt) -> tuple[str | list[str]]:
    """
    Returns parameter keys for different parts of the model (e.g., embeddings,
    layer norms, MLP weights/biases, attention weights/biases) for a given layer index.

    Args:
        layer (int): Which layer index to query.
        key (str): One of ["embedding", "output", "norm_weight", "norm_bias",
                           "mlp_weight", "mlp_bias", "attn_weight", "attn_bias"].
        ckpt (dict): A dictionary of all parameter keys (e.g., from a state_dict).

    Returns:
        tuple[str | list[str]]: The specific key(s) that match the `layer` and `key` criteria.
    """
    # Example logic for collecting all attn weights that contain 'attn.attn' in their name
    do_collect = lambda x: "transformer.h." in x and "attn.attn" in x and "weight" in x
    attn_keys = [k for k in ckpt.keys() if do_collect(k)]

    # Single embedding and output layer keys (used once in the entire model)
    embedding_key = [k for k in ckpt.keys() if "wte" in k][0]
    output_key = [k for k in ckpt.keys() if "head" in k][0]

    # Potential norm and MLP keys
    norm_weight_keys = [k for k in ckpt.keys() if ("norm" in k or "ln" in k) and "weight" in k]
    norm_bias_keys = [k for k in ckpt.keys() if ("norm" in k or "ln" in k) and "bias" in k]
    mlp_weight_keys = [
        k
        for k in ckpt.keys()
        if ("mlp" in k and "weight" in k) or ("attn.proj" in k and "weight" in k)
    ]
    mlp_bias_keys = [
        k for k in ckpt.keys() if ("mlp" in k and "bias" in k) or ("attn.proj" in k and "bias" in k)
    ]

    match key:
        case "embedding":
            return embedding_key

        case "output":
            return output_key

        case "norm_weight" | "norm_bias":
            # Grab all norm weights or biases that match this layer index
            _keys = norm_weight_keys if key == "norm_weight" else norm_bias_keys
            extract = [_s for _s in _keys if f"h.{layer - 1}." in _s]
            # If nothing found, we might be at the final layer norm
            if len(extract) == 0 and layer - 1 == (len(_keys) - 1) // 2:
                extract = [_keys[-1]]
            if len(extract) == 0:
                raise ValueError(f"Layer {layer} not found in {key} keys")
            return extract

        case "mlp_weight" | "mlp_bias":
            _keys = mlp_weight_keys if key == "mlp_weight" else mlp_bias_keys
            extract = [_s for _s in _keys if f"h.{layer - 1}." in _s]
            if len(extract) == 0:
                raise ValueError(f"Layer {layer} not found in {key} keys")
            return extract

        case "attn_weight" | "attn_bias":
            # `layer`-th attention weight or bias from attn_keys
            # Example logic: index them by (layer - 1), then replace 'weight' -> 'bias'
            select_layer = lambda x: attn_keys[int(x) - 1]
            select_layer_bias = lambda x: select_layer(x).replace("weight", "bias")
            extract = [select_layer_bias(layer)] if key == "attn_bias" else [select_layer(layer)]
            return extract

        case _:
            raise ValueError(f"Unknown key: {key} or layer: {layer}")


def gather_layerwise_keys(layer: int, ckpt: dict) -> dict:
    """
    Get the MLP, attention, and norm keys for a given layer. Also retrieves
    global embedding and output keys, though these typically apply only for layer=1
    or a separate "global" logic in a real scenario.

    Args:
        layer (int): Which layer index to query.
        ckpt (dict): A dictionary of all parameter keys (e.g. from a state_dict).

    Returns:
        dict: A dictionary grouping keys under:
          - "biases": 1D parameters (norm and MLP biases, attention bias)
          - "embeddings": 2D embedding matrices (token + output)
          - "weights_hidden": 2D MLP weights
          - "weights_attn": 2D attention weights
    """
    embedding_keys = select_layer_key(layer, "embedding", ckpt)
    output_keys = select_layer_key(layer, "output", ckpt)
    norm_weight_keys = select_layer_key(layer, "norm_weight", ckpt)
    norm_bias_keys = select_layer_key(layer, "norm_bias", ckpt)
    mlp_weight_keys = select_layer_key(layer, "mlp_weight", ckpt)
    mlp_bias_keys = select_layer_key(layer, "mlp_bias", ckpt)
    attn_weight_keys = select_layer_key(layer, "attn_weight", ckpt)
    attn_bias_keys = select_layer_key(layer, "attn_bias", ckpt)

    return {
        "biases": norm_weight_keys + norm_bias_keys + mlp_bias_keys + attn_bias_keys,
        "embeddings": [embedding_keys, output_keys],
        "weights_hidden": mlp_weight_keys,
        "weights_attn": attn_weight_keys,
    }


def apply_operation(
    base: dict,
    target: dict,
    keys: list[str],
    operation: callable,
    **op_kwargs,
):
    """
    Applies an `operation(base_tensor, target_tensor) -> torch.Tensor`
    to each key in `keys`. The result is written back into `target_sd[key]`.

    Args:
        keys: List of parameter names (e.g., 'transformer.h.0.mlp.fc.weight').
        base_sd: State dict of the base model.
        target_sd: State dict of the target model.
        operation: A callable with signature
                operation(base_tensor, target_tensor) -> expanded_tensor
                This might be a lambda that calls expand_tensor(...).
    """
    for pkey in keys:
        if pkey not in base or pkey not in target:
            print(f"Warning: key {pkey} not found in base or target.")
            continue
        base_tensor = base[pkey]
        target_tensor = target[pkey]
        # Perform the user-supplied operation
        op_kwargs["target_tensor"] = target_tensor
        result = operation(base_tensor=base_tensor, **op_kwargs)
        # Store it back
        target[pkey] = result


def plot_hypercloned_heatmaps(
    base_sd: dict,
    target_sd: dict,
    base_grouping: dict,
    target_grouping: dict,
    figsize=(10, 8),
    num_samples=1,
    random_seed=42,
):
    """
    Plots side-by-side heatmaps comparing base vs. target parameters
    for selected expansions (Case 1, 2, 3). By default, we pick:

      - "unembedding" (Case 1),
      - "embedding"   (Case 2),
      - "attn_in"     (Case 3) => attention weight
      - "mlp_in"      (Case 3) => feed-forward hidden

    You can add or remove from this selection as desired.
    For each param group, we pick up to 'num_samples' random layers/keys
    if there's more than one. Then we do a 2-column figure:
      left => base param,
      right => target param.

    Args:
        base_sd, target_sd (dict): Base and target state dicts
        base_grouping, target_grouping (dict): groupings from get_gpt_grouping(...)
        figsize (tuple): Size of the entire figure
        num_samples (int): How many random items from each param group
                           to visualize if multiple keys exist
        random_seed (int): For reproducible random picks
    """
    random.seed(random_seed)

    # We define which param groups to visualize
    # and how we label them (for the title).
    # Each item => (param_group_name, "title_for_plot")
    # For instance, 'attn_in' = "attention in => Case 3" or so.
    param_candidates = [
        ("unembedding", "Unembedding (Case 1)"),
        ("embedding", "Embedding (Case 2)"),
        ("attn_in", "Attention in (Case 3)"),
        ("mlp_in", "MLP in (Case 3)"),
    ]

    # We'll gather all param keys we want to plot
    # e.g. base_grouping["unembedding"] might have a list,
    # or attn_in is a dict with layer_idx -> key.
    # We'll store them in a list of (title, base_key, target_key).
    plots_to_make = []

    for group_name, group_title in param_candidates:
        # Check if group_name is in base_grouping & target_grouping
        if group_name not in base_grouping or group_name not in target_grouping:
            continue

        # If it's a list (like "embedding" => e.g. ["transformer.wte.weight"]),
        # we random sample from that. If it's a dict (like attn_in => {0: "transformer.h.0..."})
        # we also handle that. We'll unify them as a list of (key_index, key_string).
        base_group = base_grouping[group_name]
        target_group = target_grouping[group_name]

        # If it's a dict: base_group => {layer_idx: param_key, ...}
        # If it's a list: base_group => ["some_key", ...]
        if isinstance(base_group, dict):
            # gather items
            all_items_base = list(base_group.items())  # [(layer_idx, key_string), ...]
            all_items_target = list(target_group.items())
            # We must match them by layer_idx, so let's do a small matching dict approach:
            base_dict = dict(all_items_base)  # layer_idx -> param_key
            targ_dict = dict(all_items_target)
            # pick random layers
            all_layer_idxs = sorted(base_dict.keys())
            if len(all_layer_idxs) == 0:
                continue
            chosen_layers = random.sample(all_layer_idxs, min(num_samples, len(all_layer_idxs)))
            for layer_idx in chosen_layers:
                if layer_idx not in targ_dict:
                    continue
                base_key = base_dict[layer_idx]
                target_key = targ_dict[layer_idx]
                # store
                param_title = f"{group_title} (layer {layer_idx})"
                plots_to_make.append((param_title, base_key, target_key))

        elif isinstance(base_group, list):
            # just a list of keys
            if len(base_group) == 0:
                continue
            chosen_keys = random.sample(base_group, min(num_samples, len(base_group)))
            for b_key in chosen_keys:
                # find the corresponding t_key by same index in target
                idx_ = base_group.index(b_key)
                if idx_ >= len(target_group):
                    continue
                t_key = target_group[idx_]
                param_title = f"{group_title}"
                # if there's only one, we won't do 'layer #'
                plots_to_make.append((param_title, b_key, t_key))

    # Now let's plot them. We'll do #rows = len(plots_to_make), #cols=2 => base vs. target
    nrows = len(plots_to_make)
    ncols = 2
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)

    # If there's only 1 row, axes is not a 2D array => handle that
    if nrows == 1:
        axes = [axes]  # make it a list of pairs

    for row_idx, (param_title, base_key, target_key) in enumerate(plots_to_make):
        ax_left = axes[row_idx][0] if nrows > 1 else axes[0]
        ax_right = axes[row_idx][1] if nrows > 1 else axes[1]

        # get the base param
        w_base = base_sd[base_key]
        w_targ = target_sd[target_key]

        # convert to cpu, numpy
        wb_cpu = w_base.detach().cpu().numpy()
        wt_cpu = w_targ.detach().cpu().numpy()

        # We'll do a heatmap. For large shapes, you might want to just sample
        # or do .imshow(..., aspect='auto').
        im_left = ax_left.imshow(wb_cpu, cmap="coolwarm", aspect="auto")
        im_right = ax_right.imshow(wt_cpu, cmap="coolwarm", aspect="auto")

        ax_left.set_title(f"{param_title}\n(base: {base_key})")
        ax_right.set_title(f"{param_title}\n(target: {target_key})")

        # colorbars
        fig.colorbar(im_left, ax=ax_left, fraction=0.046, pad=0.04)
        fig.colorbar(im_right, ax=ax_right, fraction=0.046, pad=0.04)

    plt.tight_layout()
    # plt.show()
    plt.savefig("hypercloned_heatmaps.png")
