# %%
from enum import Enum
import json
import os
from pathlib import Path
from peft import LoraConfig, PeftModel
from peft.peft_model import PeftModelForCausalLM
from peft.utils import set_peft_model_state_dict, load_peft_weights
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Union, Literal
import warnings

from . import MODEL_PATH
from .messages import Message, Role, merge_messages
from .utils import get_adapter_path

QWEN32_PATH = MODEL_PATH / "Qwen/Qwen2.5-32B-Instruct"
QWEN14_PATH = MODEL_PATH / "Qwen/Qwen2.5-14B-Instruct"
QWEN7_PATH = MODEL_PATH / "Qwen/Qwen2.5-7B-Instruct"
QWEN3_PATH = MODEL_PATH / "Qwen/Qwen2.5-3B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SYSTEM_MESSAGES = {
    "qwen": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
    "llama": "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly.",
}

MODEL_FULL_NAME = {
    "llama3-8b-instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
    "llama3-8b": "meta-llama/Meta-Llama-3-8B-Instruct",
    "llama3-70b-instruct": "meta-llama/Meta-Llama-3-70B-Instruct",
    "llama3-70b": "meta-llama/Meta-Llama-3-70B-Instruct",
    "qwen2.5-72b-instruct": "Qwen/Qwen2.5-72B-Instruct",
    "qwen2.5-32b-instruct": "Qwen/Qwen2.5-32B-Instruct",
    "qwen2.5-14b-instruct": "Qwen/Qwen2.5-14B-Instruct",
    "qwen2.5-7b-instruct": "Qwen/Qwen2.5-7B-Instruct",
    "qwen2.5-3b-instruct": "Qwen/Qwen2.5-3B-Instruct",
    "qwen2.5-0.5b-instruct": "Qwen/Qwen2.5-0.5B-Instruct",
}


def get_model_family(model_id: Union[str, Path]):
    model_id = str(model_id)
    if "llama" in model_id:
        return "llama"
    elif "qwen" in model_id.lower():
        return "qwen"
    else:
        raise ValueError(f"Model family not recognized for {model_id}")


def get_system_message(model_id: Union[str, Path]):
    model_family = get_model_family(model_id)
    return SYSTEM_MESSAGES[model_family]


# %%
class LLM:
    def __init__(
        self,
        base_model_name_or_path: Union[str, os.PathLike],  # The base model's name or path
        adapter_ids: list[Path] = None,  # list of adapter's paths
        opening_message: Message = None,  # This message is always added to the beginning of the prompt
    ):
        if base_model_name_or_path in MODEL_FULL_NAME:
            base_model_name_or_path = MODEL_FULL_NAME[base_model_name_or_path]
            self.model_path = base_model_name_or_path
        else:
            self.model_path = base_model_name_or_path

        self.model_family = get_model_family(base_model_name_or_path)

        self.model = None
        self.temperature = 0.5

        tokenizer_path = self.model_path
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)

        if 'LOCAL_RANK' not in os.environ or os.environ['LOCAL_RANK'] == 0:
            print("Tokenizer loaded", flush=True)

        if self.model_family == "llama":
            self.llama_eot_token = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")

        adapter_ids = adapter_ids or []
        self.adapter_ids = [get_adapter_path(adapter_id) for adapter_id in adapter_ids]

        # This message is always added to the beginning of the prompt
        self.opening_message = opening_message

    @classmethod
    def from_adapter(cls, adapter_id: str, opening_message: str = None):
        """Create an LLM from an adapter."""
        assert adapter_id, "Adapter ID is empty"
        model_id, adapter_ids = get_adapter_chain(adapter_id)
        return cls(model_id, adapter_ids, opening_message)

    def get_config(self):
        return {
            "model_path": str(self.model_path),
            "adapter_ids": [str(adapter_id) for adapter_id in self.adapter_ids],
        }

    def messages_to_prompt(self, messages: list, no_template: bool = False) -> str:
        if self.opening_message and not no_template:
            messages = [self.opening_message] + messages
        if no_template:
            prompt = ""
            for i, msg in enumerate(messages):
                prompt += (" " if i else "") + msg.content
            return prompt
        if self.model_family == "llama":
            return self.llama_messages_to_prompt(messages)
        elif self.model_family == "qwen":
            return self.qwen_messages_to_prompt(messages)
        else:
            return NotImplementedError("Unknown model family: " + self.model_family)

    def qwen_messages_to_prompt(self, messages: list) -> str:
        new_messages = []
        for msg in messages:
            if msg.role == Role.SYSTEM:
                new_messages.append({
                    "role": "system",
                    "content": msg.content,
                })
            elif msg.role == Role.USER:
                new_messages.append({
                    "role": "user",
                    "content": msg.content,
                })
            else:
                raise ValueError(f"Wrong message role {msg.role}.")
        prompt = self.tokenizer.apply_chat_template(
            new_messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return prompt

    def llama_messages_to_prompt(self, messages: list) -> str:
        prompt = ""
        for msg in messages:
            if msg.role == Role.SYSTEM:
                prompt += f"<|start_header_id|>system<|end_header_id|>\n\n{msg.content}<|eot_id|>"
            elif msg.role == Role.AI:
                prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{msg.content}<|eot_id|>"
            elif msg.role == Role.USER:
                prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{msg.content}<|eot_id|>"
            else:
                raise ValueError(f"Wrong message role {msg.role}.")
        prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
        return prompt

    def tokenize(self, seq: str) -> torch.Tensor:
        "Tokenize a sequence without adding special tokens."
        tokens = self.tokenizer.encode(seq, add_special_tokens=False, return_tensors="pt")
        return tokens

    def add_bos(self, tokens: torch.Tensor):
        if self.tokenizer.bos_token_id:
            bos = torch.tensor([[self.tokenizer.bos_token_id]])
            return torch.cat([bos, tokens], dim=1)
        else:
            return tokens

    def add_eos(self, tokens: torch.Tensor):
        if self.model_family == "llama":
            eos = torch.tensor([[self.llama_eot_token]])
        else:
            eos = torch.tensor([[self.tokenizer.eos_token_id]])
        return torch.cat([tokens, eos], dim=1)

    def decode(self, tokens: torch.Tensor):
        return self.tokenizer.batch_decode(tokens)[0]

    def extract_question(self, q_str: str):
        if self.model_family == "llama":
            question = "".join(''.join(q_str.split("<|end_header_id|>")[2:]).split("<|eot_id|>")[0].strip())
        elif self.model_family == "qwen":
            question = (
                "".join(q_str.split("<|im_start|>user")[1:])
                .split("<|im_end|>")[0]
                .strip()
            )
        else:
            raise NotImplementedError(f"Function extract_question not implemented for LLM family {self.model_family}")
        return question

    def get_terminators(self):
        terminators = [self.tokenizer.eos_token_id]
        if self.model_family == 'llama':
            terminators.append(self.llama_eot_token)
        return terminators

    def load_model(self, quantize: bool = False, training: bool = False, legacy: bool = False, deepspeed: bool = False):
        """Load the model and all adapters and merge them."""

        load_in_8bit = quantize
        #torch_dtype = torch.float16
        torch_dtype = torch.bfloat16
        if training:
            device_map = None
        else:
            device_map = "auto"  # You can't train a model loaded with device_map='auto' in any distributed mode.

        t0 = time.perf_counter()
        base_model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            device_map=device_map if not deepspeed else None,
            load_in_8bit=load_in_8bit,
            torch_dtype=torch_dtype,
            trust_remote_code=True,
        )
        t = time.perf_counter() - t0
        print(f"Time to load the model: {t:.02f} sec", flush=True)

        # No adapter
        if not self.adapter_ids:
            self.model = base_model
            return self.model

        assert not quantize, "Quantization is not tested with adapters"
        assert len(self.adapter_ids) == 1, "Only one adapter is supported at the moment"
        for adapter_id in self.adapter_ids:
            t0 = time.perf_counter()
            model = base_model
            adapter_name = adapter_id.name
            print(f"Load and merge adapter {adapter_id} under name {adapter_name}")
            model = PeftModel.from_pretrained(
                model=model,
                model_id=adapter_id,
                adapter_name="lora",
                is_trainable=False,
            )
            model = model.merge_and_unload()
            t = time.perf_counter() - t0
            print(f"Adapter {adapter_id} loaded and merged: {t:.02f} sec", flush=True)

        self.model = model

        return self.model

    def generate(
        self,
        input_ids: torch.Tensor,
        *,
        attention_mask: torch.Tensor = None,
        temperature: float = None,
        max_new_tokens: int = 2000,
        do_sample: bool = True,
    ) -> Union[str, bool]:

        input_ids = input_ids.to(DEVICE)

        t0 = time.perf_counter()

        terminators = self.get_terminators()

        tokens = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=temperature or self.temperature,
            do_sample=do_sample,
            eos_token_id=terminators,
        )

        truncated = (
            bool(tokens[0, -1] not in terminators)
        )

        prompt_length = input_ids.size(1)
        output_tokens = tokens[:, prompt_length:]
        t = time.perf_counter() - t0
        n_generated = output_tokens.size(1)
        print(f"Generated {n_generated} tokens, time: {t:.02f} sec total, {n_generated / t:.02f} tokens/sec")

        output = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)

        return output, truncated

    def call(
        self,
        messages: list[Message] = None,
        temperature: float = None,
        max_new_tokens: int = 2000,
        merge_messages_by_role: bool = True,
    ) -> Union[str, bool]:
        if merge_messages_by_role:
            messages = merge_messages(messages)
        prompt = self.messages_to_prompt(messages)

        inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")
        content, trunctated = self.generate(
            **inputs,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
        )
        return content, trunctated


def get_adapter_chain(adapter_id: str):
    """Get the model ID and adapter chain for a given adapter."""

    adapter_path = Path(get_adapter_path(adapter_id))

    # Our config file is in the adapter's directory
    base_model_config_file = adapter_path / "base_model_config.json"
    if os.path.exists(base_model_config_file):
        with open(base_model_config_file, 'r') as f:
            base_model_config = json.load(f)

        model_id = base_model_config["model_path"]
        adapter_ids = base_model_config["adapter_ids"] + [adapter_path]

        return model_id, adapter_ids

    warnings.warn(f"Adapter {adapter_id} does not have base_model_config.json", stacklevel=2)
    adapter_config_file = adapter_path / "adapter_config.json"
    with open(adapter_config_file, 'r') as f:
        adapter_config = json.load(f)

    if adapter_config["base_model_name_or_path"]:
        model_id = adapter_config["base_model_name_or_path"]
        adapter_ids = [adapter_path]

    return model_id, adapter_ids


class Tokenizer:
    def __init__(self, tokenizer_id: str):
        # TODO: Implement tokenizer: Union[str, Path]
        assert isinstance(tokenizer_id, str), "Tokenizer must be a string"

        if tokenizer_id in MODEL_FULL_NAME:
            tokenizer_id = MODEL_FULL_NAME[tokenizer_id]

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
        self.model_family = get_model_family(tokenizer_id)

    def tokenize(self, text: str) -> torch.Tensor:
        tokens = self.tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
        return tokens

    def add_bos(self, tokens: torch.Tensor):
        bos = torch.tensor([[self.tokenizer.bos_token_id]])
        return torch.cat([bos, tokens], dim=1)

    def add_eos(self, tokens: torch.Tensor):
        if self.model_family == "llama":
            eos = torch.tensor([[self.llama_eot_token]])
        else:
            eos = torch.tensor([[self.tokenizer.eos_token_id]])
        return torch.cat([tokens, eos], dim=1)
