import torch
import torch.nn as nn
import math
from copy import deepcopy
from typing import Optional, Tuple

def rotate_half(x, group):
    rotate_x =[]
    dh=x.shape[-1]//group
    for i in range(group):
        rotate_x.append(-x[..., i*dh + dh//2 : (i+1)*dh])
        rotate_x.append(x[..., i*dh: i*dh + dh//2])
    return torch.cat(rotate_x, dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, rope_head=1):
    rope_dim = cos.shape[-1]*rope_head
    nope_dim = q.shape[-1]-rope_dim
    q_rope, q_nope = q.split([rope_dim, nope_dim], dim=-1)
    k_rope, k_nope = k.split([rope_dim, nope_dim], dim=-1)
    cos = cos.unsqueeze(1)
    sin = sin.unsqueeze(1)
    rope_repeat = q_rope.shape[-1]//cos.shape[-1]
    q_rope_embed = (q_rope * cos.repeat(1,1,1,rope_repeat)) + (rotate_half(q_rope, rope_repeat) * sin.repeat(1,1,1,rope_repeat))
    rope_repeat = k_rope.shape[-1]//cos.shape[-1]
    k_rope_embed = (k_rope * cos.repeat(1,1,1,rope_repeat)) + (rotate_half(k_rope, rope_repeat) * sin.repeat(1,1,1,rope_repeat))
    q_embed = torch.cat([q_rope_embed, q_nope], dim=-1)
    k_embed = torch.cat([k_rope_embed, k_nope], dim=-1)
    return q_embed, k_embed

class RemoveRope(nn.Module):
    def __init__(self, self_attn, key_outputs=None, dim2head=None, rope_head=1):
        super().__init__()
        self.config = self_attn.config
        self.layer_idx = self_attn.layer_idx
        self.hidden_size = self_attn.config.hidden_size
        self.num_attention_heads = self_attn.config.num_attention_heads
        self.head_dim = self_attn.head_dim
        self.num_key_value_heads = self_attn.config.num_key_value_heads
        self.latent_dim = self.num_key_value_heads * self.head_dim
        self.attention_dropout = self_attn.attention_dropout
        self.rope_head = rope_head

        self.q_proj = self_attn.q_proj
        self.k_proj = self_attn.k_proj
        self.v_proj = self_attn.v_proj
        self.o_proj = self_attn.o_proj
        self.__insert_kv_up_proj__()
        if key_outputs is not None:
            Rk = self.joint_complex_pca(key_outputs, dim2head)
            self.rotate_k_proj(Rk, dim2head=dim2head)
            self.rotate_k_up_proj(Rk, dim2head=dim2head)
            
    def __insert_kv_up_proj__(self):
        self.k_up_proj = nn.Linear(self.latent_dim, self.hidden_size, bias=False, dtype=self.k_proj.weight.dtype, device=self.k_proj.weight.device)
        self.v_up_proj = nn.Linear(self.latent_dim, self.hidden_size, bias=False, dtype=self.v_proj.weight.dtype, device=self.v_proj.weight.device)
        kv_groups = self.num_attention_heads // self.num_key_value_heads
        k_up_eye = torch.eye(self.latent_dim, dtype=self.k_proj.weight.dtype, device=self.k_proj.weight.device)
        v_up_eye = torch.eye(self.latent_dim, dtype=self.v_proj.weight.dtype, device=self.v_proj.weight.device)
        k_up_eye = k_up_eye.reshape(self.num_key_value_heads, self.head_dim, self.latent_dim)
        v_up_eye = v_up_eye.reshape(self.num_key_value_heads, self.head_dim, self.latent_dim)
        self.k_up_proj.weight.data = torch.stack([k_up_eye]*kv_groups,dim=1).reshape(self.hidden_size, self.latent_dim).contiguous()
        self.v_up_proj.weight.data = torch.stack([v_up_eye]*kv_groups,dim=1).reshape(self.hidden_size, self.latent_dim).contiguous()

    @torch.no_grad()
    def joint_complex_pca(self, Z: list[torch.Tensor], dim2head: int = 1) -> torch.Tensor:
        dtype = self.k_proj.weight.dtype
        eigen_vecs = []
        for i in range(self.head_dim//2//dim2head):
            H = None
            for Z_batch in Z:
                b,n,d = Z_batch.shape
                head_batch = deepcopy(Z_batch).view(b,n, self.num_key_value_heads, 2, self.head_dim//2//dim2head, dim2head)
                head_batch = head_batch.permute(0, 1, 3, 2, 5, 4)
                head_batch = head_batch.reshape(b,n*2, self.num_key_value_heads*dim2head, self.head_dim//2//dim2head)
                head_batch_i = head_batch[:,:,:,i].double().to(self.k_proj.weight.device)
                head_batch_i = torch.sum(head_batch_i.mT @ head_batch_i, dim=0)  # sum over the batch dimension.
                H = head_batch_i if H is None else H + head_batch_i
            damp = 0.01 * torch.mean(torch.diag(H))
            diag = torch.arange(H.shape[-1]).to(self.k_proj.weight.device)
            H[diag, diag] = H[diag, diag] + damp
            X_eig = torch.linalg.eigh(H)
            del H
            index = torch.argsort(X_eig[0], descending=True)
            eigen_vecs.append(X_eig[1][:, index])
        return torch.stack(eigen_vecs+eigen_vecs).to(dtype)
    
    def rotate_k_proj(self, U, dim2head=1):
        k_weight = deepcopy(self.k_proj.weight.data)
        U = U.to(k_weight.dtype).to(k_weight.device)
        if self.k_proj.bias is not None:
            k_bias = deepcopy(self.k_proj.bias.data)
            k_weight = torch.cat([k_weight, k_bias.unsqueeze(1)], dim=1)
        k_weight = k_weight.reshape(self.num_key_value_heads, self.head_dim//dim2head, dim2head, -1)
        k_weight = k_weight.permute(0, 2, 1, 3).reshape(self.num_key_value_heads*dim2head, self.head_dim//dim2head, -1)
        k_weight = torch.einsum("dhc,hdD->cdD", U, k_weight)
        k_weight = k_weight.reshape(self.num_key_value_heads, dim2head, self.head_dim//dim2head, -1)
        k_weight = k_weight.permute(0, 2, 1, 3).reshape(self.num_key_value_heads, self.head_dim, -1)
        if self.k_proj.bias is not None:
            k_bias = k_weight[:, :, -1]
            k_weight = k_weight[:, :, :-1]
            assert self.k_proj.bias.data.shape == self.hidden_size
            self.k_proj.bias.data = k_bias.reshape(self.hidden_size).contiguous()
        assert self.k_proj.weight.data.shape == (self.latent_dim, self.hidden_size)
        self.k_proj.weight.data = k_weight.reshape(self.latent_dim, self.hidden_size).contiguous()
        
    def rotate_k_up_proj(self, U, dim2head=1):
        k_up_weight = deepcopy(self.k_up_proj.weight.data)
        U = U.to(k_up_weight.dtype).to(k_up_weight.device)
        k_up_weight = k_up_weight.reshape(self.hidden_size, self.num_key_value_heads, self.head_dim//dim2head, dim2head)
        k_up_weight = k_up_weight.permute(0, 1, 3, 2).reshape(self.hidden_size, self.num_key_value_heads*dim2head, self.head_dim//dim2head)
        k_up_weight = torch.einsum("dhc,Dhd->Dcd", U, k_up_weight)
        k_up_weight = k_up_weight.reshape(self.hidden_size, self.num_key_value_heads, dim2head, self.head_dim//dim2head)
        k_up_weight = k_up_weight.permute(0, 1, 3, 2).reshape(self.hidden_size, self.latent_dim)
        assert self.k_up_proj.weight.data.shape == (self.hidden_size, self.latent_dim)
        self.k_up_proj.weight.data = k_up_weight.contiguous()
  
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim)
        k_up_weight = self.k_up_proj.weight.view(self.num_attention_heads, self.head_dim, self.latent_dim)
        query_states = torch.einsum("bthd,hdc->bhtc", query_states, k_up_weight)
    
        key_states = key_states.view(bsz, 1, q_len, self.latent_dim)
        value_states = value_states.view(bsz, 1, q_len, self.latent_dim)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, self.rope_head)
        
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask
            
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
       
        v_up_weight = self.v_up_proj.weight.view(self.num_key_value_heads, self.num_attention_heads//self.num_key_value_heads, self.head_dim, self.latent_dim)
        value_states = torch.einsum("bhtc,hgdc->bhgtd", value_states, v_up_weight)
        value_states = value_states.reshape(bsz, self.num_attention_heads, -1, self.head_dim)
        attn_output = torch.matmul(attn_weights, value_states)
        
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)
        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights
