from typing import List, Optional, Tuple

from numpy import array, dtype, float16, ndarray
from torch import no_grad, Tensor
from transformers import AutoTokenizer, AutoModelForCausalLM

def get_latent_representations(
        tokenizer: AutoTokenizer,
        model: AutoModelForCausalLM,
        text: str,
        layer_indices: List[int]=[-1],
        _attention_weights: bool = True,
        _token_embeddings: bool = False,
        _numpy_float_type: dtype = float16,
) -> Tuple[Optional[List[ndarray]], Optional[List[ndarray]]]:
    """
    Args:
        tokenizer (AutoTokenizer): Tokenizer for the language model.
        model (AutoModelForCausalLM): The language model.
        text (str): Input text to generate latent representations for.
        layer_indices (List[int], optional): Indices of layers to extract latent representations from. Defaults to [-1] (last layer).
        _attention_weights (bool, optional): Whether to extract attention weights. Defaults to True.
        _token_embeddings (bool, optional): Whether to extract token embeddings. Defaults to False.
        _numpy_float_type (dtype, optional): Numpy data type for the extracted representations. Defaults to float16.
    
    Returns:
        Tuple[Optional[List[ndarray]], Optional[List[ndarray]]]: Tuple containing attention matrices and token embeddings (if requested).
    """
    # Get model outputs including attention weights and hidden states
    outputs = model(
        tokenizer(text, return_tensors='pt')['input_ids'].to('cuda'),
        output_attentions=_attention_weights,
        output_hidden_states=_token_embeddings,
    )

    # Extract attention matrices if _attention_weights is True
    # tuple of n layers, (batch_size, num_heads, sequence_length, sequence_length)
    print(len(outputs.attentions))
    attention_matrices = tuple(
        outputs.attentions[ind][0].detach().cpu().numpy().astype(_numpy_float_type)
        for ind in layer_indices) if _attention_weights else list(array([]))
    # Extract token embeddings if _token_embeddings is True
    token_embeddings = [
        outputs.hidden_states[index][0].detach().cpu().numpy().astype(_numpy_float_type)
    for index in layer_indices] if _token_embeddings else list(array([]))
    del outputs

    return attention_matrices, token_embeddings


def generate_response(
        tokenizer: AutoTokenizer,
        model: AutoModelForCausalLM,
        text: str,
        temperature: float = 0,
        max_new_tokens: int = 30
) -> Tensor:
    """
    Generates a response given an input text.

    Args:
        tokenizer (AutoTokenizer): Tokenizer for text encoding.
        model (AutoModelForCausalLM): Pretrained language model.
        text (str): Input text for response generation.
        temperature (float, optional): Sampling temperature for text generation.
        max_new_tokens (int, optional): Maximum number of new tokens to generate.

    Returns:
        Tensor: Generated response tensor.
    """
    inputs = tokenizer(text, return_tensors='pt')
    with no_grad():
        return model.generate(
            inputs.input_ids.to('cuda'),
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=temperature,
            num_return_sequences=1
        )
        
def decode_tokens(
    tokenizer:AutoTokenizer,
    token_ids:Tensor
) -> str:
    """
    Decodes token IDs back to text.

    Args:
        tokenizer (AutoTokenizer): Tokenizer for decoding.
        token_ids: List of token IDs to be decoded.

    Returns:
        str: Decoded text.
    """
    return tokenizer.batch_decode(
        token_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]
