from src.models.modeling_gpt2 import ExtendedGPT2Config, ExtendedGPT2Block, ExtendedGPT2LMHeadModel, GPT2LinearAttention
from src.models.modeling_gpt_neox import ExtendedGPTNeoXConfig, ExtendedGPTNeoXLayer, ExtendedGPTNeoXModel, ExtendedGPTNeoXForCausalLM, GPTNeoXLinearAttention

import torch
from typing import Union

CausalLMConfig = Union[
    ExtendedGPT2Config, 
    ExtendedGPTNeoXConfig
]
CausalLM = Union[
    ExtendedGPT2LMHeadModel, 
    ExtendedGPTNeoXForCausalLM
]
CausalLMLayer = Union[
    ExtendedGPT2Block, 
    ExtendedGPTNeoXLayer
]
LinearAttention = Union[
    GPT2LinearAttention,
    GPTNeoXLinearAttention
]

def get_qkv(model: CausalLM, input_ids: torch.LongTensor, attention_mask: torch.Tensor, hidden_states: torch.Tensor, layer_idx: int):
    """
    Compute query, key and value of the layer specified by `layer_idx` from the given `hidden_states`.
    For GPT-NeoX, this function mimics the behavior when position_ids, head_mask, input_embeds are None.
    Note that this function requires `model` rather than `layer` because whole model is needed to get the position_ids.
    """
    if isinstance(model, ExtendedGPT2LMHeadModel):
        layer: ExtendedGPT2Block = model.transformer.h[layer_idx]
        hidden_states_before_attn = layer.ln_1(hidden_states)
        query, key, value = layer.attn.c_attn(hidden_states_before_attn).split(split_size=layer.attn.embed_dim, dim=2)
    
    elif isinstance(model, ExtendedGPTNeoXForCausalLM):
        seq_length = input_ids.size(1)

        transformer: ExtendedGPTNeoXModel = model.transformer
        inputs_embeds = model.transformer.embed_in(input_ids)
        cache_position = torch.arange(0, seq_length, device=inputs_embeds.device)
        position_ids = cache_position.unsqueeze(0)
        first_hidden_states = transformer.emb_dropout(inputs_embeds)
        position_embeddings = transformer.rotary_emb(first_hidden_states, position_ids)
        
        layer: ExtendedGPTNeoXLayer = transformer.h[layer_idx]
        hidden_states_before_attn = layer.input_layernorm(hidden_states)
        query, key, value, _ = layer.attention._attn_projections_and_rope(hidden_states=hidden_states_before_attn, 
                                                                          position_ids=position_ids,
                                                                          layer_past=None, use_cache=False, 
                                                                          cache_position=cache_position, 
                                                                          position_embeddings=position_embeddings)
        
    else:
        raise NotImplementedError(f"Unknown layer type: {type(layer)}")

    return query, key, value