from functools import partial

from peft import LoraConfig, get_peft_model
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from mow.common import defaults
from mow.common.trainer import CustomTrainer, CustomTrainerConfig
from mow.dataset.auto import AutoChatDatasetBuilder
from mow.dataset.history import ChatHistoryMixin
from mow.utils.config import TrainConfigMixin
from mow.utils.types import instanceof


class TrainExpertConfig(TrainConfigMixin, key="config"):
    """
    Configuration class for training an expert model.
    """

    default_lora_config = defaults.default_lora_config

    default_train_config = CustomTrainerConfig(
        batch_size=4,
        logging_steps=500,
        save_steps=1000,
        eval_steps=1000,
    )

    def __init__(
        self,
        *,
        main_model: str,
        model_config: AutoConfig | None = None,
        lora_config: LoraConfig | None = None,
        train_config: CustomTrainerConfig | None = None,
        train_dataset: str,
        eval_dataset: str,
        num_train_samples: int = 0,
        num_eval_samples: int = 0,
    ):
        if model_config is None:
            model_config = AutoConfig.from_pretrained(main_model)
            assert model_config is not None, "Model config cannot be None"
        if lora_config is None:
            lora_config = self.default_lora_config
        if train_config is None:
            train_config = self.default_train_config

        super().__init__(train_config=train_config)

        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.num_train_samples = num_train_samples
        self.num_eval_samples = num_eval_samples
        self.main_model_path = main_model
        self.model_config = model_config
        self.lora_config = lora_config


def train_expert(config: TrainExpertConfig):
    main_model = AutoModelForCausalLM.from_pretrained(
        config.main_model_path, config=config.model_config
    )
    peft_model = get_peft_model(main_model, config.lora_config)

    tokenizer = AutoTokenizer.from_pretrained(config.main_model_path)

    train_dataset = (
        AutoChatDatasetBuilder.load(config.train_dataset)
        .doif(
            lambda builder: instanceof(builder, ChatHistoryMixin),
            lambda builder: builder.expand(desc="Expanding chat history"),
        )
        .as_chat(tokenizer=tokenizer)
        .shuffle()
        .unwrap()
    )

    eval_dataset = (
        AutoChatDatasetBuilder.load(config.eval_dataset)
        .doif(
            lambda builder: instanceof(builder, ChatHistoryMixin),
            lambda builder: builder.expand(desc="Expanding chat history"),
        )
        .as_chat(tokenizer=tokenizer)
        .unwrap()
    )

    if config.num_train_samples > 0:
        train_dataset = train_dataset.take(config.num_train_samples)
    if config.num_eval_samples > 0:
        eval_dataset = eval_dataset.shuffle().take(config.num_eval_samples)

    trainer = CustomTrainer(
        model=peft_model,
        tokenizer=tokenizer,
        args=config.train_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    trainer.train()
    if config.train_config.output_dir is not None:
        trainer.save_model(config.train_config.output_dir)
