import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import torch
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from .utils import get_model_max_length


logger = logging.getLogger(__name__)

FILE_PATH = Path(__file__)


@dataclass
class ModelConfig:
    model_id: Optional[str]
    base_dir: Optional[Path | str] = None
    pretrained: bool = True
    inference_only: bool = False

    @property
    def model_id_not_none(self) -> str:
        if self.model_id is None:
            raise ValueError("model_id is None")
        return self.model_id

    @property
    def name(self) -> str:
        assert self.model_id is not None
        return self.model_id if self.pretrained else f"{self.model_id}_ut"

    @property
    def training_dtype(self) -> torch.dtype:
        if self.model_id_not_none.startswith(
            "Llama"
        ) or self.model_id_not_none.startswith("llama"):
            # Llama2 was trained in bfloat16, even though the weights
            # are stored in float16 on HF. There can be training
            # instabilities with float16, so we use bfloat16 for training.
            return torch.bfloat16
        elif self.model_id_not_none.startswith("gpt2"):
            # GPT2 is trained in float32, and we use float32 for training.
            return torch.float32
        elif self.model_id_not_none.startswith(
            "EleutherAI/pythia"
        ) or self.model_id_not_none.startswith("pythia"):
            # Pythia convergest faster with bfloat16 than with float16,
            # and might have been trained that way as well.
            if self.model_id_not_none.endswith("-70m"):
                # The loss does not converge for this (small) model
                # with bfloat16, so we use float32 for training,
                # since this model is also not big.
                return torch.float32
            else:
                return torch.bfloat16
        elif self.model_id_not_none.startswith("phi"):
            return torch.bfloat16
        elif self.model_id_not_none.startswith("opt"):
            return torch.bfloat16
        else:
            raise ValueError(f"Unknown model_id: {self.model_id}")

    @property
    def checkpoint(self) -> str:
        assert not (self.base_dir is None and self.model_id is None)

        if self.model_id is None:
            return str(self.base_dir)

        if self.model_id.startswith("phi"):
            model_id = f"microsoft/{self.model_id}"
        elif self.model_id.startswith("opt"):
            model_id = f"facebook/{self.model_id}"
        else:
            model_id = self.model_id

        if self.base_dir is None:
            return model_id
        return str(Path(self.base_dir) / model_id)

    # @property
    # def loader_args(self) -> dict:
    #     if self.inference_only:
    #         dtype = "auto"
    #     else:
    #         dtype = self.training_dtype
    #     return {
    #         "pretrained_model_name_or_path": self.checkpoint,
    #         "torch_dtype": dtype,
    #     }


def load_model(
    model_config: ModelConfig,
) -> PreTrainedModel:
    checkpoint = model_config.checkpoint
    if model_config.inference_only:
        dtype_to_load = "auto"
    else:
        dtype_to_load = model_config.training_dtype

    if model_config.pretrained:
        logger.info(f"Loading pretrained model from {checkpoint}")
        model = AutoModelForCausalLM.from_pretrained(
            model_config.checkpoint,
            torch_dtype=dtype_to_load,
        )
    else:
        logger.info(f"Loading config-only model from {checkpoint}")
        # Only load the config file, i.e. don't load the weights
        config = AutoConfig.from_pretrained(checkpoint)
        model = AutoModelForCausalLM.from_config(
            config,
            torch_dtype=dtype_to_load,
            # trust_remote_code=model_config.trust_remote_code,
        )
    return model


def load_tokenizer(
    model_config: ModelConfig,
    max_length: int,
    *,
    special_tokens: list[str] | None = None,
) -> PreTrainedTokenizer:
    checkpoint = model_config.checkpoint
    logger.info(f"Loading tokenizer for {checkpoint}")
    tokenizer = AutoTokenizer.from_pretrained(
        str(checkpoint),
        padding_side="left",
        truncation_side="left",
        max_length=max_length,
    )
    tokenizer.pad_token = tokenizer.eos_token
    if special_tokens is not None:
        tokenizer.add_special_tokens(
            {
                "additional_special_tokens": special_tokens,
            }
        )
    return tokenizer


def load_model_tokenizer(
    model_config: ModelConfig,
    *,
    special_tokens: list[str] | None = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
    model = load_model(
        model_config,
        # gradient_checkpointing=gradient_checkpointing,
    )
    max_length = get_model_max_length(model)
    tokenizer = load_tokenizer(
        model_config, max_length=max_length, special_tokens=special_tokens
    )
    # model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
    return model, tokenizer
