"""
This code uses:
- the nGPT Github implementation
https://github.com/NVIDIA/ngpt
- the DyT Github implementation
https://github.com/jiachenzhu/DyT/blob/main/dynamic_tanh.py
- FlashAttention
https://github.com/Dao-AILab/flash-attention
- Huggingface Transformers
https://github.com/huggingface/transformers
"""

import math
from typing import List, Optional

import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mha import FlashSelfAttention
from torch.nn import functional as F

torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True

from transformers import PretrainedConfig
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutput


class DynamicTanh(nn.Module):
    def __init__(self, normalized_shape, channels_last=True, alpha_init_value=0.5):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.alpha_init_value = alpha_init_value
        self.channels_last = channels_last

        self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        x = torch.tanh(self.alpha * x)
        if self.channels_last:
            x = x * self.weight + self.bias
        else:
            x = x * self.weight[:, None, None] + self.bias[:, None, None]
        return x

    def extra_repr(self):
        return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}, channels_last={self.channels_last}"


class Block(nn.Module):

    def __init__(self, config, layer):
        super().__init__()
        self.config = config

        if config.mode == "GPT2":
            self.use_GPT2 = True
            self.use_nGPT = False
            self.use_aGPT = False
        elif config.mode == "nGPT":
            self.use_GPT2 = False
            self.use_nGPT = True
            self.use_aGPT = False
        elif config.mode == "aGPT":
            self.use_GPT2 = False
            self.use_nGPT = False
            self.use_aGPT = True
        else:
            raise UserWarning(f"Unknown mode {config.mode}")

        self.Wqkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.att_c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.head_dim = config.n_embd // config.n_head

        self.num_heads = config.n_head

        rotary_emb_dim = 0.5 * self.head_dim
        rotary_emb_base = 10000.0
        rotary_emb_scale_base = None
        rotary_emb_interleaved = False

        self.rotary_emb = RotaryEmbedding(
            rotary_emb_dim,
            base=rotary_emb_base,
            scale_base=rotary_emb_scale_base,
            interleaved=rotary_emb_interleaved,
            device=None,
        )

        self.c_fc = nn.Linear(config.n_embd, 2 * 4 * config.n_embd, bias=config.bias)
        self.silu = nn.SiLU()
        self.mlp_c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)

        if self.use_GPT2:
            if config.GPT2_norm == "rms":
                self.rmsnorm_att = nn.RMSNorm(config.n_embd)
                self.rmsnorm_mlp = nn.RMSNorm(config.n_embd)
            elif config.GPT2_norm == "dyt":
                self.rmsnorm_att = DynamicTanh(config.n_embd, alpha_init_value=config.GPT2_DyT_alpha_att)
                self.rmsnorm_mlp = DynamicTanh(config.n_embd, alpha_init_value=config.GPT2_DyT_alpha_other)
            elif config.GPT2_norm == "none":
                self.rmsnorm_att = lambda x: x
                self.rmsnorm_mlp = lambda x: x

            if config.GPT2_ln_scaling:
                self.ln_Scale = nn.Parameter(torch.ones(1) * 1 / math.sqrt(1 + layer))
            else:
                self.ln_Scale = None

            if self.config.qk_norm:
                # Query-Key Normalization for Transformers (https://arxiv.org/abs/2010.04245)
                # g is init with log2(L^2 - L) which is ~20 for L=2048
                self.scalar_g = nn.Parameter(torch.ones(1) * 20)
                softmax_scale = 1.0
            else:
                sqrt_head_dim = (self.config.n_embd / self.config.n_head) ** 0.5
                softmax_scale = 1.0 / sqrt_head_dim

        elif self.use_nGPT:

            softmax_scale = math.sqrt(self.head_dim)

            self.attn_alpha_init_value = 0.05
            self.attn_alpha_init_scaling = config.base_scale
            self.attn_alpha = torch.nn.Parameter(
                self.attn_alpha_init_scaling * torch.ones(self.config.n_embd, dtype=torch.float32))

            self.mlp_alpha_init_value = 0.05
            self.mlp_alpha_init_scaling = config.base_scale
            self.mlp_alpha = torch.nn.Parameter(
                self.mlp_alpha_init_scaling * torch.ones(self.config.n_embd, dtype=torch.float32))

            self.sqk_init_value = 1.0
            self.sqk_init_scaling = config.base_scale
            self.sqk = torch.nn.Parameter(self.sqk_init_scaling * torch.ones(self.config.n_embd, dtype=torch.float32))

            self.suv_init_value = 1.0
            self.suv_init_scaling = 1.0
            self.suv = torch.nn.Parameter(
                self.suv_init_scaling * torch.ones(2 * 4 * config.n_embd, dtype=torch.float32))

        elif self.use_aGPT:

            softmax_scale = 1.0
            scaling_alpha = config.aGPT_alpha_scale
            shape_d = self.config.n_embd

            self.attn_alpha_init_value = self.config.alpha_init_value
            self.attn_alpha_init_scaling = scaling_alpha
            self.attn_alpha = torch.nn.Parameter(
                self.attn_alpha_init_scaling * torch.ones(shape_d, dtype=torch.float32),
                requires_grad=self.config.learn_alpha)

            self.mlp_alpha_init_value = self.config.alpha_init_value
            self.mlp_alpha_init_scaling = scaling_alpha
            self.mlp_alpha = torch.nn.Parameter(self.mlp_alpha_init_scaling * torch.ones(shape_d, dtype=torch.float32),
                                                requires_grad=self.config.learn_alpha)

            self.scale_g_init_value = math.sqrt(self.head_dim)
            self.scale_g_init_scaling = scaling_alpha
            self.scale_g = nn.Parameter(torch.ones(1, dtype=torch.float32) * scaling_alpha, requires_grad=True)

            sqk_init_value = 1.0 / math.sqrt(self.head_dim / self.config.n_embd)
            self.scale_qkv = torch.nn.Parameter(sqk_init_value * torch.ones(1, dtype=torch.float32),
                                                requires_grad=False)

            s_attn_out_init_value = 1.0 / math.sqrt(self.config.n_embd / self.head_dim)
            self.attn_out_scale = nn.Parameter(s_attn_out_init_value * torch.ones(1, dtype=torch.float32),
                                               requires_grad=False)

            s_mlp_uv_init_value = 1.0 / math.sqrt(config.n_embd * 4 / config.n_embd)
            self.scale_in_u = nn.Parameter(s_mlp_uv_init_value * torch.ones(1, dtype=torch.float32),
                                           requires_grad=False)
            self.scale_in_v = nn.Parameter(s_mlp_uv_init_value * torch.ones(1, dtype=torch.float32),
                                           requires_grad=False)

            s_acf_init_value = 3.74  # found by MC sampling with SwiGLU
            self.scale_acf = nn.Parameter(s_acf_init_value * torch.ones(1, dtype=torch.float32), requires_grad=False)

            s_mlp_out_init_value = 1.0 if config.post_norm else math.sqrt(config.n_embd * 4 / config.n_embd)
            self.scale_out = nn.Parameter(s_mlp_out_init_value * torch.ones(1, dtype=torch.float32),
                                          requires_grad=False)

        self.inner_attn = FlashSelfAttention(
            causal=True,
            softmax_scale=softmax_scale,
            attention_dropout=config.dropout,
            alibi_slopes=None,
            window_size=(-1, 1)
        )

    @torch.compiler.disable(recursive=True)
    def use_flash_attention(self, qkv):
        seqlen_offset = 0
        qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)

        cu_seqlens = None
        max_seqlen = None
        kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen}
        context = self.inner_attn(qkv, **kwargs)
        return context

    def forward(self, h):

        hin = h
        if self.use_GPT2:
            hin = self.rmsnorm_att(h)

            if self.ln_Scale is not None:
                hin = hin * self.ln_Scale

        qkv = self.Wqkv(hin)
        qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)

        if self.use_GPT2 and self.config.qk_norm:
            q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

            q = F.normalize(q, p=2, dim=-1)
            k = F.normalize(k, p=2, dim=-1)
            q = q * self.scalar_g
            qkv = torch.stack([q, k, v], dim=2).to(qkv.dtype)


        elif self.use_nGPT:
            q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
            sqk = (self.sqk * (self.sqk_init_value / self.sqk_init_scaling)).view(1, 1, self.config.n_head,
                                                                                  self.config.n_embd // self.config.n_head)

            if self.config.qk_norm:
                q = sqk * F.normalize(q, p=2, dim=-1)
                k = sqk * F.normalize(k, p=2, dim=-1)
            else:
                q = sqk * q
                k = sqk * k

            qkv = torch.stack([q, k, v], dim=2).to(qkv.dtype)

        elif self.use_aGPT:
            q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

            softmax_scale = self.scale_g * (self.scale_g_init_value / self.scale_g_init_scaling)
            softmax_scale = torch.abs(softmax_scale)
            if self.config.qk_norm:
                scale_q = softmax_scale
                scale_k = 1
            else:
                scale_q = softmax_scale * self.scale_qkv
                scale_k = self.scale_qkv

            v = v * self.scale_qkv

            if self.config.qk_norm:
                q = F.normalize(q, p=2, dim=-1) * scale_q
                k = F.normalize(k, p=2, dim=-1) * scale_k
            else:
                q = q * scale_q
                k = k * scale_k

            qkv = torch.stack([q, k, v], dim=2).to(qkv.dtype)

        context = self.use_flash_attention(qkv)

        h_att = self.att_c_proj(rearrange(context, "... h d -> ... (h d)"))

        if self.use_aGPT:
            h_att = h_att * self.attn_out_scale

        if self.use_GPT2:
            h = h + h_att

        elif self.use_nGPT or self.use_aGPT:
            lr = self.attn_alpha * (self.attn_alpha_init_value / self.attn_alpha_init_scaling)
            lr = torch.abs(lr)

            A_norm = h

            if self.config.post_norm:
                B_norm = F.normalize(h_att, p=2, dim=-1)
            else:
                B_norm = h_att

            if self.config.alpha_correction:
                attn_correct = 1.0 / torch.sqrt(1 - 2 * lr * (1 - lr))
                h = (A_norm + lr * (B_norm - A_norm)) * attn_correct
            else:
                res = A_norm + lr * (B_norm - A_norm)
                h = F.normalize(res, p=2, dim=-1)

        hin = h
        if self.use_GPT2:
            hin = self.rmsnorm_mlp(h)
            if self.ln_Scale is not None:
                hin = hin * self.ln_Scale

        uv = self.c_fc(hin)

        if self.use_GPT2:
            u, v = torch.chunk(uv, 2, dim=-1)
        elif self.use_nGPT:
            suv = (self.suv * ((self.suv_init_value / self.suv_init_scaling) * (self.config.n_embd ** 0.5)))
            uv = suv * uv
            u, v = torch.chunk(uv, 2, dim=-1)

        elif self.use_aGPT:
            u, v = torch.chunk(uv, 2, dim=-1)
            u = u * self.scale_in_u
            v = v * self.scale_in_v
            v = v * math.sqrt(self.config.n_embd)

        x_mlp = u * self.silu(v)

        if self.use_aGPT:
            x_mlp = x_mlp * self.scale_acf

        h_mlp = self.mlp_c_proj(x_mlp)

        if self.use_aGPT:
            h_mlp = h_mlp * self.scale_out

        if self.use_GPT2:
            h = h + h_mlp

        elif self.use_nGPT or self.use_aGPT:
            lr = self.mlp_alpha * (self.mlp_alpha_init_value / self.mlp_alpha_init_scaling)
            lr = torch.abs(lr)

            A_norm = h

            if self.config.post_norm:
                B_norm = F.normalize(h_mlp, p=2, dim=-1)
            else:
                B_norm = h_mlp

            if self.config.alpha_correction:
                attn_correct = 1.0 / torch.sqrt(1 - 2 * lr * (1 - lr))
                h = (A_norm + lr * (B_norm - A_norm)) * attn_correct
            else:
                res = A_norm + lr * (B_norm - A_norm)
                h = F.normalize(res, p=2, dim=-1)
        return h


class anGPTConfig(PretrainedConfig):
    model_type = "anGPT"

    def __init__(
            self,
            block_size: int = 1024,
            vocab_size: int = 50304,
            n_layer: int = 12,
            n_head: int = 12,
            n_embd: int = 1024,
            base_scale: float = 1.0 / (1024.0 ** 0.5),
            mode: str = 'GPT2',
            dropout: float = 0.0,
            bias: bool = False,
            qk_norm: bool = True,
            alpha_correction: bool = False,
            explicit_norm: bool = True,
            learn_alpha: bool = True,
            alpha_init_value: float = 0.05,
            post_norm: bool = True,
            GPT2_norm: str = 'rms',
            GPT2_DyT_alpha_att: float = 1.0,
            GPT2_DyT_alpha_other: float = 1.0,
            GPT2_ln_scaling: bool = False,
            scaled_projection: bool = False,
            aGPT_init_normalize: bool = True,
            aGPT_logits_scale: bool = False,
            aGPT_pre_head_scale: bool = True,
            **kwargs,
    ):
        super().__init__(**kwargs)

        self.block_size = block_size
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.base_scale = base_scale
        self.mode = mode
        self.dropout = dropout
        self.bias = bias
        self.qk_norm = qk_norm
        self.alpha_correction = alpha_correction
        self.explicit_norm = explicit_norm
        self.learn_alpha = learn_alpha
        self.alpha_init_value = alpha_init_value
        self.post_norm = post_norm
        self.GPT2_norm = GPT2_norm
        self.GPT2_DyT_alpha_att = GPT2_DyT_alpha_att
        self.GPT2_DyT_alpha_other = GPT2_DyT_alpha_other
        self.GPT2_ln_scaling = GPT2_ln_scaling
        self.scaled_projection = scaled_projection
        self.aGPT_init_normalize = aGPT_init_normalize
        self.aGPT_logits_scale = aGPT_logits_scale
        self.aGPT_pre_head_scale = aGPT_pre_head_scale


class anTransformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        if config.mode == "GPT2":
            self.use_GPT2 = True
            self.use_nGPT = False
            self.use_aGPT = False
        elif config.mode == "nGPT":
            self.use_GPT2 = False
            self.use_nGPT = True
            self.use_aGPT = False
        elif config.mode == "aGPT":
            self.use_GPT2 = False
            self.use_nGPT = False
            self.use_aGPT = True
        else:
            raise UserWarning(f"Unknown mode {config.mode}")

        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            drop=nn.Dropout(config.dropout),
            h=nn.ModuleList([Block(config, il) for il in range(config.n_layer)])
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.apply(self._init_weights)
        if self.config.scaled_projection:
            for pn, p in self.named_parameters():
                if pn.endswith('c_proj.weight'):
                    torch.nn.init.normal_(p, mean=0.0, std=config.base_scale / math.sqrt(2 * config.n_layer))

        if self.use_GPT2:
            if config.GPT2_norm == "rms":
                self.rmsnorm_f = nn.RMSNorm(config.n_embd)
            elif config.GPT2_norm == "dyt":
                self.rmsnorm_f = DynamicTanh(config.n_embd, alpha_init_value=config.GPT2_DyT_alpha_other)
                self.lernable_scalar = nn.Parameter(torch.ones(1))
            elif config.GPT2_norm == "none":
                self.rmsnorm_f = lambda x: x


        elif self.use_nGPT:
            self.sz_init_value = 1.00
            self.sz_init_scaling = config.base_scale
            self.sz = torch.nn.Parameter(self.sz_init_scaling * torch.ones(config.vocab_size, dtype=torch.float32))

        elif self.use_aGPT:
            if self.config.aGPT_pre_head_scale:
                self.pre_head_scale = nn.Parameter(torch.ones(config.n_embd, dtype=torch.float32))
                self.scale_factor = math.sqrt(config.n_embd)

            elif self.config.aGPT_logits_scale:
                self.sz_init_value = 1.00
                self.sz_init_scaling = config.aGPT_alpha_scale
                self.sz = torch.nn.Parameter(self.sz_init_scaling * torch.ones(config.vocab_size, dtype=torch.float32))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.base_scale)
            if self.use_aGPT and self.config.aGPT_init_normalize:
                module.weight.data.copy_(module.weight.data / module.weight.data.norm(p=2, dim=-1, keepdim=True))
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.base_scale)
            if self.use_aGPT and self.config.aGPT_init_normalize:
                module.weight.data.copy_(module.weight.data / module.weight.data.norm(p=2, dim=-1, keepdim=True))

    def forward(self, idx):

        x = self.transformer.wte(idx)

        if self.use_GPT2 and self.config.GPT2_norm == "dyt":
            x = x * self.lernable_scalar

        for l, block in enumerate(self.transformer.h):
            x = block(x)

        if self.use_GPT2:
            x = self.rmsnorm_f(x)

        if self.use_aGPT and self.config.aGPT_pre_head_scale:
            x = x * self.pre_head_scale * self.scale_factor

        logits = self.lm_head(x)

        if self.use_nGPT:
            sz = self.sz * (self.sz_init_value / self.sz_init_scaling)
            logits = sz * logits

        elif self.use_aGPT and self.config.aGPT_logits_scale:
            sz = self.sz * (self.sz_init_value / self.sz_init_scaling)
            logits = sz * logits

        return logits


class anTransformerForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = anGPTConfig
    base_model_prefix = 'model'
    supports_gradient_checkpointing = False
    _no_split_modules = ['VarLenMHA', 'MHA']
    _supports_cache_class = False

    def __init__(self, config, padding_idx=0):
        super().__init__(config)

        self.vocab_size = config.vocab_size
        self.model = anTransformer(config)

    def compile_model(self):
        self.model = torch.compile(self.model)

    def forward(self, input_ids, **kwargs):
        logits = self.model(input_ids)
        return CausalLMOutput(logits=logits)

    def tie_weights(self):
        pass

    def get_input_embeddings(self) -> nn.Module:
        return self.model.transformer.wte

    def set_input_embeddings(self, embeddings):
        self.model.transformer.wte = embeddings

    def get_output_embeddings(self) -> nn.Module:
        return self.model.lm_head

    def set_output_embeddings(self, embeddings):
        self.model.lm_head = embeddings

    def tie_weights(self):
        pass

    def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            attention_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: bool = True,
            num_logits_to_keep: Optional[int] = None,
            **kwargs
    ):
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {'inputs_embeds': inputs_embeds}
        else:
            model_inputs = {'input_ids': input_ids.contiguous()}

        if num_logits_to_keep is not None:
            model_inputs['num_logits_to_keep'] = num_logits_to_keep

        model_inputs.update({
            'past_key_values': past_key_values,
            'use_cache': use_cache,
            'attention_mask': attention_mask,
            'num_logits_to_keep': num_logits_to_keep,
        })
        return model_inputs
