from typing import List, TypedDict, Literal
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, BitsAndBytesConfig
from abc import ABC, abstractmethod
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, infer_auto_device_map
from huggingface_hub import login
from Params import MAX_SEQ

login(token="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")

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
)

VICUNA_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."

LLAMA_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant is designed to be safe and helpful."

VICUNA_CHAT_TEMPLATE = """\
{% for message in messages -%}
    {%- if message['role'] == 'system' -%}
        {{ message['content'].strip() + ' ' }}
    {%- elif message['role'] == 'user' -%}
        {{ 'USER: ' + message['content'].strip() + ' ' }}
    {%- elif message['role'] == 'assistant' -%}
        {{ 'ASSISTANT: '  + message['content'] + eos_token + '' }}
    {%- endif %}
    {%- if loop.last and add_generation_prompt -%}
        {{ 'ASSISTANT:' }}
    {%- endif -%}
{%- endfor %}
"""


"""
This file contains the model classes for the different models that are used in the experiments.
These models are based on the Huggingface Transformers library.
All inherit from the BaseModel class, that defines some shared code and abstract methods
"""

Role = Literal["system", "user", "assistant"]

class Message(TypedDict):
    role: Role
    content: str

Dialog = List[Message]

class ModelConfig:
    def __init__(self, max_seq_len=1024):
        self.max_seq_len = max_seq_len
        self.max_batch_size = 1
        self.max_position_embeddings = max_seq_len



class ModelBuilder:
    def __init__(self):
        self.model_map = {
            "vicuna": (Vicuna, "lmsys/vicuna-13b-v1.5"),
            "vicuna_uncensored": (Vicuna, "cognitivecomputations/Wizard-Vicuna-13B-Uncensored"),
            "llama-3": (Llama3, "meta-llama/Meta-Llama-3-8B-Instruct"),
            "mistral": (Mistral, "mistralai/Mistral-7B-Instruct-v0.3"),
            "llama-3-70b": (Llama3_70b, "meta-llama/Meta-Llama-3-70B-Instruct"),
            "mistral_uncensored": (Mistral, "cognitivecomputations/dolphin-2.9.3-mistral-7B-32k"),
            "llama-3_uncensored": (Llama3, "cognitivecomputations/dolphin-2.9.3-llama-3-8b"),
            "llama-3-70b_uncensored": (Llama3_70b, "cognitivecomputations/dolphin-2.9.1-llama-3-70b"),
        }

    def build_model(self, model_name: str, max_seq_len=1024):
        if model_name in self.model_map.keys():
            return self.model_map[model_name][0](self.model_map[model_name][1])
        else:
            warnings.warn(f"Model {model_name} not found in model map, trying to load anyway")
            return BaseModel(model_name)

class BaseModel(ABC):
    def __init__(self, model_card: str):
        if model_card is None:
            self.stop_string = None
            return
        # if child class is Vicuna, load the model with quantization
        if isinstance(self, Vicuna):
            warnings.warn("Loading Vicuna with quantization")
            model = AutoModelForCausalLM.from_pretrained(model_card, quantization_config=quantization_config)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_card).to('cuda')
        model.config.max_position_embeddings = MAX_SEQ
        tokenizer = AutoTokenizer.from_pretrained(model_card)
        tokenizer.model_max_length = MAX_SEQ
        self.model = model

        self.tokenizer = tokenizer
        self.stop_string = None

    def dialogs_to_input(self, dialogs: List[Dialog], padding="do_not_pad"):
        if self.tokenizer.chat_template is None:
            self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE
        tokenized_input = self.tokenizer.apply_chat_template(dialogs, tokenize=True, add_generation_prompt=True,
                                                             return_tensors='pt', return_dict=True, padding=padding,
                                                             max_length=MAX_SEQ)
        return tokenized_input

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, eos_token_id=self.terminators, **kwargs)

    def forward_legacy(self, tokens, prev_pos=None):
        """
        The forward method is used to generate the next token logits for a given input sequence.
        :param tokens:
        :param prev_pos:
        :return:
        """
        outputs = self.model(input_ids=tokens, output_hidden_states=False, output_attentions=False)
        next_token_logits = outputs.logits[:, -1, :].unsqueeze(1)
        return next_token_logits

    def decode(self, output):
        return self.tokenizer.decode(output, skip_special_tokens=True)
class Vicuna(BaseModel):
    def __init__(self, model_card):
        super().__init__(model_card)
        if self.tokenizer.chat_template is None:
            self.tokenizer.chat_template = VICUNA_CHAT_TEMPLATE
        self.terminators = [
            self.tokenizer.eos_token_id,
        ]
        self.system_prompt = B_SYS

class Llama3(BaseModel):
    def __init__(self, model_card:str):
        super().__init__(model_card)
        assert self.tokenizer.chat_template is not None, "Chat template not found in llama-3"
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

class Mistral(BaseModel):
    def __init__(self, model_card:str):
        super().__init__(model_card)
        self.terminators = [
            self.tokenizer.eos_token_id,
        ]

class Llama3_70b(BaseModel):
    def __init__(self, model_card: str):
        super().__init__(None)
        if model_card == "meta-llama/Meta-Llama-3-70B-Instruct":
            cache_directory = "/dt/shabtaia/dt-fujitsu-gai/llama-3-70b/"
        elif model_card == "cognitivecomputations/dolphin-2.9.1-llama-3-70b":
            cache_directory = "/dt/shabtaia/dt-fujitsu-gai/llama-3-70b_uncencsored/"
        else:
            raise ValueError(f"Model card {model_card} not found")
        self.model = AutoModelForCausalLM.from_pretrained(
            cache_directory,
            device_map='cuda:0',
            torch_dtype=torch.float16,
            quantization_config=quantization_config
        )
        self.tokenizer = AutoTokenizer.from_pretrained(cache_directory)
        assert self.tokenizer.chat_template is not None, "Chat template not found in llama-3-70b"
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
