from dataclasses import dataclass
from typing import Any, cast

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    FineGrainedFP8Config,
    Mistral3ForConditionalGeneration,
    PreTrainedModel,
    TokenizersBackend,
)

from hallucinations.config import LlmConfig

LLAMA_3_MODELS = [
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "meta-llama/Meta-Llama-3.1-8B",
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.2-3B",
]

PHI_35_MODELS = [
    "microsoft/Phi-3.5-mini-instruct",
]

PHI_2_MODELS = [
    "microsoft/phi-2",
]

MISTRAL_NEMO_MODELS = [
    "mistralai/Mistral-Nemo-Instruct-2407",
    "mistralai/Mistral-Small-24B-Instruct-2501",
]

MINISTRAL_3_MODELS = [
    "mistralai/Ministral-3-3B-Instruct-2512",
    "mistralai/Ministral-3-3B-Base-2512",
]

QWEN_MODELS = [
    "Qwen/Qwen3-4B-Instruct-2507",
    "Qwen/Qwen3-4B",
]


OLMO_3_MODELS = [
    "allenai/Olmo-3-7B-Instruct",
]


@dataclass
class ModelForGeneration:
    llm: PreTrainedModel
    tokenizer: TokenizersBackend
    generate_kwargs: dict[str, Any]
    default_system_prompt: str | None = None


def get_llm(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    if llm_config.name in LLAMA_3_MODELS:
        return get_llama_3(llm_config, **kwargs)
    elif llm_config.name in PHI_35_MODELS:
        return get_phi_35(llm_config, **kwargs)
    elif llm_config.name in PHI_2_MODELS:
        return get_phi_2(llm_config, **kwargs)
    elif llm_config.name in MISTRAL_NEMO_MODELS:
        return get_mistral_nemo(llm_config, **kwargs)
    elif llm_config.name in MINISTRAL_3_MODELS:
        return get_ministral_3(llm_config, **kwargs)
    elif llm_config.name in QWEN_MODELS:
        return get_llama_3(llm_config, **kwargs)  # same way of loading
    elif llm_config.name in OLMO_3_MODELS:
        return get_olmo3(llm_config, **kwargs)
    else:
        raise ValueError(f"Model {llm_config.name} not supported.")


def get_llama_3(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    model = get_model(llm_config, **kwargs)
    tokenizer = get_tokenizer(llm_config, set_pad_token=True)

    return ModelForGeneration(
        llm=model,
        tokenizer=tokenizer,
        generate_kwargs={
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
        },
    )


def get_mistral_nemo(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    model = get_model(llm_config, **kwargs)
    tokenizer = get_tokenizer(llm_config, set_pad_token=True)

    return ModelForGeneration(
        llm=model,
        tokenizer=tokenizer,
        generate_kwargs={
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
        },
    )


def get_ministral_3(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    # tokenizer = MistralCommonBackend.from_pretrained(llm_config.name)
    # TODO: due to code-compatibility issues, we use hf tokenizer instead of mistral tokenizer
    tokenizer = get_tokenizer(llm_config)
    model = Mistral3ForConditionalGeneration.from_pretrained(
        llm_config.name,
        dtype=llm_config.torch_dtype,
        quantization_config=FineGrainedFP8Config(dequantize=True),
        attn_implementation=llm_config.attn_implementation,
        **kwargs,
    )
    return ModelForGeneration(
        llm=model,
        tokenizer=tokenizer,
        generate_kwargs={},
        default_system_prompt="You are a helpful AI assistant.",
    )


def get_phi_35(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    model = get_model(llm_config, **kwargs)
    tokenizer = get_tokenizer(llm_config)

    return ModelForGeneration(
        llm=model,
        tokenizer=tokenizer,
        generate_kwargs={
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
        },
    )


def get_phi_2(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    model = get_model(llm_config, **kwargs)
    tokenizer = get_tokenizer(llm_config, set_pad_token=True)

    return ModelForGeneration(
        llm=model,
        tokenizer=tokenizer,
        generate_kwargs={
            "eos_token_id": tokenizer.eos_token_id,
            "pad_token_id": tokenizer.pad_token_id,
        },
    )


def get_olmo3(llm_config: LlmConfig, **kwargs: Any) -> ModelForGeneration:
    model = get_model(llm_config, **kwargs)
    tokenizer = get_tokenizer(llm_config, set_pad_token=False)

    return ModelForGeneration(
        llm=model,
        tokenizer=tokenizer,
        generate_kwargs={},
    )


def get_model(
    llm_config: LlmConfig,
    **kwargs: Any,
) -> PreTrainedModel:
    if llm_config.quantization is not None:
        kwargs["quantization_config"] = BitsAndBytesConfig(**llm_config.quantization)

    model = AutoModelForCausalLM.from_pretrained(
        llm_config.name,
        dtype=llm_config.torch_dtype,
        attn_implementation=llm_config.attn_implementation,
        trust_remote_code=llm_config.trust_remote_code,
        **kwargs,
    )

    return model


def get_tokenizer(
    llm_config: LlmConfig,
    set_pad_token: bool = False,
    **kwargs: Any,
) -> TokenizersBackend:
    tokenizer = AutoTokenizer.from_pretrained(
        llm_config.tokenizer_name,
        trust_remote_code=llm_config.trust_remote_code,
        **kwargs,
    )
    tokenizer.padding_side = llm_config.tokenizer_padding_side

    if set_pad_token:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer = cast(TokenizersBackend, tokenizer)
    return tokenizer
