import torch
import torch.nn as nn
from typing import Tuple, Optional, Union
import torch.nn.functional as F
from ....utils.registry import layer_registry, norm_registry, init_registry
from ...base import Base
from ..util import apply_rotary_pos_emb
from ...op import get_eff_attention, ascan
from ..cache import RATPlusSingleLayerCache, RATPlusFullSingleLayerCache


@layer_registry.register("ratplus16localprefixfgatesimple")
class RATPlus16LocalPrefixFgateSimple(Base):

    def __init__(
        self,
        d_model,
        num_head=16,
        bias=False,
        init=None,
        ln="rmsnorm",
        chunk_size=1,
        chunk_size1=64,
        prefix_size=0,
        local_size=0,
        apply_re=True,
        mix_train=True,
        **kwargs,
    ):
        super().__init__()
        factory_kwargs = {"device": kwargs.get("device", "cuda"),
                          "dtype": kwargs.get("dtype", torch.float32)}
        self.layer_id = kwargs.get("layer_id", 0)
        self.d_model = d_model
        self.num_head = num_head
        assert self.d_model % self.num_head == 0
        assert bias is False
        self.d_head = self.d_model // self.num_head
        self.softmax_scale = self.d_head ** -0.5
        # q, k, x, f, g
        self.in_proj = nn.Linear(d_model, 5 * self.d_model if apply_re else 4 * self.d_model, bias=bias, **factory_kwargs)
        self.input_norm = norm_registry[ln](self.d_model, eps=1.0e-6, **factory_kwargs)
        self.out_proj = nn.Linear(self.d_model, self.d_model, bias=bias, **factory_kwargs)
        self.init = init
        self.mix_train = mix_train
        self.apply_re = apply_re
        if self.mix_train:
            self.chunk_size = chunk_size
            assert self.chunk_size == 1
        else:
            self.chunk_size = chunk_size1
        self.chunk_size1 = chunk_size1 if type(chunk_size1) is int else chunk_size1[self.layer_id]
        self.prefix_size = prefix_size if type(prefix_size) is int else prefix_size[self.layer_id]
        self.local_size = local_size if type(local_size) is int else local_size[self.layer_id]
        self.eff_attention = get_eff_attention(softmax_scale=self.softmax_scale, chunk_size=self.chunk_size1, local_size=self.local_size, prefix_size=self.prefix_size)

    def init_weight(self, init_config):
        super().init_weight(init_config)

    def apply_rope(self, q, k, **kwargs):
        rotary_pos_emb = kwargs.get(f"rope", None)
        if rotary_pos_emb is None:
            raise NotImplementedError
        q_rope, k_rope = apply_rotary_pos_emb(q, k, rotary_pos_emb[0][None, None, :, :], rotary_pos_emb[1][None, None, :, :])
        return q_rope.to(k.dtype), k_rope.to(k.dtype)

    def prepare_input(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        inp = self.in_proj(hidden_states)
        o, q, k, x, g = torch.split(inp, [self.d_model, self.d_model, self.d_model, self.d_model, self.d_model if self.apply_re else 0], dim=-1)
        return o, q, k, x, g

    def forward(self, hidden_states: torch.Tensor, cache: Optional[Union[RATPlusFullSingleLayerCache, RATPlusSingleLayerCache]], **kwargs):
        bs, seq_len, _ = hidden_states.shape
        # Graph begins
        shortcut = hidden_states
        hidden_states = self.input_norm(hidden_states)
        o, q, k, x, g = self.prepare_input(hidden_states)
        # RNN begins
        if self.apply_re:
            g = torch.sigmoid(g.to(torch.float32))
            q, k, x, g = [m.reshape(bs, seq_len, self.num_head, self.d_head).transpose(1, 2) for m in (q, k, x, g)]
            g = g.repeat(1, 1, 1, 2)
            gated_kx = ascan(g, (1.0 - g) * torch.cat([k, x], dim=-1))
            gated_k, gated_x = gated_kx[..., :self.d_head].to(torch.bfloat16), gated_kx[..., self.d_head:].to(torch.bfloat16)
            if cache is not None:
                seq_pos = kwargs.get("seq_pos", seq_len - 1)
                cache.lastkcache[cache.bs_start: cache.bs_start + bs].copy_(gated_k[:, :, seq_pos: seq_pos + 1])
                cache.lastvcache[cache.bs_start: cache.bs_start + bs].copy_(gated_x[:, :, seq_pos: seq_pos + 1])
        else:
            q, k, x = [m.reshape(bs, seq_len, self.num_head, self.d_head).transpose(1, 2).contiguous() for m in (q, k, x)]
            gated_k, gated_x = k, x
        q, gated_k = self.apply_rope(q, gated_k, **kwargs)

        if cache is not None:
            seq_pos = kwargs.get("seq_pos", seq_len - 1)
            cache.update_kv_prefill(seq_pos, bs, gated_k, gated_x)
        # cross-chunk attention
        if self.mix_train and self.training:
            mode = kwargs.get("mode", "dense")
            if mode == "sparse":
                out = self.eff_attention(q, gated_k, gated_x)
            else:
                out = F.scaled_dot_product_attention(q, gated_k, gated_x, is_causal=True).transpose(1, 2).reshape(bs, seq_len, -1)
        else:
            out = self.eff_attention(q, gated_k, gated_x)
        out = out * torch.sigmoid(o)
        final_out = self.out_proj(out) + shortcut
        return final_out

    def step(self, hidden_states, cache: Union[RATPlusSingleLayerCache, RATPlusFullSingleLayerCache], seq_pos=0, **kwargs): # (b, 1, d)
        bs, seq_len, _ = hidden_states.shape
        # Graph begins
        shortcut = hidden_states
        hidden_states = self.input_norm(hidden_states)
        o, q, k, x, g = self.prepare_input(hidden_states)
        # RNN begins
        if self.apply_re:
            g = torch.sigmoid(g.to(torch.float32))
            q, k, x, g = [m.reshape(bs, seq_len, self.num_head, self.d_head).transpose(1, 2) for m in (q, k, x, g)]
            lastkcache, lastvcache = cache.lastkcache[cache.bs_start: cache.bs_start + bs], cache.lastvcache[cache.bs_start: cache.bs_start + bs]
            gated_k = (g * lastkcache + (1.0 - g) * k).to(torch.bfloat16)
            gated_x = (g * lastvcache + (1.0 - g) * x).to(torch.bfloat16)
            lastkcache.copy_(gated_k)
            lastvcache.copy_(gated_x)
        else:
            q, k, x  = [m.reshape(bs, seq_len, self.num_head, self.d_head).transpose(1, 2) for m in (q, k, x)]
            gated_k, gated_x = k, x
        q, gated_k = self.apply_rope(q, gated_k, **kwargs)
        kcache, vcache = cache.get_kv_step(seq_pos, bs, gated_k, gated_x)
        out = F.scaled_dot_product_attention(q, kcache, vcache, is_causal=False).transpose(1, 2).reshape(bs, seq_len, self.d_model)
        # Update cache, update the new token, and move seq_start and seq_end
        cache.update_kv_step(seq_pos, bs, gated_k, gated_x)
        out = out * torch.sigmoid(o)
        final_out = self.out_proj(out) + shortcut
        return final_out

    def nflops(self, bs, seq_len):
        pass

    @property
    def nparams(self,):
        return sum([w.numel() for w in self.parameters()])

    @staticmethod
    def get_ckpt_name(model_config):
        chunk_size1 = model_config.chunk_size1
        local_size = model_config.local_size
        if type(chunk_size1) is list and len(chunk_size1) > 2:
            chunk_size1 = str(chunk_size1[0]) + str(chunk_size1[1])
        if type(local_size) is list and len(local_size) > 2:
            local_size = str(local_size[0]) + str(local_size[1])
        return model_config._name_ + f"l{model_config.chunk_size}l{chunk_size1}p{model_config.prefix_size}w{local_size}" + f"re{model_config.apply_re}" + f"mixtrain{model_config.mix_train}" + f"ropepostogate" + f"fFalse" # previsouly, onlyA here

    def extra_repr(self):
        return f"d_model={self.d_model}, nhead={self.num_head}, chunk_size1={self.chunk_size1}, prefix_size={self.prefix_size}, local_size={self.local_size}, re={self.apply_re}, mixtrain={self.mix_train}"