# -*- coding: utf-8 -*-
from __future__ import annotations

import math
import warnings
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, Dict

import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (BaseModelOutputWithPast,
                                           CausalLMOutputWithPast)
from dataclasses import dataclass
from transformers.utils import ModelOutput
@dataclass
class BaseModelOutputWithPast_with_two_caches(ModelOutput):

    last_hidden_state: torch.FloatTensor = None
    past_key_values1: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    all_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class CausalLMOutputWithPast_with_two_caches(ModelOutput):
    logits: torch.FloatTensor = None
    loss: Optional[torch.FloatTensor] = None
    past_key_values1: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    all_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging

# from fla.layers.attn import Attention
from configuration_transformer_rnn import TransformerConfig_rnn

# import sys
# from attn_rnn import Attention_rnn ###########################################################
# from attn_svd import Attention_svd ###########################################################
# from attn import Attention         ###########################################################
# from gated_deltanet import GatedDeltaNet ###########################################################
# from rwkv7 import RWKV7Attention ###########################################################
# from attn_gated_delta import GatedDeltaNet_attention ###########################################################
# from scattering_mixer2 import Scattering_Mixer ###########################################################
from task_aware_delta_net import Task_Aware_Delta_Net ###########################################################

# from moe_rnn import CustomGRUCell, CustomRNNCell
from ttt_cross_layer import TTT_Cross_Layer

# from fla.models.transformer.configuration_transformer import TransformerConfig
from fla.models.utils import Cache
from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
                         RMSNorm)
from fla.modules.activations import swiglu_linear
from fla.modules.layernorm import rms_norm_linear

if TYPE_CHECKING:
    from transformers.processing_utils import Unpack

logger = logging.get_logger(__name__)

class TransformerMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        hidden_ratio: Optional[int] = None,
        intermediate_size: Optional[int] = None,
        hidden_act: str = 'swish',
        norm_first: bool = True,
        norm_eps: float = 1e-5
    ) -> TransformerMLP:
        super().__init__()

        self.hidden_size = hidden_size
        # the final number of params is `hidden_ratio * hidden_size^2`
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
        if hidden_ratio is None:
            hidden_ratio = 4
        if intermediate_size is None:
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
        self.hidden_ratio = hidden_ratio
        self.intermediate_size = intermediate_size
        self.norm_first = norm_first

        if norm_first:
            self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(
        self,
        x: torch.Tensor,
        **kwargs: Unpack[Any]
    ) -> torch.Tensor:
        if self.norm_first:
            x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
        else:
            x = self.gate_proj(x)
        gate, y = x.chunk(2, -1)
        return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)

class TransformerMLP_svd(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        hidden_ratio: Optional[int] = None,
        intermediate_size: Optional[int] = None,
        hidden_act: str = 'swish',
        norm_first: bool = True,
        norm_eps: float = 1e-5
    ) -> TransformerMLP_svd:
        super().__init__()

        self.hidden_size = hidden_size
        # the final number of params is `hidden_ratio * hidden_size^2`
        # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
        if hidden_ratio is None:
            hidden_ratio = 4
        if intermediate_size is None:
            intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
            intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
        self.hidden_ratio = hidden_ratio
        self.intermediate_size = intermediate_size
        self.norm_first = norm_first

        if norm_first:
            self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

        self.reflector_qkvo = nn.Linear(self.intermediate_size, self.hidden_size * 4)

    def forward(
        self,
        x: torch.Tensor,
        reflect: bool = False,
        **kwargs: Unpack[Any]
    ) -> torch.Tensor:
        if self.norm_first:
            x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
        else:
            x = self.gate_proj(x)
        gate, y = x.chunk(2, -1)
        hidden_states = swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
        if reflect:
            reflector_qkvo = swiglu_linear(gate, y, self.reflector_qkvo.weight, self.reflector_qkvo.bias)
            reflector_qkvo = nn.Sigmoid()(reflector_qkvo)
            return hidden_states, reflector_qkvo
        else:
            return hidden_states

class TransformerBlock_rnn(nn.Module):
    def __init__(self, config: TransformerConfig_rnn, layer_idx: int):
        super().__init__()

        self.hidden_size = config.hidden_size
        self.layer_idx = layer_idx

        if not config.norm_first:
            self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)

        self.head_dim = config.hidden_size // config.num_heads
        self.Task_Aware_Delta_Net = Task_Aware_Delta_Net(
            hidden_size=config.hidden_size,
            head_dim=self.head_dim,
            num_heads=config.num_heads,
            mode='chunk',
            rope_theta=config.rope_theta,
            max_position_embeddings=config.max_position_embeddings,
            norm_first=config.norm_first,
            norm_eps=config.norm_eps,
            layer_idx=layer_idx,
            concept_dim=config.concept_dim
        )
        # use_ttt = True
        # if use_ttt:
        #     self.rnn_router = TTT_Cross_Layer(config)
        # else:
        #     self.rnn_router = CustomGRUCell(config)

        if not config.norm_first:
            self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
        self.mlp = TransformerMLP(
            hidden_size=config.hidden_size,
            hidden_ratio=config.hidden_ratio,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            norm_first=config.norm_first,
            norm_eps=config.norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values1: Optional[Tuple[torch.Tensor]] = None,
        all_past_key_values: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        h_old: Optional[torch.Tensor] = None,
        rnn_router: Optional[nn.Module] = None,
        params: Optional[Dict] = None,
        **kwargs: Unpack[Any]
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

        residual = hidden_states
        if hasattr(self, 'attn_norm'):
            hidden_states = self.attn_norm(hidden_states)
    
        hidden_states, attentions, past_key_values1, all_past_key_values, h_new, params = self.Task_Aware_Delta_Net(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            past_key_values1=past_key_values1,
            all_past_key_values=all_past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            rnn_router=rnn_router,
            h_old=h_old,
            params=params,
            **kwargs
        )
            # if self.rnn_router is not None:
            #     hidden_states = self.rnn_router(hidden_states, **kwargs)

        if hasattr(self, 'mlp_norm'):
            hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
        else:
            hidden_states = residual + hidden_states
            residual = hidden_states

        hidden_states = self.mlp(hidden_states, **kwargs)
        hidden_states = residual + hidden_states    
       
        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attentions,)

        if use_cache:
            outputs += (past_key_values1, all_past_key_values)

        outputs += (h_new,)
        outputs += (params,)
        return outputs

class TransformerPreTrainedModel_rnn(PreTrainedModel):

    config_class = TransformerConfig_rnn
    supports_gradient_checkpointing = True
    _no_split_modules = ['TransformerBlock_rnn']

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(
        self,
        module: nn.Module,
        rescale_prenorm_residual: bool = False,
        num_residuals_per_layer: int = 2,
    ):
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif hasattr(module, 'reset_parameters'):
            module.reset_parameters()

        if rescale_prenorm_residual:
            # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
            #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
            #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
            #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
            #
            # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
            for name, p in module.named_parameters():
                if name in ["o_proj.weight", "down_proj.weight"]:
                    # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                    # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                    # We need to reinit p since this code could be called multiple times
                    # Having just p *= scale would repeatedly scale it down
                    with torch.no_grad():
                        p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)


class TransformerModel_rnn(TransformerPreTrainedModel_rnn):
    def __init__(
        self,
        config: TransformerConfig_rnn
    ) -> TransformerModel_rnn:
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.concept_dim = config.concept_dim

        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([TransformerBlock_rnn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)

        self.gradient_checkpointing = False

        self.post_init()

        self.rnn_router = TTT_Cross_Layer(config)

    def get_input_embeddings(self):
        return self.embeddings

    def set_input_embeddings(self, value):
        self.embeddings = value

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values1: Optional[List[torch.FloatTensor]] = None,
        all_past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs: Unpack[Any]
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if output_attentions:
            warnings.warn(
                "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
            )
            output_attentions = False
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is None and inputs_embeds is None:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if use_cache and not isinstance(past_key_values1, Cache):
            past_key_values1 = Cache.from_legacy_cache(past_key_values1)
        if use_cache and not isinstance(all_past_key_values, Cache):
            all_past_key_values = Cache.from_legacy_cache(all_past_key_values)

        if inputs_embeds is None:
            inputs_embeds = self.embeddings(input_ids)

        # embed positions
        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_hidden_states = () if output_hidden_states else None
        all_attns = () if output_attentions else None
        next_cache1 = None
        next_cache2 = None
        h_old = None
        params = None
        
        for layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer.__call__,
                    hidden_states,
                    attention_mask,
                    past_key_values1,
                    all_past_key_values,
                    output_attentions,
                    use_cache,
                    h_old=h_old,
                    params=params,
                    rnn_router=self.rnn_router,
                    **kwargs
                )
            else:
                layer_outputs = layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    past_key_values1=past_key_values1,
                    all_past_key_values=all_past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    h_old=h_old,
                    params=params,
                    rnn_router=self.rnn_router,
                    **kwargs
                )

            hidden_states = layer_outputs[0]

            h_old = layer_outputs[-2]
            params = layer_outputs[-1]
            if use_cache:
                next_cache1 = layer_outputs[2 if output_attentions else 1]
                next_cache2 = layer_outputs[3 if output_attentions else 2]

            if output_attentions:
                all_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache1, all_hidden_states, all_attns] if v is not None)

        # return BaseModelOutputWithPast_with_two_caches(
        #     last_hidden_state=hidden_states,
        #     past_key_values1=next_cache1,
        #     all_past_key_values=next_cache2,
        #     hidden_states=all_hidden_states,
        #     attentions=all_attns
        # )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache1,
            hidden_states=all_hidden_states,
            attentions=all_attns
        )
    
class TransformerForCausalLM_rnn(TransformerPreTrainedModel_rnn, GenerationMixin):

    _tied_weights_keys = ["lm_head.weight"]
    def __init__(self, config):
        super().__init__(config)
        self.model = TransformerModel_rnn(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embeddings

    def set_input_embeddings(self, value):
        self.model.embeddings = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.LongTensor = None,
        past_key_values1: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        all_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: bool = True,
        num_logits_to_keep: Optional[int] = None,
        **kwargs
    ):
        # only last token for `inputs_ids` if the `past_key_values` is passed along.
        if past_key_values1 is not None:
            input_ids = input_ids[:, -1:]
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values1 is None:
            model_inputs = {'inputs_embeds': inputs_embeds}
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard.
            # Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs = {'input_ids': input_ids.contiguous()}

        if num_logits_to_keep is not None:
            model_inputs['num_logits_to_keep'] = num_logits_to_keep
        # model_inputs.update({
        #     'past_key_values1': past_key_values1,
        #     'all_past_key_values': all_past_key_values,
        #     'use_cache': use_cache,
        #     'attention_mask': attention_mask,
        #     'num_logits_to_keep': num_logits_to_keep,
        # })
        model_inputs.update({
            'past_key_values1': past_key_values1,
            'use_cache': use_cache,
            'attention_mask': attention_mask,
            'num_logits_to_keep': num_logits_to_keep,
        })
        return model_inputs

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values1: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        all_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        num_logits_to_keep: Optional[int] = 0,
        **kwargs: Unpack[Any]
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values1=past_key_values1,
            all_past_key_values=all_past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs
        )

        hidden_states = outputs[0]
        fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
        logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
        loss = None
        if labels is not None:
            if self.config.fuse_cross_entropy:
                if fuse_linear_and_cross_entropy:
                    loss_fct = FusedLinearCrossEntropyLoss()
                else:
                    loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
            else:
                loss_fct = nn.CrossEntropyLoss()
            # Enable model parallelism
            # labels = labels.to(hidden_states.device)
            # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
            if fuse_linear_and_cross_entropy:
                loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
                                labels.view(-1),
                                self.lm_head.weight,
                                self.lm_head.bias)
            else:
                loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        # return CausalLMOutputWithPast_with_two_caches(
        #     loss=loss,
        #     logits=logits,
        #     past_key_values1=outputs.past_key_values1,
        #     all_past_key_values=outputs.all_past_key_values,
        #     hidden_states=outputs.hidden_states,
        #     attentions=outputs.attentions,
        # )
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
if __name__ == '__main__':
    config = TransformerConfig_rnn(
        concept_dim=128,
        attention_bias=False,
        bos_token_id=1,
        eos_token_id=2,
        fuse_cross_entropy=True,
        fuse_norm=True,
        hidden_act="swish",
        hidden_size=1024,
        initializer_range=0.02,
        max_position_embeddings=8192,
        model_type="transformer_rnn",
        num_heads=16,
        num_hidden_layers=24,
        norm_eps=1e-06,
        tie_word_embeddings=True,
        use_cache=True,
        vocab_size=32000,
    )
    model = TransformerForCausalLM_rnn(config).cuda().to(torch.bfloat16)
    input_ids = torch.randint(0, 100, (2, 70)).cuda()
    attention_mask = torch.ones_like(input_ids).cuda()
    output = model(input_ids, attention_mask=attention_mask)
    print(output)
    print(output.loss)
    print(output.logits)
    print(output.all_past_key_values)
    print(output.hidden_states)
    print(output.attentions)
