
import torch
from torch import nn
from typing import Optional, Tuple, List, Union, Callable


from transformers.models.gpt2.modeling_gpt2 import Conv1D, GPT2MLP#, eager_attention_forward
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
)
from transformers import (
    GPT2Model,
)

from .mot_module import MoTMod, MoTLayerNorm, apply_residual

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    pass
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    # scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    **kwargs,
) -> Tuple[torch.Tensor, None]:
    pass
    return attn_output, None
    
        
class MoTGPT2Block(nn.Module):
    def __init__(self, config, layer_idx=None, modality_num=2):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        self.type_ids = None  # torch.LongTensor=None

        self.scale_attn_weights = config.scale_attn_weights
        self.is_cross_attention = False
        self.scale_attn_by_inverse_layer_idx = False
        max_positions = config.max_position_embeddings
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            ),
            persistent=False,
        )

        self.modality_num = modality_num
        modality_channels = [self.embed_dim] * modality_num
        self.modality_channels = modality_channels
        
        self.attn_out_dim = self.embed_dim*3
        self.c_attn = MoTMod(Conv1D(3 * self.embed_dim, self.embed_dim), modality_num, out_dim=self.attn_out_dim)
        self.c_proj = MoTMod(Conv1D(self.embed_dim, self.embed_dim), modality_num)

        # TODO: add dropout
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        
        # TODO: add gate ffn layers
        self.ln_1 = MoTLayerNorm(nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon), self.modality_num)
        self.ln_2 = MoTLayerNorm(nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon), self.modality_num)
 
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size

        self.mlp = MoTMod(GPT2MLP(inner_dim, config), modality_num)

    def update_typeids(self, type_ids):
        self.valid_pos = type_ids
        self.ln_1.update_typeids(type_ids)
        self.ln_2.update_typeids(type_ids)
    
    def apply_qkv_proj(self, hidden_states: torch.Tensor, module: torch.Tensor):
        if self.type_ids is None:
            self.type_ids = torch.zeros_like(hidden_states[..., 0]).long()
        type_ids = self.type_ids
        
        hidden_states_ = torch.zeros_like(hidden_states).repeat(1, 1, 3)
        for type_id, _ in enumerate(self.modality_channels):
            if torch.any(type_ids == type_id).cpu().tolist():
                hidden_states_[type_ids == type_id] = module[type_id](hidden_states[type_ids == type_id])
            
        return hidden_states_

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )
        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights
    
    def forward_attn_shared(
        self, 
        hidden_states: torch.Tensor,
        # type_ids: torch.Tensor=None,
        layer_past: List[torch.Tensor]=None,    # FIXME: do not support kv-cache here
        use_cache: bool=False,
        attention_mask: Optional[torch.FloatTensor] = None,
        **kwargs
    ):
        all_states = self.c_attn(hidden_states)
        query_states, key_states, value_states = all_states.split(self.split_size, dim=2)
        
        shape_q = (*query_states.shape[:-1], -1, self.head_dim)
        shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
        query_states = query_states.view(shape_q).transpose(1, 2)
        key_states = key_states.view(shape_kv).transpose(1, 2)
        value_states = value_states.view(shape_kv).transpose(1, 2)
        
        if layer_past is not None:
            past_key, past_value = layer_past
            key_states = torch.cat((past_key, key_states), dim=-2)
            value_states = torch.cat((past_value, value_states), dim=-2)

        if use_cache is True:
            pask_key_values = (key_states, value_states)
        else:
            pask_key_values = None
        
        attn_output, attn_weights = self._attn(query_states, key_states, value_states, 
                                            attention_mask=attention_mask, head_mask=None)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(*attn_output.shape[:-2], -1)
        
        final_attn_output = self.c_proj(attn_output)
        return (final_attn_output, pask_key_values)
    
    def forward(self,
            hidden_states: Optional[Tuple[torch.FloatTensor]],
            layer_past: Optional[Tuple[torch.Tensor]] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = False,
        ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        
        residual = hidden_states
        layernorm_output = self.ln_1(hidden_states)
        attn_output, pask_key_values = self.forward_attn_shared(
            layernorm_output, 
            layer_past = layer_past, 
            use_cache=use_cache,
            attention_mask=attention_mask,
        )
        hidden_states = apply_residual(attn_output, residual)

        residual = hidden_states
        layernorm_output = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(layernorm_output)
        hidden_states = apply_residual(feed_forward_hidden_states, residual)
        
        if use_cache:
            outputs = (hidden_states, pask_key_values)
        else:
            outputs = (hidden_states,)
        return outputs
    

class MoTGPT2Model(GPT2Model):
    def __init__(self, config, modality_num=2):
        super().__init__(config)
        self.embed_dim = config.hidden_size
        
        self.modality_num = modality_num
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([MoTGPT2Block(config, layer_idx=i, modality_num=modality_num) for i in range(config.num_hidden_layers)])
        self.ln_f = MoTLayerNorm(nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon), self.modality_num )

    def update_typeids(self, type_ids):
        self.valid_pos = type_ids
        for n in self.h:
            n.update_typeids(type_ids)
        self.ln_f.update_typeids(type_ids)

    def forward(
        self,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        input_shape = inputs_embeds[0].size()[:-1]
        batch_size = inputs_embeds[0].shape[0]
        device = inputs_embeds[0].device

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0)

        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        # Attention mask.
        _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
        attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
        if self._attn_implementation == "flash_attention_2":
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        elif _use_sdpa:
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask=attention_mask,
                input_shape=(batch_size, input_shape[-1]),
                inputs_embeds=inputs_embeds[0],
                past_key_values_length=past_length,
            )
        else:
            if attention_mask is not None:
                attention_mask = attention_mask[:, None, None, :]

                attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
                attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        encoder_attention_mask = None
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds
        hidden_states = self.drop(hidden_states)
    
        output_shape = (-1,) + input_shape[1:] + (hidden_states[0].size(-1),)

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None

        for i in range(len(self.h)):
            block, layer_past = self.h[i], past_key_values[i]
            
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )
            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

        hidden_states = self.ln_f(hidden_states)
        
        hidden_states = [hid.view(output_shape) for hid in hidden_states]
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )
