import math
import numpy as np
import torch
# from copy import deepcopy
from functools import partial

from saws.warmstart.common import get_gpt_grouping
from saws.warmstart.hyperclone_utils import (
    global_params_map,
    attn_params_map,
    mlp_params_map,
    clone_and_replace_matrix, 
    clone_and_replace_vector, 
    clone_positional_embedding
)
from saws.warmstart.utils import _ddp_proof_deepcopy


def _clone_params_for_layer(
    layer_idx: int,
    param_op_map: dict[str, callable],
    base_grouping: dict,
    target_grouping: dict,
    base_sd: dict,
    target_sd: dict,
    snr_db=None,
    normalize=False,
    scale_q_if_doubled: bool=False,  # <--- small hack for Q-scaling
):
    """
    Clones multiple parameter groups (e.g. "attn_in", "mlp_out") for 
    a single 'layer_idx'. This references param_op_map, which maps 
    each group (like "attn_in") to a clone function 
    (like clone_and_replace_matrix or clone_and_replace_vector).

    In the paper's expansions (Case 1, 2, 3):
      - Typically, attention layers are "Case 3" if both input & output 
        dimension are expanded.
      - If only the output dimension is expanded, that's "Case 2" 
        (common for 'attn_out').
      - If only the input dimension is expanded, that's "Case 1" 
        (less common in attention, but might be in some architectures).

    scale_q_if_doubled:
      A quick hack that, if True and we detect "attn_in" with a 
      [3*d_in, d_in] shape, attempts to scale Q by 1/sqrt(2) if 
      the dimension was doubled from d->2d. This partially addresses 
      the paper's note that if the head dimension doubles, 
      new attention outputs ~ sqrt(2)*old, so we scale Q to keep 
      old outputs consistent.

    Differences (TODOs):
      - We do not do the paper's exact random offset expansions (eta1, eta2). 
        Instead, if snr_db is set, partial random noise is added by 
        clone_and_replace_matrix -> add_noise(...) in a 2D block-based pattern, 
        which is not identical to the paper's approach but is approximate.
      - We do not replicate heads if num_heads_multiplier>1. 
        That is left for future expansions.
    """
    for param_group, clone_func in param_op_map.items():
        if param_group in base_grouping and param_group in target_grouping:
            if (layer_idx in base_grouping[param_group] 
                and layer_idx in target_grouping[param_group]):
                b_key = base_grouping[param_group][layer_idx]
                t_key = target_grouping[param_group][layer_idx]
                if b_key in base_sd and t_key in target_sd:

                    # Perform the normal expansions
                    clone_func(
                        base_sd=base_sd,
                        target_sd=target_sd,
                        bkey=b_key,
                        tkey=t_key,
                        snr_db=snr_db,
                        normalize=normalize
                    )

                    # Optional hack: if "attn_in" and scale_q_if_doubled is True,
                    # we try to detect a [3*d_in, d_in] shape with dimension doubling,
                    # then slice out Q block and multiply by 1/sqrt(2).
                    if param_group == "attn_in" and scale_q_if_doubled:
                        w = target_sd[t_key]
                        d_out, d_in = w.shape
                        # If d_out = 3*d_in => Q,K,V each is [d_in, d_in]. 
                        # If it's 2x expansions from an old dimension 'd', 
                        # we guess old was d_in/2 => new is d_in => scale Q by 1/sqrt(2).
                        if d_out == 3 * d_in:
                            # print(f"[hyperclone] Scaling Q for layer {layer_idx} by 1/sqrt(2).")
                            with torch.no_grad():
                                q_block = w[:d_in, :]  # Q portion
                                q_block *= 1.0 / math.sqrt(2)
                            w[:d_in, :] = q_block
                            target_sd[t_key] = w


def _clone_global_params(
    param_op_map: dict[str, callable],
    base_grouping: dict,
    target_grouping: dict,
    base_sd: dict,
    target_sd: dict,
    snr_db: float = None,
    normalize: bool = False,
):
    """
    Clones "global" GPT parameters (embedding, unembedding, final LN) 
    by iterating param_op_map, which might be:
      {"embedding": clone_and_replace_matrix, 
       "unembedding": clone_and_replace_matrix}.

    Typically:
      - "embedding" => "Case 2" expansions (only output dim repeated).
      - "unembedding" => "Case 1" expansions (only input dim repeated).

    If snr_db is set, partial noise is injected in expansions 
    for 2D expansions (like weights). For 1D expansions (like bias vectors), 
    we skip or do a simpler approach if add_noise is set up for that.

    Differences from the paper:
      - We do not do explicit random offsets (eta), 
        we do partial snr_db-based noise if 2D expansions are found.
    """
    for group_key, clone_func in param_op_map.items():
        if (group_key in base_grouping 
            and group_key in target_grouping
            and base_grouping[group_key] 
            and target_grouping[group_key]):
            for b_key, t_key in zip(base_grouping[group_key], 
                                    target_grouping[group_key]):
                if b_key in base_sd and t_key in target_sd:
                    clone_func(
                        base_sd=base_sd,
                        target_sd=target_sd,
                        bkey=b_key,
                        tkey=t_key,
                        snr_db=snr_db,
                        normalize=normalize
                    )


def hyperclone(
    base: torch.nn.Module | dict,
    target: torch.nn.Module | dict,
    retain_type: bool = True,
    normalize: bool = False,
    snr_db: float = 0.01,
    **kwargs,
) -> torch.nn.Module | dict:
    """
    Clones a GPT-style model by grouping and replicating parameters, 
    referencing the expansions in the paper:

      - Case 1 (input only expansions): e.g. unembedding
      - Case 2 (output only expansions): e.g. embedding
      - Case 3 (both in+out expansions): e.g. hidden MLP / attention

    We partially approximate the random offsets "eta" from the paper 
    by applying snr_db-based noise in 2D expansions (like weights). 
    If snr_db=None, no random offsets are used.

    We do:
      1) Convert base & target to state_dict if needed.
      2) Gather param grouping from get_gpt_grouping (embedding, unembedding, LN, attn, mlp).
      3) Check expansions => must be integer factor, e.g. doubling from 768 -> 1536.
      4) Clone top-level items (embedding = "Case 2", unembedding="Case 1").
      5) Clone final LN & positional embeddings (still "Case 2" if dimension is bigger).
      6) For each layer, replicate LN, attention, MLP => typically "Case 3" expansions 
         in hidden layers, though only partial expansions if e.g. we expand 
         input but not output or vice versa.
      7) If the dimension is scaled, we do (unembedding_weight *= 1/scaling_factor) 
         akin to scaledLinear in the original code.
      8) If retain_type, load the new state_dict into 'target' if it's an nn.Module.

    Differences from the paper:
      - We do not replicate heads if num_heads_multiplier>1 
        (the code is partial).
      - The random offset approach from the paper (eta1, eta2) 
        is replaced by snr_db-based add_noise in 2D expansions only.
      - For 1D expansions (like bias), add_noise is generally not applied 
        to avoid index errors or we do a simpler approach if we want 1D offsets.
      - The "scale Q by 1/sqrt(2)" for dimension doubling in attention 
        is done as a hack if scale_q_if_doubled=True in _clone_params_for_layer, 
        but only if attn_in is shaped [3*d_in, d_in].

    Args:
        base (nn.Module|dict): The smaller model or its state dict.
        target (nn.Module|dict): The bigger model or its state dict.
        retain_type (bool): If True, we load the final dict into target 
            if target is a module. Else we return the dict.
        normalize (bool): If True, we pass `normalize=True` to expansions, 
            dividing repeated columns by sqrt(n).
        snr_db (float): If not None, partial random noise is added in expansions 
            for 2D parameters. Defaults to 0.1. 
            If you don't want noise, set it to None or 0.

    Returns:
        The updated target model or the new state dict if retain_type=False.
    """
    # Convert to state dict if needed
    if isinstance(base, torch.nn.Module):
        base = base.state_dict()
    # _target = deepcopy(target)
    _target = _ddp_proof_deepcopy(target)

    # Gather param groupings
    base_grouping = get_gpt_grouping(base)
    target_grouping = get_gpt_grouping(_target)

    # from .common import plot_hypercloned_heatmaps
    # plot_hypercloned_heatmaps(base, _target, base_grouping, target_grouping)

    # We assume the model dimension is scaled, check factor
    base_n_embed = base_grouping["metadata"]["n_embed"]
    target_n_embed = target_grouping["metadata"]["n_embed"]
    scaling_factor = target_n_embed / base_n_embed
    if not scaling_factor.is_integer():
        raise ValueError(
            f"Scaling factor must be integer. Found {target_n_embed}/{base_n_embed} => {scaling_factor}"
        )
    scaling_factor = int(scaling_factor)
    
    # 1) Clone top-level items (embedding => "Case 2", unembedding => "Case 1")
    _clone_global_params(
        param_op_map=global_params_map,
        base_grouping=base_grouping,
        target_grouping=target_grouping,
        base_sd=base,
        target_sd=_target,
        snr_db=snr_db,
        normalize=False
    )

    # 2) Clone final LN if present, repeating LN param n times so LN output is repeated
    final_ln_base = base_grouping["layernorms"]["final"]
    final_ln_target = target_grouping["layernorms"]["final"]
    for b_ln, t_ln in zip(final_ln_base, final_ln_target):
        if b_ln in base and t_ln in _target:
            if len(base[b_ln].shape) == 2:
                clone_and_replace_matrix(
                    base, _target, b_ln, t_ln, snr_db=snr_db, normalize=False
                )
            else:
                clone_and_replace_vector(
                    base, _target, b_ln, t_ln, snr_db=snr_db
                )

    # 3) Clone positional embedding => "Case 2" expansions if dimension is bigger
    if "transformer.wpe.weight" in base and "transformer.wpe.weight" in _target:
        clone_positional_embedding(base, _target, snr_db=snr_db)

    # 4) For each layer => LN, attention, MLP expansions
    n_layers_base = base_grouping["metadata"]["n_layers"]
    n_layers_target = target_grouping["metadata"]["n_layers"]
    max_layer = min(n_layers_base, n_layers_target)

    for layer_idx in range(max_layer):
        # LN in each block
        norm_block_b = base_grouping["layernorms"]["blocks"][layer_idx]
        norm_block_t = target_grouping["layernorms"]["blocks"][layer_idx]
        for norm_key in ["norm_1_weight", "norm_1_bias", "norm_2_weight", "norm_2_bias"]:
            if norm_key in norm_block_b and norm_key in norm_block_t:
                bnk = norm_block_b[norm_key]
                tnk = norm_block_t[norm_key]
                if bnk in base and tnk in _target:
                    if len(base[bnk].shape) == 2:
                        clone_and_replace_matrix(
                            base, _target, bnk, tnk, snr_db=snr_db, normalize=False
                        )
                    else:
                        clone_and_replace_vector(
                            base, _target, bnk, tnk, snr_db=snr_db
                        )

        # Attention expansions => "Case 3" if both input & output scaled, 
        # plus optional Q-scaling hack
        _clone_params_for_layer(
            layer_idx=layer_idx,
            param_op_map=attn_params_map,
            base_grouping=base_grouping,
            target_grouping=target_grouping,
            base_sd=base,
            target_sd=_target,
            snr_db=snr_db,
            normalize=False,
            scale_q_if_doubled=True
        )

        # MLP expansions => "Case 3" if both input & output scaled
        _clone_params_for_layer(
            layer_idx=layer_idx,
            param_op_map=mlp_params_map,
            base_grouping=base_grouping,
            target_grouping=target_grouping,
            base_sd=base,
            target_sd=_target,
            snr_db=snr_db,
            normalize=normalize
        )

    # 5) If embedding dimension scaled, we do partial approach => multiply unembedding by 1/scaling_factor
    if scaling_factor > 1:
        if target_grouping["unembedding"]:
            unemb_key = target_grouping["unembedding"][0]
            if unemb_key in _target:
                _target[unemb_key] *= 1.0 / scaling_factor

    # Differences from the paper we DO NOT implement:
    #  - explicit multi-head duplication if num_heads_multiplier > 1
    #  - exact random offsets for expansions => only snr_db-based partial offsets
    #  - advanced Q-scaling for any ratio other than doubling

    # Final step: load updated dict if retain_type and target is a module
    if retain_type and isinstance(target, torch.nn.Module):
        target.load_state_dict(_target)
        return target
    else:
        return _target
