import torch
import torch.nn as nn
from typing import Tuple, Optional, Union
import torch.nn.functional as F
from ....utils.registry import layer_registry, norm_registry, init_registry
from ...base import Base
from ..util import apply_rotary_pos_emb
from ...op import get_eff_attention, ascan, ascan_fixed_size



@layer_registry.register("ratplus16localprefixfgate")
class RATPlus16LocalPrefixFgate(Base):

    def __init__(
        self,
        d_model,
        num_head=16,
        bias=False,
        init=None,
        ln="rmsnorm",
        chunk_size=1,
        chunk_size1=64,
        total_size=64,
        prefix_size=0,
        local_size=0,
        apply_re=True,
        mix_train=True,
        **kwargs,
    ):
        super().__init__()
        factory_kwargs = {"device": kwargs.get("device", "cuda"),
                          "dtype": kwargs.get("dtype", torch.float32)}
        self.layer_id = kwargs.get("layer_id", 0)
        self.d_model = d_model
        self.num_head = num_head
        assert self.d_model % self.num_head == 0
        assert bias is False
        self.d_head = self.d_model // self.num_head
        self.softmax_scale = self.d_head ** -0.5
        # q, k, x, f, g
        self.in_proj = nn.Linear(d_model, 5 * self.d_model, bias=bias, **factory_kwargs)
        self.input_norm = norm_registry[ln](self.d_model, eps=1.0e-6, **factory_kwargs)
        self.out_proj = nn.Linear(self.d_model, self.d_model, bias=bias, **factory_kwargs)
        self.init = init
        self.mix_train = mix_train
        self.apply_re = apply_re
        if self.mix_train:
            self.chunk_size = chunk_size
            assert self.chunk_size == 1
        else:
            self.chunk_size = chunk_size1
        self.chunk_size1 = chunk_size1 if type(chunk_size1) is int else chunk_size1[self.layer_id]
        self.total_size = total_size
        self.prefix_size = prefix_size if type(prefix_size) is int else prefix_size[self.layer_id]
        self.local_size = local_size if type(local_size) is int else local_size[self.layer_id]
        self.eff_attention = get_eff_attention(softmax_scale=self.softmax_scale, chunk_size=self.chunk_size1, local_size=self.local_size, prefix_size=self.prefix_size)

    def init_weight(self, init_config):
        super().init_weight(init_config)

    def apply_rope(self, q, k, **kwargs):
        rotary_pos_emb = kwargs.get(f"rope", None)
        q_rope, k_rope = apply_rotary_pos_emb(q, k, rotary_pos_emb[0][None, None, :, :], rotary_pos_emb[1][None, None, :, :])
        return q_rope.to(k.dtype), k_rope.to(k.dtype)

    def prepare_input(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        inp = self.in_proj(hidden_states)
        o, q, k, x, g = torch.split(inp, [self.d_model, self.d_model, self.d_model, self.d_model, self.d_model], dim=-1)
        return o, q, k, x, g

    def forward(self, hidden_states, cache=None, **kwargs):
        assert cache is None
        bs, seq_len, _ = hidden_states.shape
        # Graph begins
        shortcut = hidden_states
        hidden_states = self.input_norm(hidden_states)
        o, q, k, x, g = self.prepare_input(hidden_states)
        g = torch.sigmoid(g.to(torch.float32))
        q, k, x, g = [m.reshape(bs, seq_len, self.num_head, self.d_head).transpose(1, 2) for m in (q, k, x, g)]
        # RNN begins
        if self.apply_re:
            gated_kx = ascan_fixed_size(g.repeat(1, 1, 1, 2), torch.cat([k, x], dim=-1), self.total_size)
            gated_k, gated_x = gated_kx[..., :self.d_head].to(torch.bfloat16), gated_kx[..., self.d_head:].to(torch.bfloat16)
        else:
            gated_k, gated_x = k, x
        q, gated_k = self.apply_rope(q, gated_k, **kwargs)

        # cross-chunk attention
        if self.mix_train and self.training:
            mode = kwargs.get("mode", "dense")
            if mode == "sparse":
                out = self.eff_attention(q, gated_k, gated_x)
            else:
                out = F.scaled_dot_product_attention(q, gated_k, gated_x, is_causal=True).transpose(1, 2).reshape(bs, seq_len, -1)
        else:
            out = self.eff_attention(q, gated_k, gated_x)
        out = out * torch.sigmoid(o)
        final_out = self.out_proj(out) + shortcut
        return final_out

    def step(self, hidden_states, cache=None, **kwargs): # (b, 1, d)
        pass

    def nflops(self, bs, seq_len):
        pass

    @property
    def nparams(self,):
        return sum([w.numel() for w in self.parameters()])

    @staticmethod
    def get_ckpt_name(model_config):
        return model_config._name_ + f"t{model_config.total_size}l{model_config.chunk_size}l{model_config.chunk_size1}p{model_config.prefix_size}w{model_config.local_size}" + f"re{model_config.apply_re}" + f"mixtrain{model_config.mix_train}" + f"ropeogate" + f"fFalse"

    def extra_repr(self):
        return f"d_model={self.d_model}, nhead={self.num_head}, chunk_size1={self.chunk_size1}, prefix_size={self.prefix_size}, local_size={self.local_size}, re={self.apply_re}, mixtrain={self.mix_train}"