import torch
import torch.nn as nn
import torch.nn.functional as F

import copy
from typing import Optional, Any, Union, Callable, Tuple

import torch
from torch import Tensor
from torch.nn import MultiheadAttention
from torch.nn import Dropout
from torch.nn import Linear
from torch.nn import LayerNorm
from transformers.models.gptj.modeling_gptj import GPTJMLP, GPTJBlock, GPTJAttention


def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
    sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
    return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)

def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')

def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
    sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
    cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
    return (tensor * cos) + (rotate_every_two(tensor) * sin)


class GPTJBiAttention(GPTJAttention):
    def __init__(self, config):
        super().__init__(config)
        
        pos_embd_dim = config.rotary_dim or config.hidden_size
        max_positions = config.max_position_embeddings
        self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)

    def _attn(
            self,
            query,
            key,
            value,
            attention_mask=None,
            head_mask=None,
        ):
            # compute causal mask from causal mask buffer
            query_length, key_length = query.size(-2), key.size(-2)

            # Keep the attention weights computation in fp32 to avoid overflow issues
            query = query.to(torch.float32)
            key = key.to(torch.float32)

            attn_weights = torch.matmul(query, key.transpose(-1, -2))
            attn_weights = attn_weights / self.scale_attn

            if attention_mask is not None:
                # Apply the attention mask
                attn_weights = attn_weights + attention_mask

            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
            attn_weights = attn_weights.to(value.dtype)
            attn_weights = self.attn_dropout(attn_weights)

            # Mask heads if we want to
            if head_mask is not None:
                attn_weights = attn_weights * head_mask

            attn_output = torch.matmul(attn_weights, value)

            return attn_output, attn_weights
        
    def _get_embed_positions(self, position_ids):
        embed_positions = self.embed_positions
        if embed_positions.device != position_ids.device:
            embed_positions = embed_positions.to(position_ids.device)
            self.embed_positions = embed_positions
        return embed_positions.repeat(position_ids.shape[0], 1, 1)
        
    def forward(
        self,
        q_hidden_states: torch.FloatTensor,
        kv_hidden_states: Optional[torch.FloatTensor] = None,
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tensor:
        if kv_hidden_states is None:
            kv_hidden_states = q_hidden_states
            
        query = self.q_proj(q_hidden_states)
        key = self.k_proj(kv_hidden_states)
        value = self.v_proj(kv_hidden_states)

        query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
        key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
        value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

        embed_positions = self._get_embed_positions(position_ids)

        repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
        sincos = torch.gather(embed_positions, 1, repeated_position_ids)
        sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)

        if self.rotary_dim is not None:
            k_rot = key[:, :, :, : self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim :]

            q_rot = query[:, :, :, : self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim :]

            k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
            q_rot = apply_rotary_pos_emb(q_rot, sin, cos)

            key = torch.cat([k_rot, k_pass], dim=-1)
            query = torch.cat([q_rot, q_pass], dim=-1)
        else:
            key = apply_rotary_pos_emb(key, sin, cos)
            query = apply_rotary_pos_emb(query, sin, cos)

        key = key.permute(0, 2, 1, 3)
        query = query.permute(0, 2, 1, 3)

        if layer_past is not None:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)
        
        return attn_output


class GPT2EncoderBlock(GPTJBlock):
    def __init__(self, config):
        super().__init__(config)
        inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
        
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.ln_3 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        
        self.attn = GPTJBiAttention(config)
        self.mem_attn = GPTJBiAttention(config)
        self.mlp = GPTJMLP(inner_dim, config)

    def forward(
        self,
        tgt: Optional[torch.FloatTensor],
        memory: Optional[torch.FloatTensor],
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> Tensor:
        
        x = tgt ## qry_embed
        x1 = memory ## qry_output
        
        ## self cross attn
        x = self.ln_1(
            x + self.attn(
                q_hidden_states=x, 
                attention_mask=attention_mask, 
                position_ids=position_ids,
            )
        )

        ## self-mem attn
        x = self.ln_2(
            x + self.mem_attn(
                q_hidden_states=x, 
                kv_hidden_states=x1, 
                attention_mask=attention_mask,
                position_ids=position_ids,
            )
        )
        
        ## ffn
        x = self.ln_3(
            x + self.mlp(x)
        )

        return x

    
    
#### norm-first
        
#         ## self cross attn
#         x = x + self.attn(
#             q_hidden_states=self.ln_1(x), 
#             attention_mask=attention_mask,
#             position_ids=position_ids,
#         )
        
#         ## self-mem attn
#         x = x + self.mem_attn(
#             q_hidden_states=self.ln_2(x), 
#             kv_hidden_states=x1, 
#             attention_mask=attention_mask,
#             position_ids=position_ids,
#         )
        
#         ## ffn
#         x = x + self.mlp(self.ln_3(x))

#         return x
