from copy import deepcopy
import math
from pathlib import Path
import re
import torch


DEFAULT_SHRINKING_FACTOR = 0.4  # from literature, TODO: cite here.
DEFAULT_PERTURBATION_FACTOR = 0.1  # from literature, TODO: cite here.


def split_architecture(keys):
    """Splits the architecture keys into layers and organizes them.

    NOTE: Designed to handle `saws.model.GPT_Scales` only for now.
    TODO: Extend for more architecture classes; can we generalize?

    Args:
        keys: A list of architecture keys.

    Returns:
        A dictionary where the keys are layer indices and the values are lists of
        parameter names within that layer.
    """

    pattern = r"^(lm_head|transformer\.ln_f|transformer\.wte|transformer\.h\.(\d+)\.(.+))\.([a-z0-9_]+)$"
    layers = {}

    for key in keys:
        match = re.match(pattern, key)
        if match:
            layer_name, layer_index, layer_type, param_type = match.groups()
            if layer_index:
                layers.setdefault(layer_index, []).append(key)
            else:
                layers.setdefault(layer_name, []).append(key)

    return layers


def _check_warmstart_tensors(base: torch.Tensor, target: torch.Tensor) -> None:
    """Checks if the base and target tensors have the same shape.

    Args:
        base: The base tensor.
        target: The target tensor.
    """
    assert all(t >= b for t, b in zip(target.shape, base.shape)), \
        "The target tensor must be greater or equal to the base tensor for every dimension."


def _check_warmstart_dicts(base: dict, target: dict) -> None:
    """Checks if the base and target dictionaries have the same keys.

    Args:
        base: The base dictionary.
        target: The target dictionary.
    """
    assert sorted(base.keys()) == sorted(target.keys()), \
        "The keys of the base and target models must match."


def _is_active_layer(layer: str, active_layer: str) -> bool:
    """Checks if the layer is the active layer.

    Args:
        layer: The layer to check.
        active_layer: The active layer.

    Returns:
        Whether the layer is the active layer.
    """
    if active_layer is None:
        return True
    # if active_layer is specified and not None
    if active_layer == "input" and "wte" in layer.lower():
        return True
    if active_layer == "hidden" and (".h." in layer.lower() or "ln_f" in layer.lower()):
        return True
    if active_layer == "readout" and "lm_head" in layer.lower():
        return True
    if active_layer == "embeddings" and ("lm_head" in layer.lower() or "wte" in layer.lower()):
        return True
    return False


def _target_with_base_masked(
    base: torch.Size,
    target: torch.Size,
    mask_base: bool=False
) -> torch.Tensor:
    """Returns a masked target tensor with the base tensor shapes set to 0 and rest to 1.
    
    Args:
        base: torch.Size, The base tensor shape.    
        target: torch.Size, The target tensor shape.
        mask_base: bool, Whether to mask the base tensor before applying shrink-and-perturb.

    Returns:
        torch.Tensor: The masked target tensor containing 0s and 1s.
    """
    _mask = torch.ones(target)  # .to(target.device)
    if not mask_base:
        return _mask
    if len(base) == 1:
        _mask[:base[0]] = 0
    elif len(base) == 2:
        _mask[:base[0], :base[1]] = 0
    else:
        raise NotImplementedError("Masking for tensors with more than 2 dimensions is not yet supported.")
    return _mask


def _get_shrinking_factor(
    base: torch.Size,
    target: torch.Size,
    layer: str,
    active_layer: str = None,
) -> float:
    shrinking_factor_fn = lambda x: DEFAULT_SHRINKING_FACTOR-math.exp(-x+(1-0.9163)) 
    if not _is_active_layer(layer, active_layer):
        return 1.0
    # calculate scaling factor given layer
    if len(base) == 1:
        # for 1-D tensors across all layers
        scaling_factor = target[0] / base[0]
        # shrinking_factor = 1 / scaling_factor
        shrinking_factor = shrinking_factor_fn(scaling_factor)
    elif len(base) == 2 and (
        _is_active_layer(layer, "input") or _is_active_layer(layer, "readout")
    ):
        # for 2-D embedding and unembedding layers
        if base[0] == target[0] and base[1] != target[1]:
            scaling_factor = target[1] / base[1]
        elif base[0] != target[0] and base[1] == target[1]:
            scaling_factor = target[0] / base[0]
        else:
            raise ValueError("Input and readout layers must have one dimension matching.")
        # shrinking_factor = 1 / scaling_factor
        shrinking_factor = shrinking_factor_fn(scaling_factor)
    elif len(base) == 2 and _is_active_layer(layer, "hidden"):
        # for 2-D hidden layers (attention layers + MLP projections)
        scaling_factor = target[-1] / base[-1]
        # shrinking_factor = 1 / math.sqrt(scaling_factor)
        shrinking_factor = shrinking_factor_fn(math.sqrt(scaling_factor))
    else:
        raise NotImplementedError("More than 2 dimensionsal tensors is not yet supported.")    

    return shrinking_factor


def _apply_shrinking(
    target_layer: torch.Tensor,
    target_weight: torch.Tensor,
    shrinking_factor: float,
    active_layer: str = None,
    base_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Shrinks the target tensor by a factor if part of the `active_layer` specified.

    If active_layer is None, the function shrinks all the weights by the factor.
    If active_layer is specified, the function shrinks only the weights of that layer.

    Args:
        base: The base tensor.
        target: The target tensor.
        shrinking_factor: The factor to shrink the target tensor by.
        active_layer: The layer to apply the shrinking to.
        base_mask: The mask to apply to the base tensor.

    Returns:
        The shrunk target tensor.
    """
    shrinking_scaling_matrix = torch.ones_like(target_weight) * shrinking_factor

    if base_mask is not None:
        shrinking_scaling_matrix = shrinking_scaling_matrix * base_mask
        # important to set 1 to not shrink the masked weights
        shrinking_scaling_matrix[base_mask == 0] = 1

    if _is_active_layer(target_layer, active_layer):
        return target_weight * shrinking_scaling_matrix
    return target_weight


def _pad_tensors_to_zeros(base: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Pads the base tensor with zeros to match the shape of the target tensor.
    
    Args:
        base: The tensor to be padded.
        target: The tensor to match the shape of.

    Returns:
        The padded tensor.
    """
    assert len(base.shape) == len(target.shape), "Tensors must have the same number of dimensions."
    _check_warmstart_tensors(base, target)

    if base.shape == target.shape:
        # copies the base model weights to the target model as the sizes are the same
        return base

    # create a target-sized tensor with 0s
    target = torch.zeros_like(target)
    # add the base matrix to the target matrix from the top-left corner
    ## the below operation is equivalent to:- target[:base.shape[0], :base.shape[1]] = base
    ## but generalized for n-dimensional tensors
    target[tuple(slice(0, dim) for dim in base.shape)] = base
    return target


def _clone_tensor(base, target):
    """
    Resize a tensor `base` to match shape of `target` by repeating `base` along all axes.

    Slice to the required `target` shape if shapes are not exact multiples.
    
    Args:
        base (torch.Tensor): The base tensor to be resized.
        target (torch.Tensor): The target tensor whose shape the base tensor should match.
    
    Returns:
        torch.Tensor: The resized tensor.
    """
    # Calculate the repeat factors for each dimension
    repeat_factors = [math.ceil(t / b) for t, b in zip(target.shape, base.shape)]
    
    # Repeat the base tensor
    resized = base.repeat(repeat_factors)
    
    # Slice the resized tensor to match the target shape
    return resized[tuple(slice(0, dim) for dim in target.shape)]


def _clone_tensor_mirror(base, target):
    """
    Resize a tensor `base` to match shape of `target` by repeating `base` along all axes,
    with flipping along axis 1 for column-wise expansion and axis 0 for row-wise expansion.

    Slice to the required `target` shape if shapes are not exact multiples.
    
    Args:
        base (torch.Tensor): The base tensor to be resized.
        target (torch.Tensor): The target tensor whose shape the base tensor should match.
    
    Returns:
        torch.Tensor: The resized tensor.
    """
    # Initialize the current tensor as the base
    # _target = deepcopy(base)
    _target = _ddp_proof_deepcopy(base)

    # Handle row-wise expansion (axis 0)
    if target.shape[0] > base.shape[0]:
        num_row_repeats = math.ceil(target.shape[0] / base.shape[0])
        row_tiles = []
        for i in range(num_row_repeats):
            if i % 2 == 0:
                row_tiles.append(_target)
            else:
                row_tiles.append(_target.flip(dims=[0]))
        _target = torch.cat(row_tiles, dim=0)

    # Handle column-wise expansion (axis 1)
    if len(target.shape) > 1 and target.shape[1] > base.shape[1]:
        num_col_repeats = math.ceil(target.shape[1] / base.shape[1])
        col_tiles = []
        for i in range(num_col_repeats):
            if i % 2 == 0:
                col_tiles.append(_target)
            else:
                col_tiles.append(_target.flip(dims=[1]))
        _target = torch.cat(col_tiles, dim=1)
    
    # Slice the resized tensor to match the target shape
    return _target[tuple(slice(0, dim) for dim in target.shape)]


def _ddp_proof_deepcopy(T: torch.Tensor) -> dict[str, torch.Tensor]:
    torch.cuda.empty_cache()
    if isinstance(T, torch.nn.Module):
        _target = {k: v.clone() for k, v in T.state_dict().items()}
    else:
        _target = {k: v.clone() for k, v in T.items()}
    return _target


def clone_base_to_target(
    base: torch.nn.Module | dict,
    target: torch.nn.Module | dict,
    mirror: bool = False,
    retain_type: bool = True,
    **kwargs,
) -> torch.nn.Module | dict:
    """Clones the base model to match the shape of the target model.
    
    Args:
        base: The base model.
        target: The target model.
        mirror: To mirror the base tensor along the axes to match the target tensor shape.
        retain_type: Whether to retain the type of the target model.

    Returns:
        The cloned target model.
    """
    if isinstance(base, torch.nn.Module):
        base = base.state_dict()
    # _target = deepcopy(target)
    _target = _ddp_proof_deepcopy(target)
    if isinstance(target, torch.nn.Module):
        _target = target.state_dict()

    _check_warmstart_dicts(base, _target)

    for k in _target.keys():
        # Clone the base tensor to match the target tensor shape
        _target[k] = _clone_tensor_mirror(
            base[k], _target[k]
        ) if mirror else _clone_tensor(base[k], _target[k])

    if retain_type:
        if isinstance(target, torch.nn.Module):
            target.load_state_dict(_target)
        else:
            target = _target
        return target
    else:
        return _target


def pad_zeros_model(
    base: torch.nn.Module | dict,
    target: torch.nn.Module | dict,
    retain_type: bool = True,
    active_layer: str = None,
    **kwargs,
) -> torch.nn.Module | dict:
    """ Pads the base model with zeros to match the shape of the target model.

    Args:
        base: The base model.
        target: The target model.
        mup_init: Whether to use μP initialization instead of zeros.
        retain_type: Whether to retain the type of the target model.

    Returns:
        The padded target model.
    """
    if isinstance(base, torch.nn.Module):
        base = base.state_dict()

    # Need explicit copying instead of a `deepcopy()` as it affects serialization during DDP
    _target = _ddp_proof_deepcopy(target)

    _check_warmstart_dicts(base, _target)

    for k in _target.keys():
        if not _is_active_layer(k, active_layer):
            continue
        # Pad the base tensor with zeros to match the target tensor shape
        _target[k] = _pad_tensors_to_zeros(base[k], _target[k])

    if retain_type:
        if isinstance(target, torch.nn.Module):
            target.load_state_dict(_target)
        else:
            target = _target
        return target
    else:
        return _target


def _sample_and_fill(
    base: torch.nn.modules.module.Module | dict,
    target: torch.nn.modules.module.Module | dict,
) -> torch.nn.modules.module.Module | dict:    
    """Fills the target tensor or module with samples from the base tensor or module.

    The function first duplicates available columns from the base to complete a matrix 
    with base.shape[0] rows filled.
    Then, it samples rows from the available rows and columns in the base.

    Args:
        base (torch.nn.modules.module.Module | dict): The tensor or module to sample from. 
            It should have shape (N, M).
        target (torch.nn.modules.module.Module | dict): The tensor or module to fill with samples. 
            It should have shape (P, Q).

    Returns:
        torch.nn.modules.module.Module | dict: The target tensor or module, 
            filled with samples from the base.

    Note:
        The function modifies the target in-place for the column sampling, 
            but returns a new tensor or module for the row sampling.
    """
    assert len(base.shape) <= 2 and len(target.shape) <=2, \
        "This function supports only up to 2-dimensional tensors!"

    # sampling for columns across available rows and columns 
    # i.e., duplicating available columns to complete a matrix with base.shape[0] rows filled
    reshape_flag = False
    if len(base.shape) == 1:
        reshape_flag = True
        base = base.reshape(1, base.shape[0])
        target = target.reshape(1, target.shape[0])

    idxs = torch.randint(0, base.shape[1], (target.shape[1],))
    idxs[:base.shape[1]] = torch.arange(0, base.shape[1])
    target[:base.shape[0], :] = target[:base.shape[0], idxs]

    # sampling for rows across available rows and columns
    idxs = torch.randint(0, base.shape[0], (target.shape[0],))
    idxs[:base.shape[0]] = torch.arange(0, base.shape[0])
    target = target[idxs, :]

    if reshape_flag:
        target = target[0, :]

    return target


def slice_from_base_to_target(
    base: torch.nn.Module | dict,
    target: torch.nn.Module | dict,
    retain_type: bool = True,
    active_layer: str = None,
    **kwargs,
) -> torch.nn.Module | dict:
    if isinstance(base, torch.nn.Module):
        base = base.state_dict()
    if isinstance(target, torch.nn.Module):
        target_state_dict = target.state_dict()
    else:
        target_state_dict = target

    _check_warmstart_dicts(base, target_state_dict)

    for k in target_state_dict.keys():
        if not _is_active_layer(k, active_layer):
            continue
        # Pad zeros first
        target_state_dict[k] = _pad_tensors_to_zeros(base[k], target_state_dict[k])
        # Sample and fill
        target_state_dict[k] = _sample_and_fill(base[k], target_state_dict[k])

    if retain_type:
        if isinstance(target, torch.nn.Module):
            target.load_state_dict(target_state_dict)
        else:
            target = target_state_dict
    else:
        return target_state_dict


def shrink_and_perturb(
    base: torch.nn.Module | dict,
    target: torch.nn.Module | dict,
    warm_type: str = "slice",
    shrinking_factor: float | str = DEFAULT_SHRINKING_FACTOR,
    perturbation_sigma: float = DEFAULT_PERTURBATION_FACTOR,
    mup_init: bool = False,
    retain_type: bool = True,
    active_layer: str = None,
    mask_base: bool = False,
    **kwargs,
) -> torch.nn.Module | dict:
    """Shrinks the target model and perturbs it.

    By default, the function applies the shrink-and-perturb (SnP) method.
    The `base` model is applied first as per the `warm_type` selected.
    Subsequently, given the `shrinking_factor`, the target model is scaled, and given the
    `perturbation_sigma`, the scaled weights are perturbed.
    The defaults of `shrinking_factor=0.4` and `perturbation_sigma=0.1` are set as per the default 
    values from the official PyTorch implementation of SNP: https://github.com/JordanAsh/warm_start
    Also used in the papers: https://arxiv.org/abs/2206.10011 and https://arxiv.org/abs/2307.15621

    Args:
        base (torch.nn.Module | dict): The base model or state_dict.
        target (torch.nn.Module | dict): The target model or state_dict.
        warm_type (str): The warmstarting type to apply to the target model.
        shrinking_factor (float | str): The factor to shrink the target model by.
            If 'str', check for "layer-wise", to apply dynamic shrinking factors
        perturbation_sigma (float): The standard deviation of the perturbation.
        mup_init (bool): Whether to use μP initialization instead of zeros.
        retain_type (bool): Whether to retain the type of the target model.
        active_layer (str): The layer to apply the warmstarting to.
            If not None, the chosen layer will be the only one to be warmstarted and other will 
            receive the standard initialization specified by the routine.
            Can be ["input", "hidden", "readout"].
        mask_base (bool): Whether to mask the base model before applying shrink-and-perturb.
    
    Returns:
        torch.nn.Module | dict: The warmstarted target model.
    """
    if warm_type == "zeros":
        _target = pad_zeros_model(base, target, retain_type=False, active_layer=active_layer)
    elif warm_type == "slice":
        _target = slice_from_base_to_target(base, target, retain_type=False, active_layer=active_layer)
    elif warm_type == "clone":
        _mirror = kwargs["mirror"] if "mirror" in kwargs else False
        _target = clone_base_to_target(base, target, mirror=_mirror, retain_type=False)
    else:
        raise ValueError(f"Invalid warm_type: {warm_type}.")

    if isinstance(_target, torch.nn.Module):
        target_state_dict = _target.state_dict()
    else:
        target_state_dict = _target

    for k, v in target_state_dict.items():
        # generate a potential mask for the base tensor
        _base_mask = _target_with_base_masked(
            base[k].shape, v.shape, mask_base
        ).to(v.device)
        # get scaling factor for shrinking
        if isinstance(shrinking_factor, str) and not shrinking_factor == "layer-wise":
            raise ValueError(
                f"Invalid str shrinking_factor: {shrinking_factor}. Try a float or \"layer-wise\"."
            )
        _shrinking_factor = _get_shrinking_factor(
            base[k].shape, v.shape, k, active_layer
        ) if isinstance(shrinking_factor, str) else shrinking_factor
        # apply shrinking to each tensor
        v_shrunk = _apply_shrinking(k, v, _shrinking_factor, active_layer, _base_mask)

        if mup_init:
            # Use the μP-initialized weights as the perturbation, assuming the `target` is from μP
            perturbation = (
                target.state_dict()[k].to(v.device)
                if isinstance(target, torch.nn.Module)
                else target[k].to(v.device)
            )
        else:
            perturbation = torch.normal(0, perturbation_sigma, v.shape).to(v.device)
        target_state_dict[k] = v_shrunk + (perturbation * _base_mask)

    if retain_type:
        if isinstance(target, torch.nn.Module):
            target.load_state_dict(target_state_dict)
        else:
            target = target_state_dict
        return target
    else:
        return target_state_dict


def zero_centered_mup_perturb(
    base: torch.nn.Module | dict,
    target: torch.nn.Module | dict,
    warm_type: str = "zeros",
    retain_type: bool = True,
    **kwargs,
) -> torch.nn.Module | dict:
    """Zero centers the scaled weights instead of shrinking in SnP.
    
    Assumes the following:
        shrinking_factor = 1
        perturbation_sigma = 0
        mup_init = True

    Args:
        base (torch.nn.Module | dict): The base model or state_dict.
        target (torch.nn.Module | dict): The target model or state_dict.
        warm_type (str): The warmstarting type to apply to the target model.
        retain_type (bool): Whether to retain the type of the target model.

    Returns:
        torch.nn.Module | dict: The warmstarted target model.
    """

    if warm_type == "zeros":
        _target = pad_zeros_model(base, target, retain_type=False)
    elif warm_type == "slice":
        _target = slice_from_base_to_target(base, target, retain_type=False)
    elif warm_type == "clone":
        _mirror = kwargs["mirror"] if "mirror" in kwargs else False
        _target = clone_base_to_target(base, target, mirror=_mirror, retain_type=False)
    else:
        raise ValueError(f"Invalid warm_type: {warm_type}.")
    
    if isinstance(_target, torch.nn.Module):
        target_state_dict = _target.state_dict()
    else:
        target_state_dict = _target

    for k, v in target_state_dict.items():
        # NOTE: key difference from vanilla-SnP
        v_modified = (v - v.mean()) / v.std()

        # Use the μP-initialized weights as the perturbation, assuming the `target` is from μP
        perturbation = (
            target.state_dict()[k].to(v.device)
            if isinstance(target, torch.nn.Module)
            else target[k].to(v.device)
        )
        target_state_dict[k] = v_modified + perturbation

    if retain_type:
        if isinstance(target, torch.nn.Module):
            target.load_state_dict(target_state_dict)
        else:
            target = target_state_dict
        return target
    else:
        return target_state_dict
