import math
from functools import partial
from typing import Literal, Optional
from warnings import warn

import lightning as L
import mup
import torch
import torch.nn as nn
from lightning.fabric.strategies import FSDPStrategy
from litgpt.config import Config
from litgpt.model import GPT, Block, CausalSelfAttention, GptNeoxMLP, LLaMAMLP
from litgpt.pretrain import reset_parameters
from mup import MuReadout


class file_data_share:
    """This class is mainly used for easy and quick data transfer between different files and methods."""

    layer_wise_max_attn_weight: list = []

    @staticmethod
    def clear_data() -> None:
        file_data_share.layer_wise_max_attn_weight = []


class GPT_Scales(GPT):
    """Overloading of the LitGPT class to use muP.

    Following instructions from https://github.com/microsoft/mup?tab=readme-ov-file#basic-usage

    """

    def __init__(
        self, config: Config, mup_init: bool = False, share_embeddings: bool = False
    ) -> None:
        super().__init__(config)
        self.share_embeddings = share_embeddings
        if mup_init:
            self.lm_head = MuReadout(
                config.n_embd,
                config.padded_vocab_size,
                bias=config.lm_head_bias,
                readout_zero_init=True,
            )
        else:
            self.lm_head = nn.Linear(
                config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
            )
        self.transformer = nn.ModuleDict(
            {
                "wte": nn.Embedding(config.padded_vocab_size, config.n_embd),
                "h": nn.ModuleList(Block_Scales(config, mup_init) for _ in range(config.n_layer)),
                "ln_f": config.norm_class(config.n_embd, eps=config.norm_eps),
            }
        )

        if self.share_embeddings:
            self.lm_head.weight = self.transformer["wte"].weight

    def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
        T = idx.size(1)
        if self.max_seq_length < T:
            raise ValueError(
                f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
            )

        if input_pos is not None:  # use the kv cache
            cos = self.cos.index_select(0, input_pos)
            sin = self.sin.index_select(0, input_pos)
            if self.mask_cache is None:
                raise TypeError("You need to call `gpt.set_kv_cache()`")
            mask = self.mask_cache.index_select(2, input_pos)
        else:
            #######################################################
            # Only change from the original LitGPT forward method #
            #######################################################
            cos = self.cos[:T].clone()  # DDP requires clone() here
            sin = self.sin[:T].clone()  # DDP requires clone() here
            mask = None

        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
        if self.config.scale_embeddings:
            x = x * (self.config.n_embd**0.5)

        for block in self.transformer.h:
            x = block(x, cos, sin, mask, input_pos)

        x = self.transformer.ln_f(x)

        return self.lm_head(x)  # (b, t, vocab_size)


class GPT_Scales_Detached(GPT_Scales):
    def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
        T = idx.size(1)
        if self.max_seq_length < T:
            raise ValueError(
                f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
            )

        if input_pos is not None:  # use the kv cache
            cos = self.cos.index_select(0, input_pos)
            sin = self.sin.index_select(0, input_pos)
            if self.mask_cache is None:
                raise TypeError("You need to call `gpt.set_kv_cache()`")
            mask = self.mask_cache.index_select(2, input_pos)
        else:
            #######################################################
            # Only change from the original LitGPT forward method #
            #######################################################
            cos = self.cos[:T].clone()  # DDP requires clone() here
            sin = self.sin[:T].clone()  # DDP requires clone() here
            mask = None

        x = self.transformer.wte(idx).detach()  # token embeddings of shape (b, t, n_embd)
        if self.config.scale_embeddings:
            x = x * (self.config.n_embd**0.5)

        for block in self.transformer.h:
            x = block(x, cos, sin, mask, input_pos)

        x = self.transformer.ln_f(x)

        return self.lm_head(x)


class Block_Scales(Block):
    def __init__(self, config: Config, mup_init: bool = False) -> None:
        super().__init__(config)
        self.attn = CausalSelfAttention_Scales(config, mup_init)


class CausalSelfAttention_Scales(CausalSelfAttention):
    def __init__(self, config: Config, mup_init: bool = False) -> None:
        super().__init__(config)
        self.mup_init = mup_init

    def scaled_dot_product_attention(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        scale = (
            1.0 / self.config.head_size if self.mup_init else 1.0 / math.sqrt(self.config.head_size)
        )

        L, S = q.size(-2), k.size(-2)

        if self.mup_init:
            scale_factor = 1 / (q.size(-1)) if scale is None else scale
        else:
            scale_factor = 1 / math.sqrt(q.size(-1)) if scale is None else scale

        # a bit more memory efficient version of causal self attention
        # https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L67
        attn_weight = q @ k.transpose(-2, -1) * scale_factor
        attn_bias = (
            torch.ones(L, S, dtype=q.dtype, device=q.device).tril(diagonal=0).requires_grad_(False)
        )
        attn_weight = attn_weight.masked_fill_(attn_bias == 0, float("-inf"))
        del attn_bias

        file_data_share.layer_wise_max_attn_weight.append(torch.max(attn_weight).detach().item())

        attn_weight = torch.softmax(attn_weight, dim=-1)
        y = attn_weight @ v

        return y.transpose(1, 2)


def initialize_weights(
    fabric: L.Fabric,
    model: GPT_Scales,
    mup_base_scales: dict[str, int] | int | None = None,
    init_type: Literal["plain", "scaled", "GPT-NeoX", "DeepSeek"] | None = None,
) -> float:
    """Initialize weights of the model.

    Args:
        fabric (L.Fabric): Fabric object.
        model (GPT_Scales): Model to be initialized.
        mup_base_scales (dict[str, int] | int | None, optional): Base scales for muP initialization. Defaults to None.
        init_type (Literal["plain", "scaled", "GPT-NeoX", "DeepSeek"] | None, optional): Initialization type. Defaults to None.

    Returns:
        float: Standard deviation of the initialized weights.
    """

    def init_module_weights(
        module: nn.Module,
        std: float,
        mup_init: bool,
    ) -> None:
        """Initialize weights of a module."""
        init_funct = mup.normal_ if mup_init else nn.init.normal_
        init_funct(module.weight, mean=0.0, std=std)
        if getattr(module, "bias", None) is not None:
            nn.init.zeros_(module.bias)

    # Determine the model dimension based on the base scales
    d_model = (
        mup_base_scales
        if isinstance(mup_base_scales, int)
        else mup_base_scales.get("d_model", model.config.n_embd)
        if isinstance(mup_base_scales, dict)
        else model.config.n_embd
    )

    if init_type in {"plain", "scaled", "GPT-NeoX"}:
        fabric.print(f"Using {init_type} weight initialization.")
        # "plain" weight initialization (https://arxiv.org/abs/2312.16903).
        std = math.sqrt(2.0 / (5 * d_model))
    elif init_type == "DeepSeek":
        fabric.print("Using DeepSeek weight initialization.")
        std = 0.006  # TODO: mention source
    else:
        fabric.print("Using scaled parameterization")
        std = 1.0 / math.sqrt(d_model)  # TODO: mention source

    # Initialize the Embedding and Linear modules
    # (for the "plain" and "DeepSeek" initialization that's it)
    for module in model.modules():
        if isinstance(module, (nn.Embedding, nn.Linear)):
            module.reset_parameters = partial(
                init_module_weights, module, std=std, mup_init=mup_base_scales is not None
            )

    # Adjust initialization for scaled model types
    if init_type in {"scaled", "GPT-NeoX"}:
        if init_type == "scaled":
            # "scaled" weight initialization (https://arxiv.org/abs/2312.16903).
            std = std / math.sqrt(model.config.n_layer * 2)
        elif init_type == "GPT-NeoX":
            # GPT-NeoX-20B weight initialization (https://arxiv.org/abs/2204.06745).
            std = 1 / math.sqrt(d_model) / model.config.n_layer

        def scaled_init(
            model: GPT_Scales,
            std: float,
            mup_init: bool,
            init_module_weights: callable,
        ) -> None:
            """Adjust standard deviation for LLaMAMLP, CausalSelfAttention_Scales, and GptNeoxMLP modules."""
            for module in model.modules():
                if isinstance(module, (LLaMAMLP, CausalSelfAttention_Scales, GptNeoxMLP)):
                    module.proj.reset_parameters = partial(
                        init_module_weights, module.proj, std=std, mup_init=mup_init
                    )

        # Adjust initialization for scaled model types
        scaled_init(
            model=model,
            std=std,
            mup_init=mup_base_scales is not None,
            init_module_weights=init_module_weights,
        )

    if not isinstance(fabric.strategy, FSDPStrategy):
        reset_parameters(model)
    else:
        warn(
            f"Cannot initialize network with current strategy {fabric.strategy}, using standard parametrization"
        )

    return std
