from abc import ABC, abstractmethod
from transformers import LlamaForCausalLM , LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer, GPTNeoXForCausalLM, BitsAndBytesConfig, set_seed
import torch
import numpy as np
from openai import OpenAI
from scipy.special import softmax
from prompts import PROMPTS

# Define abstract model class
class Model(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def get_answer_from_question(self, prompt, temperature=0.1):
        pass
   
    def create_prompt(self, question, system_prompt_type, language):
        options = question["options"]
        options_letters = sorted(list(options.keys()))
        question_word = "Questão" if language == "pt-br" else "Question"
        option_word = "Options" if language == "en" else "Alternativas"
        answer_word = "Answer" if language == "en" else "Resposta"
        question_subject = question["test"]

        system_prompt = self.get_system_prompt(system_prompt_type, language, question_subject)

        prompt = f'{system_prompt}\n\n{question_word}:\n{question["body"].strip()}\n\n{option_word}:\n'
        for option in options_letters:
            prompt += f"({option}) {options[option]}\n"
        prompt += f"\n{answer_word}: ("

        return prompt


    def get_system_prompt(self, system_prompt_type, language, question_subject):
        if system_prompt_type == "one-shot":
            return PROMPTS[system_prompt_type][language][question_subject]
        else:
            return PROMPTS[system_prompt_type][language]

class ModelHF(Model):
    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def get_answer_from_question(self, question, system_prompt_type, language):
        """
        Get answer from question

        Args:
            question: dict, question with body and options
            system_prompt_type: str, system prompt type
            language: str, language of the question

        Returns:
            max_token: str, token with highest probability
            prob_dist: list, probability distribution of tokens (logits after softmax)
            logits: list, logits of the tokens
        """
        set_seed(self.seed)
        prompt = self.create_prompt(question, system_prompt_type, language)
        inputs = self.tokenizer(prompt, return_tensors='pt').input_ids.to(self.device)
        outputs = self.model(inputs, output_hidden_states=False, output_attentions=False, return_dict=True, use_cache=False)
        logits = outputs.logits[0, -1, :].cpu().detach().numpy() # Last token logits
        logits = logits.astype(np.float32)

        # Getting the probability distribution of tokens in tokenizer_map (A, B, C, D, E)
        # Reference: https://huggingface.co/blog/open-llm-leaderboard-mmlu
        prob_dist = [logits[self.tokenizer_map[option]] for option in sorted(self.tokenizer_map.keys())]
        # Applying the softmax function to the logits
        prob_dist = softmax(prob_dist)
        max_token = sorted(self.tokenizer_map.keys())[np.argmax(prob_dist)]

        # Returning only the logits with respect to the (A, B, C, D, E) tokens
        logits = logits[[self.tokenizer_map[option] for option in sorted(self.tokenizer_map.keys())]]
        
        return max_token, prob_dist, logits

class LLAMA2(ModelHF):
    """
    LLAMA2 model class
    """
    def __init__(self, model_size, token, device, is_instruct_version=True, random_seed=0):
        super().__init__()

        self.instruct_version = "-chat" if is_instruct_version else ""
        self.model = LlamaForCausalLM.from_pretrained(f"meta-llama/Llama-2-{model_size}{self.instruct_version}-hf", token=token, device_map="auto", torch_dtype=torch.float16)
        self.tokenizer = LlamaTokenizer.from_pretrained(f"meta-llama/Llama-2-{model_size}{self.instruct_version}-hf", token=token)
        self.tokenizer_map = {letter: self.tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        self.model_size = model_size
        self.device = device
        self.seed = random_seed
        set_seed(self.seed)

    def get_answer_from_question(self, question, system_prompt_type, language):
        return super().get_answer_from_question(question, system_prompt_type, language)
    
class Mistral(ModelHF):
    """
    Mistral model class
    """

    def __init__(self, model_size, token, device, is_instruct_version=True, random_seed=0):
        super().__init__()
        self.model_size = model_size
        self.instruct_version = "-Instruct" if is_instruct_version else ""
        if self.model_size == "7b":
            self.tokenizer = AutoTokenizer.from_pretrained(f"mistralai/Mistral-7B{self.instruct_version}-v0.1", token=token)
            self.model = AutoModelForCausalLM.from_pretrained(f"mistralai/Mistral-7B{self.instruct_version}-v0.1", token=token, device_map="auto", torch_dtype=torch.float16)
        elif self.model_size == "8x7b":
            self.bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )

            self.tokenizer = AutoTokenizer.from_pretrained(f"mistralai/Mixtral-8x7B{self.instruct_version}-v0.1", token=token)
            self.model = AutoModelForCausalLM.from_pretrained(f"mistralai/Mixtral-8x7B{self.instruct_version}-v0.1", device_map="auto", quantization_config=self.bnb_config)

        self.tokenizer_map = {letter: self.tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        self.device = device
        self.seed = random_seed
        set_seed(random_seed)

    def get_answer_from_question(self, question, system_prompt_type, language):
        return super().get_answer_from_question(question, system_prompt_type, language)
    
class Gemma(ModelHF):
    """
    Gemma model class
    """
    def __init__(self, model_size, token, device, is_instruct_version=True, random_seed=0):
        super().__init__()
        self.model_size = model_size
        self.instruct_version = "-it" if is_instruct_version else ""
        self.tokenizer = AutoTokenizer.from_pretrained(f"google/gemma-{self.model_size}{self.instruct_version}", token=token)
        self.model = AutoModelForCausalLM.from_pretrained(f"google/gemma-{self.model_size}{self.instruct_version}", token=token, device_map="auto", torch_dtype=torch.float16)
        self.tokenizer_map = {letter: self.tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        self.device = device
        self.seed = random_seed
        set_seed(random_seed)

    def get_answer_from_question(self, question, system_prompt_type, language):
        return super().get_answer_from_question(question, system_prompt_type, language)
    
class Pythia(ModelHF):
    """
    Pythia model class
    """
    def __init__(self, model_size, token, device, random_seed=0):
        super().__init__()
        self.model_size = model_size
        self.tokenizer = AutoTokenizer.from_pretrained(f"EleutherAI/pythia-{self.model_size}-deduped", token=token)
        self.model = GPTNeoXForCausalLM.from_pretrained(f"EleutherAI/pythia-{self.model_size}-deduped", token=token, device_map="auto", torch_dtype=torch.float16)
        self.tokenizer_map = {letter: self.tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        self.device = device
        self.seed = random_seed
        set_seed(random_seed)

    def get_answer_from_question(self, question, system_prompt_type, language):
        return super().get_answer_from_question(question, system_prompt_type, language)

class LLAMA3(ModelHF):
    """
    LLAMA3 model class
    """
    def __init__(self, model_size, token, device, is_instruct_version=True, random_seed=0):
        super().__init__()

        model_size = model_size.upper()
        self.instruct_version = "-Instruct" if is_instruct_version else ""
        self.model = AutoModelForCausalLM.from_pretrained(f"meta-llama/Meta-Llama-3-{model_size}{self.instruct_version}", token=token, device_map="auto", torch_dtype=torch.float16)
        self.tokenizer = AutoTokenizer.from_pretrained(f"meta-llama/Meta-Llama-3-{model_size}{self.instruct_version}", token=token)
        self.tokenizer_map = {letter: self.tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        self.model_size = model_size
        self.device = device
        self.seed = random_seed
        set_seed(self.seed)

    def get_answer_from_question(self, question, system_prompt_type, language):
        return super().get_answer_from_question(question, system_prompt_type, language)

class CommandR(ModelHF):
    """
    CommandR model class
    """
    def __init__(self, model_size, token, device, random_seed=0):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(f"CohereForAI/c4ai-command-r-v01", token=token, device_map="auto", torch_dtype=torch.float16)
        self.tokenizer = AutoTokenizer.from_pretrained(f"CohereForAI/c4ai-command-r-v01", token=token)
        self.tokenizer_map = {letter: self.tokenizer.encode(f"({letter}")[-1] for letter in "ABCDE"}
        self.device = device
        self.seed = random_seed
        set_seed(self.seed)

    def get_answer_from_question(self, question, system_prompt_type, language):
        return super().get_answer_from_question(question, system_prompt_type, language)

class GPT(Model):
    """
    GPT model class
    """

    def __init__(self, model, random_seed=0):
        super().__init__()
        self.model = model
        self.client = OpenAI(timeout=10, max_retries=2)
        self.seed = random_seed

    def get_answer_from_question(self, question, system_prompt_type, language):
        """
        Get answer from question

        Args:
            question: dict, question with body and options
            system_prompt_type: str, system prompt type
            language: str, language of the question

        Returns:
            max_token: str, token with highest probability
            prob_dist: list, probability distribution of tokens (logits after softmax)
            logits: list, logits of the tokens
        """
        
        prompt = self.create_prompt(question, system_prompt_type, language)
        
        response = self.client.chat.completions.create(
            model=self.model,
            seed=self.seed,
            temperature=0,
            max_tokens=1,
            top_p=1,
            logprobs=True,
            top_logprobs=20,
            messages=[
                {"role": "user", "content": prompt},
            ]
        )

        logprobs = response.choices[0].logprobs.content[0].top_logprobs
        
        answer = None
        for top_logprob in logprobs:
            if top_logprob.token in ["A", "B", "C", "D", "E"]:
                answer = top_logprob.token
                break
          
        system_fingerprint = response.system_fingerprint

        return answer, system_fingerprint


    