from typing import Union, List, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from .llm import LLM
from .config import CFG


class HFLLM(LLM):
    def __init__(self, model_path: str = CFG["LLAMA"]["llama_small"]):
        """
        Parameters
        ----------
        model_path : str
            The path for the model on HF.
        """
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.tokenizer, self.model = self._init_model_tokenizer(model_path)
        return None

    def _init_model_tokenizer(self, model_path):
        """
        Initialize and return the tokenizer and model based on the provided model path.
        """
        tokenizer = AutoTokenizer.from_pretrained(model_path)

        if tokenizer.pad_token is None:
            if tokenizer.eos_token:
                tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
        return tokenizer, model

    def generate(
        self,
        prompt: str,
        temperature: float = CFG["general"]["temperature"],
        n_generations: int = 1,
        length_normalize: bool = CFG["general"]["length_normalize"],
        max_length: int = 50,
        repetition_penalty: float = 1.1,
        **kwargs
    ):
        """
        Generate text in response to a prompt.

        Parameters
        ----------
        prompt : str
            The prompt you want to send to the model.
        temperature : float
            Softmax temperature for sampling. 0 is fully greedy.
        n_generations : int
            The number of samples you want to draw from the LLM.
        length_normalize : bool
            Whether you want to length-normalize the log-probability of the sequence.
        max_length : int
            The maximum length of the sequence to be generated by the model.
        repetition_penalty : float
            Penalty for repeating tokens.

        Returns
        -------
        tuple : (text, log_prob)
            The response text(s) and corresponding sequence(s)' log-probability(ies).
        """
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)

        outputs = []
        log_probs = []

        for _ in range(n_generations):
            output = self.model.generate(
                input_ids,
                max_length=max_length + input_ids.shape[1],
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                do_sample=True if temperature > 0 else False,
                return_dict_in_generate=True,
                output_scores=True,
                **kwargs
            )

            generated_sequence = output.sequences[0]
            scores = output.scores
            decoded_output = self.tokenizer.decode(
                generated_sequence, skip_special_tokens=True
            )
            new_tokens = self.tokenizer.decode(
                generated_sequence[input_ids.shape[1] :], skip_special_tokens=True
            )
            outputs.append(new_tokens)

            log_prob = self._get_logprobs(
                scores, generated_sequence[input_ids.shape[1] :], length_normalize
            )
            log_probs.append(log_prob)

        if n_generations == 1:
            return outputs[0], log_probs[0]
        else:
            return outputs, log_probs

    def _get_logprobs(self, scores, generated_ids, length_normalize: bool):
        """
        Calculate the log probability of the generated sequence.
        """
        log_prob = 0
        for score, token_id in zip(scores, generated_ids):
            log_prob += score[token_id].item()
        if length_normalize:
            log_prob /= len(generated_ids)
        return log_prob
