"""
Simple, minimal implementation of Mamba in one file of PyTorch.

Suggest reading the following before/while reading the code:
    [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
        https://arxiv.org/abs/2312.00752
    [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
        https://srush.github.io/annotated-s4

Glossary:
    b: batch size                       (`B` in Mamba paper [1] Algorithm 2)
    l: sequence length                  (`L` in [1] Algorithm 2)
    d or d_model: hidden dim
    n or d_state: latent state dim      (`N` in [1] Algorithm 2)
    expand: expansion factor            (`E` in [1] Section 3.4)
    d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)
    A, B, C, D: state space parameters  (See any state space representation formula)
                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
    Δ or delta: input-dependent step size
    dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")
"""

from __future__ import annotations

import os
import json
import math
from dataclasses import dataclass
from typing import Union
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file

from mamba_tiny.scans import selective_scan
from mamba_tiny.weights import (
    configure_mamba_weights,
    configure_mamba_block_weights,
    set_mamba_ideal_MQAR_weights,
    set_mamba_ideal_MQAR_projection_weights,
)

from utils import recorder


def load_config_hf(model_name):
    resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=True)
    return json.load(open(resolved_archive_file))


def load_state_dict_hf(model_id_or_path, device=None, dtype=None):
    # (dtype is currently ignored; added for compatibility)

    """
    Load a state_dict from either:
      • A HuggingFace repo identifier (“facebook/opt-125m”)
      • A local model directory containing config + weights
      • A raw .pt/.bin checkpoint file
    """
    # 1) raw file?
    if os.path.isfile(model_id_or_path):
        return torch.load(model_id_or_path, map_location=device)

    # 2) otherwise assume it's a hub ID or local dir:
    resolved = cached_file(
        model_id_or_path,
        WEIGHTS_NAME,
        _raise_exceptions_for_missing_entries=True
    )
    return torch.load(resolved, map_location=device)


@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 1
    conv_bias: bool = True
    bias: bool = False
    scan_mode: str = 'logcumsumexp'  # 'cumsum'
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)


def _get_default_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Mamba(nn.Module):
    def __init__(self, args: ModelArgs, **kwargs):
        """Full Mamba model."""
        super().__init__()

        self.config = kwargs

        if self.config.get('skip_biases', False):
            args.bias = False
            args.conv_bias = False

        self.args = args

        # notations
        V = self.args.vocab_size
        D = self.args.d_model
        N = self.args.d_state

        self.V = V
        self.D = D
        self.N = N

        self.device = _get_default_device()

        self.embedding = nn.Embedding(V, D)
        self.layers = nn.ModuleList([ResidualBlock(args, **self.config) for _ in range(args.n_layer)])

        if not self.config.get('skip_normalizations', False):
            self.norm_f = RMSNorm(D)

        self.lm_head = nn.Linear(D, V, bias=False)

        # Tie output projection to embedding weights. See "Weight Tying" paper
        self.lm_head.weight = self.embedding.weight


        if self.config.get('set_ideal_mqar_weights', False):
            set_mamba_ideal_MQAR_weights(self)
        elif self.config.get('set_ideal_mqar_projections', False):
            set_mamba_ideal_MQAR_projection_weights(self, freeze=True)
        else:
            # optional: (todo)
            configure_mamba_weights(self, **self.config)

    def forward(self, input_ids, num_last_tokens=0):
        """
        Args:
            input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            logits: shape (b, l, vocab_size)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """

        x = self.embedding(input_ids)
        recorder.record(self, x_input_ids=input_ids, x_input_embeddings=x)

        # just for recording
        E_in = self.embedding.weight
        E_out = self.lm_head.weight

        recorder.record(
            self, E_in=E_in, E_out=E_out,
        )

        # recorder.record(self, R_E=(E_out @ E_in.T))

        for i, layer in enumerate(self.layers):
            recorder.record(self, **{f"layer_{i}_x_input": x})
            x = layer(x)
            recorder.record(self, **{f"layer_{i}_y_output": x})

        if not self.config.get('skip_normalizations', False):
            x = self.norm_f(x)
            recorder.record(self, x_normalized=x)

        x = self.lm_head(x)
        recorder.record(self, y_output_logits=x)

        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])

        recorder.record(self, y_all_tokens_logits=x)

        # temp - TODO
        if num_last_tokens > 0:
            x = x[:, -num_last_tokens:]

        recorder.record(self, y_last_token_logits=x)

        x = CausalLMOutput(logits=x)

        return x

    @staticmethod
    def from_pretrained(pretrained_model_name: str, model=None, model_config=None, device=None, dtype=None, **kwargs):
        """Load pretrained weights from HuggingFace into model.
    
        Args:
            pretrained_model_name: One of
                * 'state-spaces/mamba-2.8b-slimpj'
                * 'state-spaces/mamba-2.8b'
                * 'state-spaces/mamba-1.4b'
                * 'state-spaces/mamba-790m'
                * 'state-spaces/mamba-370m'
                * 'state-spaces/mamba-130m'
                            
        Returns:
            model: Mamba model with weights loaded
    
        """

        if device is None:
            device = _get_default_device()
    
        if model is None:
            if model_config is None:
                model_config = load_config_hf(pretrained_model_name)

            model_args = ModelArgs(**model_config)
            model = Mamba(model_args)
        
        pretrained_dict = load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
        model_dict = model.state_dict()
        
        for k, v in pretrained_dict.items():
            k_new = k.replace('backbone.', '')
            if k_new in model_dict and v.size() == model_dict[k_new].size():
                model_dict[k_new] = pretrained_dict[k]
        
        model.load_state_dict(model_dict)
        return model

    def save_pretrained(self, save_directory):
        """
        Minimal implementation of save_pretrained for MambaLMHeadModel.
        Save the model and its configuration file to a directory.
        """
        # Ensure save_directory exists
        os.makedirs(save_directory, exist_ok=True)

        # Save the model's state_dict
        model_path = os.path.join(save_directory, 'pytorch_model.bin')
        torch.save(self.state_dict(), model_path)

        # Save the configuration of the model
        config_path = os.path.join(save_directory, 'config.json')
        with open(config_path, 'w') as f:
            json.dump(self.args.__dict__, f, indent=4)

    def generate(
            self, prompt,
            tokenizer=None, max_length: int = 50,
            sample: bool = False, top_k: int = 40,
    ):

        return _generate(
            model=self, prompt=prompt,
            tokenizer=tokenizer, max_length=max_length,
            sample=sample, top_k=top_k,
        )


def _generate(
        model, prompt,
        tokenizer=None, max_length: int = 50,
        sample: bool = False, top_k: int = 40,
):

    model.eval()

    if tokenizer is not None:
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    else:
        input_ids = prompt

    n_tokens_to_gen = max_length - input_ids.size(1)

    assert n_tokens_to_gen > 0

    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            raw_output = model(indices_to_input).logits
            next_token_logits = raw_output[:, -1]

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape

        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_indices], dim=1)

    if tokenizer is not None:
        output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    else:
        output_completions = input_ids

    return output_completions


class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs, **config):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()

        self.args = args
        self.config = config
        self.mixer = MambaBlock(args, **config)
        
        # notation
        D = args.d_model
        
        if not self.config.get("skip_normalizations", False):
            self.norm = RMSNorm(D)
        
    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)

        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        x_residual = x
        recorder.record(self, block_x_residual=x_residual) 
        if not self.config.get("skip_normalizations", False):
            x = self.norm(x)
            recorder.record(self, x_normalized=x) 

        y_mixer_output = self.mixer(x)
        recorder.record(self, block_y_mixer_output=y_mixer_output)

        if not self.config.get("skip_residual_connection", False):
            y_output = y_mixer_output + x_residual
            recorder.record(self, block_y_residual_added_output=y_output)
        else:
            y_output = y_mixer_output

        recorder.record(self, block_y_output=y_output)

        return y_output


class MambaBlock(nn.Module):

    def __init__(self, args: ModelArgs, **config):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()

        self.args = args
        self.config = config

        self.d_inner = args.d_inner
        self.d_state = args.d_state

        # notations
        D = args.d_model
        D_in = args.d_inner
        N = args.d_state

        self.device = _get_default_device()

        self.in_proj_x = nn.Linear(D, D_in, bias=args.bias)
        self.out_proj_y = nn.Linear(D_in, D, bias=args.bias)

        if not self.config.get("skip_gating", False):
            self.in_proj_z = nn.Linear(D, D_in, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=D_in,
            out_channels=D_in,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=D_in,
            padding=args.d_conv - 1,
        )

        # originally:
        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        # dt_proj projects Δ from dt_rank to d_in

        self.x_proj_to_B = nn.Linear(D_in, N, bias=False)
        self.x_proj_to_C = nn.Linear(D_in, N, bias=False)

        if not self.config.get("skip_discretization", False):
            self.x_proj_to_delta = nn.Linear(D_in, D_in, bias=True)

        if self.config.get("init_A", None) is None:
            A = repeat(torch.arange(1, N + 1), 'n -> d n', d=D_in).to(self.device)
            self.A_log = nn.Parameter(torch.log(A), requires_grad=True)
            self.A = -torch.exp(self.A_log.float())  # shape (d_in, n)

        if self.config.get("init_D", None) is None:
            self.D = nn.Parameter(torch.ones(D_in))

        # optional: (todo)
        configure_mamba_block_weights(self, **self.config)

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        x_in = x
        recorder.record(self, x_in=x_in)
        
        (b, l, d) = x.shape
        
        x = self.in_proj_x(x_in)  # shape (b, l, d_in)
        
        recorder.record(self, x_in_proj=x)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')

        recorder.record(self, x_conv1d=x)
        
        if not self.config.get("skip_non_linearities", False):
            
            x = F.silu(x)  # original
            # x = F.relu(x)
            # x = F.sigmoid(x)
            # x = F.tanh(x)
            
            recorder.record(self, x_conv1d_act=x)

        y = self.ssm(x)

        recorder.record(self, y_ssm_output=y)

        # z stuff
        if not self.config.get("skip_gating", False):
            
            z = self.in_proj_z(x_in)  # shape (b, l, d_in)
            recorder.record(self, z=z)
            
            if not self.config.get("skip_non_linearities", False):
                z = F.silu(z)
                recorder.record(self, z_silu=z)
            
            y = y * z
            recorder.record(self, y_times_z=y)
        
        y = self.out_proj_y(y)
        recorder.record(self, y_out_proj=y)

        return y

    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        if self.config.get('init_A', None) is not None:
            A = self.A.float()
        else:
            A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        
        D = self.D.float()

        # just for recording
        S_B = self.x_proj_to_B.weight
        S_C = self.x_proj_to_C.weight
        R_S = S_C.T @ S_B
        recorder.record(self, S_B=S_B, S_C=S_C, R_S=R_S)

        B = self.x_proj_to_B(x)  # (b, l, n)
        C = self.x_proj_to_C(x)  # (b, l, n)

        if not self.config.get("skip_discretization", False):
            delta = self.x_proj_to_delta(x)  # (b, l, n)
            recorder.record(self, delta_proj=delta)
            delta = F.softplus(delta)  # (b, l, d_in)
            recorder.record(self, delta_softplus=delta)
        else:
            delta_scale = 1  # 0.01
            delta = delta_scale * torch.ones_like(x)  # (b, l, d_in)

        recorder.record(self, A=A, B=B, C=C, D=D, delta=delta)

        y = selective_scan(x, delta, A, B, C, D, mode=self.args.scan_mode, config=self.config)
        recorder.record(self, y_scan_output=y)

        return y  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]


class RMSNorm(nn.Module):
    
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        
        super().__init__()
        self.eps = eps
        self.d_model = d_model

        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):

        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output
