import torch
import numpy as np
from torch import nn
from transformers.models.llama import LlamaConfig
# from typing import Optional
import math
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from rebuttal_gen_X import load_embeddings

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

class GDMask(nn.Module):

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.is_causal = True

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)

        self.q_proj_mask = nn.Parameter(
            torch.ones_like(self.q_proj.weight), requires_grad=True
        )
        self.k_proj_mask = nn.Parameter(
            torch.ones_like(self.k_proj.weight), requires_grad=True
        )



    def forward(
        self,
        hidden_states: torch.Tensor
    ):
        bsz, q_len, _ = hidden_states.size()

        query_states_m = nn.functional.linear(hidden_states, self.q_proj.weight * self.q_proj_mask) # add mask here in q_proj
        key_states_m = nn.functional.linear(hidden_states, self.k_proj.weight * self.k_proj_mask) # add mask here in k_proj
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        

        query_states_m = query_states_m.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states_m = key_states_m.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        key_states_m = repeat_kv(key_states_m, self.num_key_value_groups)
        attn_weights_m = torch.matmul(query_states_m, key_states_m.transpose(2, 3)) / math.sqrt(self.head_dim)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    
    
        # TODO: define attention mask, not None
        attention_mask = None

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states_m.shape[-2]]
            attn_weights_m = attn_weights_m + causal_mask

        # upcast attention to fp32. Below is our f(W_K, W_Q, X)_{1, \dots, h}. # 1, n=seq, h, d
        f_m = nn.functional.softmax(attn_weights_m, dim=-1, dtype=torch.float32).to(query_states_m.dtype)
        
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32. Below is our f(W_K, W_Q, X)_{1, \dots, h}. # 1, n=seq, h, d
        f = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        
        return f, f_m