from typing import List, TypedDict, Literal
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, BitsAndBytesConfig
from abc import ABC, abstractmethod
import warnings

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
)


CHAT_TEMPLATE = """\
{% for message in messages -%}
    {%- if message['role'] == 'system' -%}
        {{ message['content'].strip() + ' ' }}
    {%- elif message['role'] == 'user' -%}
        {{ 'USER: ' + message['content'].strip() + ' ' }}
    {%- elif message['role'] == 'assistant' -%}
        {{ 'ASSISTANT: '  + message['content'] + eos_token + '' }}
    {%- endif %}
    {%- if loop.last and add_generation_prompt -%}
        {{ 'ASSISTANT:' }}
    {%- endif -%}
{%- endfor %}
"""

Role = Literal["system", "user", "assistant"]

"""
This file contains the model classes for the different models that are used in the experiments.
These models are based on the Huggingface Transformers library.
All inherit from the BaseModel class, that defines some shared code and abstract methods
"""


class Message(TypedDict):
    role: Role
    content: str

Dialog = List[Message]

class ModelConfig:
    def __init__(self, max_seq_len=1024):
        self.max_seq_len = max_seq_len
        self.max_batch_size = 1
        self.max_position_embeddings = max_seq_len

# config = BitsAndBytesConfig(load_in_4bit=True)

class ModelBuilder:
    def __init__(self):
        self.model_map = {
            "llama2": [Llama2, "meta-llama/Llama-2-7b-chat-hf"],
            "llama3": [Llama3, "meta-llama/Meta-Llama-3-8B-Instruct"],
            "dolphin": [DolphinLlama3, "cognitivecomputations/dolphin-2.9-llama3-8b"],
            "mistral": [Mistral, "mistralai/Mistral-7B-Instruct-v0.3"],
            "uncensored": [LlamaUncensored, "georgesung/llama2_7b_chat_uncensored"],
            "debug": DebugModel,
            "llama2_base": [Llama2_Base, "meta-llama/Llama-2-7b-hf"],
            "guanaco": [Guanaco, "TheBloke/guanaco-7B-HF"],
            "vicuna": [Vicuna, "TheBloke/vicuna-7B-HF"]
        }

    def build_model(self, model_name: str, max_seq_len=1024):
        if model_name in self.model_map.keys():
            return self.model_map[model_name][0](self.model_map[model_name][1])
        else:
            warnings.warn(f"Model {model_name} not found in model map, trying to load anyway")
            return BaseModel(model_name)


class BaseModel(ABC):
    """
    This is the base model class that concentrates shared code and abstract methods for the different models.
    Each model has a unique input format, so each model has to implement the translate_prompt_style method.
    """
    def __init__(self, model_card):
        self.tokenizer = AutoTokenizer.from_pretrained(model_card)
        self.init_tokenizer()
        self.model = AutoModelForCausalLM.from_pretrained(model_card, do_sample=True).to('cuda')
        self.init_model_params()

    def init_model_params(self):
        self.model.config.bos_token_id = self.bos_id
        self.model.config.eos_token_id = self.eos_id
        self.model.config.pad_token_id = self.pad_id
        self.params = self.model.config
        self.params.max_batch_size = 1
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.stop_string = None


    def init_tokenizer(self):
        """
        Because the algorithm is token-based, we have to define a uniform padding token
        :return:
        """
        if not self.tokenizer.pad_token:
            self.tokenizer.add_special_tokens({"pad_token": "<pad>",
                                               "bos_token": "<s>",
                                               "eos_token": "</s>", })
            bos = self.tokenizer.bos_token
            eos = self.tokenizer.eos_token
            pad = self.tokenizer.pad_token
            self.tokenizer.bos_id = self.tokenizer.convert_tokens_to_ids(bos)
            self.tokenizer.eos_id = self.tokenizer.convert_tokens_to_ids(eos)
            self.tokenizer.pad_id = self.tokenizer.convert_tokens_to_ids(pad)
            self.tokenizer.padding_side = "right"
            self.bos_id = self.tokenizer.bos_id
            self.eos_id = self.tokenizer.eos_id
            self.pad_id = self.tokenizer.pad_id
        else:
            self.tokenizer.padding_side = "right"
            self.bos_id = self.tokenizer.bos_token_id
            self.eos_id = self.tokenizer.eos_token_id
            self.pad_id = self.tokenizer.pad_token_id



    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        t = self.tokenizer.encode(s)
        if bos:
            t = [self.tokenizer.bos_id] + t
        if eos:
            t = t + [self.tokenizer.eos_id]
        return t


    def decode(self, t: List[int]) -> str:
        return self.tokenizer.decode(t)

    @torch.no_grad
    def forward(self, tokens, prev_pos=None):
        """
        The forward method is used to generate the next token logits for a given input sequence.
        :param tokens:
        :param prev_pos:
        :return:
        """
        outputs = self.model(input_ids=tokens, output_hidden_states=False, output_attentions=False)
        next_token_logits = outputs.logits[:, -1, :].unsqueeze(1)
        return next_token_logits

    def forward_2(self, tokens, start_pos):
        """
        The forward_2 method is used to generate the next token logits for a given input sequence,
        and also returns the hidden states of the generation.
        Only used when no external embedder is used.
        :param tokens:
        :param start_pos:
        :return:
        """
        outputs = self.model(tokens, output_hidden_states=True, output_attentions=True)
        next_tokens_logits = outputs.logits[:, -1, :].unsqueeze(1)
        hidden_states = outputs.hidden_states[-1][:, -1, :]
        return next_tokens_logits, (None, hidden_states)

    def dialogs_to_tensors(self, dialog: Dialog) -> List[int]:
        messages = []
        for message in dialog:
            messages.append({"role": message["role"], "content": message["content"]})
        input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, padding="max_length",
                                          max_length=self.params.max_position_embeddings, return_tensors="pt")
        inputs = self.tokenizer(input_text, return_tensors='pt', padding="max_length", max_length=self.params.max_position_embeddings)
        return inputs

    def messages_to_tensors(self, messages: List[dict]) -> List[int]:
        if self.tokenizer.chat_template is None:
            self.tokenizer.chat_template = CHAT_TEMPLATE
        tokenized = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True,
                                                       return_tensors="pt")
        tokenized = self.tokenizer(tokenized, return_tensors='pt', padding="max_length", max_length=self.params.max_position_embeddings)
        return tokenized


    def dialogs_to_str(self, dialog:Dialog) -> str:
        messages = []
        for message in dialog:
            messages.append({"role": message["role"], "content": message["content"]})
        input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, padding="max_length",
                                          max_length=self.params.max_position_embeddings, return_tensors="pt")
        return input_text

class DebugModel(BaseModel):
    def __init__(self, model_card, max_seq_len=1024):
        self.params = ModelConfig(max_seq_len=max_seq_len)
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
        self.model = None

    def forward(self, tokens, prev_pos=None):
        return torch.tensor([[0.0]])

    def dialogs_to_tensors(self, dialog: Dialog) -> List[int]:
        return [0]

    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
        return [0]

    def decode(self, t: List[int]) -> str:
        return "Debug"

class GenericModel(BaseModel):
    """
    This is a generic model class that can be used with any model that is not specifically implemented.
    Not guaranteed to work with all models, as the input format can differ.
    Will only work if the apply_chat_template method is implemented in the model's tokenizer.
    """
    def __init__(self, model_name: str, max_seq_len=1024):
        self.model = AutoModelForCausalLM.from_pretrained(model_name, max_position_embeddings=4096).to('cuda')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.init_model_params()


class DolphinLlama3(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)
        self.stop_string = "<|im_end|>"

class Guanaco(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)
        self.stop_string = "### Human:"

    def dialogs_to_tensors(self, dialog: Dialog):
        input_text = ""
        for message in dialog:
            if message["role"] == "user":
                input_text += "### Human:\n" + message["content"] + "\n"
            elif message["role"] == "system":
                input_text += message["content"] + "\n"
            else:
                input_text += "### Assistant:\n" + message["content"] + "\n"

        inputs = self.tokenizer(input_text, return_tensors='pt', padding="max_length", max_length=self.params.max_position_embeddings)
        return inputs

class Llama3(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)
        self.terminators = [self.tokenizer.eos_token_id,
                            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]


    @torch.no_grad
    def forward(self, tokens, prev_pos=None):
        outputs = self.model.generate(tokens, max_new_tokens=1, num_return_sequences=1, eos_token_id=self.terminators,
                                      return_dict_in_generate=True, output_scores=True,
                                      pad_token_id=self.tokenizer.pad_token_id)
        next_token_logits = outputs.scores[0][:, None, :]
        return next_token_logits

class Mistral(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)


class Llama2(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)

class Llama2_Base(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)
        self.model.max_position_embeddings = max_seq_len

    def dialogs_to_tensors(self, dialog: Dialog) -> List[int]:
        # combine all dialog messages into a single string
        dialog_str = ""
        for message in dialog:
            dialog_str += message["content"] + "\n"
        # encode the dialog string
        tokens = self.tokenizer.encode(dialog_str, return_tensors="pt").to("cuda").tolist()
        return tokens

    def translate_prompt_style_old(self, dialog: Dialog) -> List[int]:
        if dialog[0]["role"] == "system":
            dialog = [
                         {
                             "role": dialog[1]["role"],
                             "content": B_SYS
                                        + dialog[0]["content"]
                                        + E_SYS
                                        + dialog[1]["content"],
                         }
                     ] + dialog[2:]
        dialog_tokens: List[int] = sum(
            [
                self.tokenizer.encode(
                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
                )
                for prompt, answer in zip(
                dialog[::2],
                dialog[1::2],
            )
            ],
            [],
        )

        dialog_tokens += self.tokenizer.encode(
            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
        )

        return [dialog_tokens]


class LlamaUncensored(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)
        self.stop_string = "### HUMAN:"

    def dialogs_to_tensors(self, dialog: Dialog):
        input_text = ""
        for message in dialog:
            if message["role"] == "user":
                input_text += "### HUMAN:\n" + message["content"] + "\n"
            elif message["role"] == "system":
                input_text += "### SYSTEM:\n" + message["content"] + "\n"
            else:
                input_text += "### RESPONSE:\n" + message["content"] + "\n"

        inputs = self.tokenizer(input_text, return_tensors='pt', padding="max_length", max_length=self.params.max_position_embeddings)
        return inputs

    def dialogs_to_str(self, dialog:Dialog) -> str:
        input_text = ""
        for message in dialog:
            if message["role"] == "user":
                input_text += "### HUMAN:\n" + message["content"] + "\n"
            elif message["role"] == "system":
                input_text += "### SYSTEM:\n" + message["content"] + "\n"
            else:
                input_text += "### RESPONSE:\n" + message["content"] + "\n"
        return input_text



