import torch
import os
import logging
from typing import Any, Optional, Mapping, Sequence
from transformers import AutoModelForCausalLM, AutoTokenizer  # type: ignore[import-untyped]
import niuload  # type: ignore[import-untyped]

from src.language_models.language_model import LanguageModel
from src.typings import (
    Role,
    ChatHistoryItem,
    LanguageModelContextLimitException,
    LanguageModelOutOfMemoryException,
    ChatHistory,
)


class HuggingfaceLanguageModel(LanguageModel):
    def __init__(
        self,
        model_name_or_path: str,
        role_dict: Mapping[str, str],
        dtype: torch.dtype | str = torch.bfloat16,
        device_map: str | Mapping[str, Any] = "auto",
    ):
        """
        Config explanations
        dtype: `dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`
        device_map: I cannot find the detail documents.
            But it seems that it can be set to "cuda" or {"": "cuda"} to use GPU.
            Set "auto" can use multiple GPUs. (Amazing!)
        """
        super().__init__(role_dict)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        if device_map == "niuload":
            device_map = niuload.balanced_load(
                model_name_or_path, return_device_map_only=True
            )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path, device_map=device_map, torch_dtype=dtype
        )

    def _convert_message_list_to_model_input_dict(
        self, batch_message_list: Sequence[Sequence[Mapping[str, str]]]
    ) -> Mapping[str, torch.Tensor]:
        batch_input_ids: torch.Tensor = self.tokenizer.apply_chat_template(
            batch_message_list,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
            padding=True,
        ).to(self.model.device)
        no_padding_mask: torch.Tensor = batch_input_ids != self.tokenizer.pad_token_id
        first_no_padding_token_index: torch.Tensor
        _, first_no_padding_token_index = (
            (no_padding_mask.cumsum(-1) == 1) & no_padding_mask
        ).max(-1)
        row_repeat: torch.Tensor = (
            torch.arange(batch_input_ids.shape[1])
            .unsqueeze(0)
            .repeat(batch_input_ids.shape[0], 1)
        ).to(self.model.device)
        column_repeat: torch.Tensor = first_no_padding_token_index.unsqueeze(1).repeat(
            1, batch_input_ids.shape[1]
        )
        batch_attention_mask_ne: torch.Tensor = row_repeat < column_repeat
        batch_attention_mask: torch.Tensor = ~batch_attention_mask_ne
        return {
            "batch_input_ids": batch_input_ids,
            "batch_attention_mask": batch_attention_mask,
        }

    @staticmethod
    def _is_any_gpu_memory_high() -> bool:
        for i in list(range(torch.cuda.device_count())):
            device = torch.device(f"cuda:{i}")
            free, total = torch.cuda.mem_get_info(device)
            # mem_used_gb = (total - free) / 1024**3
            if free / total < 0.1:  # Threshold set to 90%
                return True
        return False

    def _inference(
        self,
        batch_chat_history: Sequence[ChatHistory],
        inference_config_dict: Mapping[str, Any],
        system_prompt: str,
    ) -> Sequence[ChatHistoryItem]:
        original_tokenizer_padding_side = self.tokenizer.padding_side
        original_tokenizer_pad_token = self.tokenizer.pad_token
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        # endregion
        # region Construct batch_message_list
        message_list_prefix: list[Mapping[str, str]]
        if len(system_prompt) > 0:
            message_list_prefix = [{"role": "system", "content": system_prompt}]
        else:
            message_list_prefix = []
        batch_message_list: Sequence[Sequence[Mapping[str, str]]] = [
            message_list_prefix
            + self._convert_chat_history_to_message_list(chat_history)
            for chat_history in batch_chat_history
        ]
        # endregion
        # region Generate output
        model_input_dict: Mapping[str, torch.Tensor] = (
            self._convert_message_list_to_model_input_dict(batch_message_list)
        )
        batch_input_ids, batch_attention_mask = (
            model_input_dict["batch_input_ids"],
            model_input_dict["batch_attention_mask"],
        )
        del model_input_dict
        max_len = self.model.config.max_position_embeddings
        if batch_input_ids.shape[-1] > max_len:
            original_len = batch_input_ids.shape[-1]
            logging.warning(
                f"Input length {original_len} exceeds the model's max_position_embeddings {max_len}. "
                f"The input will be truncated to the last {max_len} tokens."
            )
            batch_input_ids = batch_input_ids[:, -max_len:]
            batch_attention_mask = batch_attention_mask[:, -max_len:]
        torch.cuda.synchronize()
        try:
            output_tensor: torch.Tensor = self.model.generate(
                batch_input_ids,
                attention_mask=batch_attention_mask,
                pad_token_id=self.tokenizer.eos_token_id,  # Mute warning
                **inference_config_dict,
            )
        except Exception as e:
            if (
                isinstance(e, torch.cuda.OutOfMemoryError)
                or HuggingfaceLanguageModel._is_any_gpu_memory_high()
            ):
                torch.cuda.empty_cache()
                raise LanguageModelOutOfMemoryException(str(e)) from e
            else:
                raise e
        finally:
            torch.cuda.synchronize()
        # endregion
        # region Convert output to ChatHistoryItem
        output_str_list: Sequence[str] = self.tokenizer.batch_decode(
            output_tensor[:, batch_input_ids.shape[1] :], skip_special_tokens=True
        )
        output_list: Sequence[ChatHistoryItem] = [
            ChatHistoryItem(role=Role.AGENT, content=output_str)
            for output_str in output_str_list
        ]
        # endregion
        # region Reset the tokenizer attributes
        self.tokenizer.padding_side = original_tokenizer_padding_side
        self.tokenizer.pad_token = original_tokenizer_pad_token
        # endregion
        return output_list
