from __future__ import annotations

from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, cast, override

from sentence_transformers import SentenceTransformer
from transformers import (
    EvalPrediction,
    PreTrainedTokenizerBase,
    SchedulerType,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    TrainingArguments,
)

from datasets import Dataset, concatenate_datasets
from mow.common import defaults
from mow.common.data import prepare_batch_data, prepare_graph_representation
from mow.common.trainer import (
    CustomTrainer,
    CustomTrainerConfig,
    PredictionCallback,
)
from mow.dataset import AutoChatDatasetBuilder
from mow.dataset.history import ChatHistoryMixin
from mow.modules.mow import MoW, MoWConfig
from mow.modules.routers import GraphRouterConfig
from mow.utils.config import TrainConfigMixin
from mow.utils.types import instanceof


@dataclass
class MowTrainerConfig(CustomTrainerConfig):
    update_embeddings_steps: int = field(
        default=100,
        metadata={
            "help": "Number of steps between updating the router embeddings."
        },
    )


class TrainMoWConfig(TrainConfigMixin[MowTrainerConfig], key="config"):
    """
    Configuration class for training an router model.
    """

    default_train_config = MowTrainerConfig(
        batch_size=4,
        logging_steps=500,
        save_steps=1000,
        eval_steps=1000,
    )

    default_lora_config = defaults.default_lora_config

    def __init__(
        self,
        *,
        mow_config: MoWConfig,
        sentence_transformer_model: str,
        train_config: MowTrainerConfig | None = None,
        train_from: str | None = None,
        datasets: dict[str, str] | dict[str, Path],
        num_train_samples: int = 0,
        num_eval_samples: int = 0,
    ):
        if train_config is None:
            train_config = self.default_train_config

        super().__init__(train_config=train_config)

        self.mow = mow_config
        self.train_from = train_from
        self.sentence_transformer_model = sentence_transformer_model
        self.datasets = {key: Path(value) for key, value in datasets.items()}
        self.num_train_samples = num_train_samples
        self.num_eval_samples = num_eval_samples


def train_mow(config: TrainMoWConfig):
    sentence_transformer = SentenceTransformer(
        config.sentence_transformer_model
    )
    sentence_transformer.eval()

    if config.train_from:
        path = Path(config.train_from)
        if not path.exists():
            raise FileNotFoundError(f"Path {path} does not exist.")
        mow = cast(MoW, MoW.from_pretrained(path))
        print(f"🌟 Loaded MoW model from: {path}")
    else:
        mow = MoW(config.mow)

    print(f"🚀 Trainable modules: {mow.trainable_modules}", end="\n\n")

    def prepare_datasets(phase: Literal["train", "test"]):
        return {
            name: (
                AutoChatDatasetBuilder.load(dataset / phase)
                .doif(
                    lambda builder: instanceof(builder, ChatHistoryMixin),
                    lambda builder: builder.expand(
                        desc=f"Expanding {name} {phase} histories"
                    ),
                )
                .as_chat(
                    tokenizer=mow.tokenizer,
                    batched=False,
                    desc=f"Converting {name} {phase} dataset to chat format",
                    action_only=True,
                )
                .prepare_graph_representation(
                    sentence_transformer=sentence_transformer,
                    desc=f"Preparing graph representation for {name} {phase} dataset",
                )
                .unwrap(
                    type="pt",
                    columns=[
                        "text",
                        "context",
                        "nodes",
                        "adjacency_matrix",
                        "relation_matrix",
                    ],
                )
            )
            for name, dataset in config.datasets.items()
        }

    original_train_datasets = prepare_datasets("train")
    original_eval_datasets = prepare_datasets("test")

    train_dataset = concatenate_datasets(list(original_train_datasets.values()))
    eval_dataset = concatenate_datasets(list(original_eval_datasets.values()))

    train_dataset = train_dataset.shuffle()

    if config.num_train_samples > 0:
        train_dataset = train_dataset.take(config.num_train_samples)
    if config.num_eval_samples > 0:
        eval_dataset = eval_dataset.take(config.num_eval_samples)

    data_collator_lm = defaults.default_data_collator_for_lm(
        tokenizer=mow.tokenizer
    )

    if config.train_config.update_embeddings_steps > 0:
        num_samples = 100 if len(train_dataset) >= 100 else len(train_dataset)
        embedding_set_source = {
            name: prepare_batch_data(dataset, data_collator_lm=data_collator_lm)
            for name, dataset in {
                name: dataset.shuffle()
                .take(num_samples)
                .map(
                    partial(
                        prepare_graph_representation,
                        sentence_transformer=sentence_transformer,
                    ),
                    batched=False,
                )
                .with_format(type="pt")
                for name, dataset in original_train_datasets.items()
            }.items()
        }
    else:
        embedding_set_source = None

    trainer = MoWTrainer(
        model=mow,
        tokenizer=mow.tokenizer,
        args=config.train_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=partial(
            prepare_batch_data, data_collator_lm=data_collator_lm
        ),
        remove_unused_columns=False,
        callbacks=[
            PredictionCallback(mow, eval_dataset, config.train_config),
        ],
        embedding_set_source=embedding_set_source,
    )

    sentence_transformer.to(mow.device)
    trainer.train()

    if (output_dir := config.train_config.output_dir) is not None:
        trainer.save_model(output_dir / "best")
    return trainer


def hyperparameter_search(
    config: TrainMoWConfig,
    n_trials: int = 10,
    direction: str = "minimize",
    study_name: str | None = None,
):
    try:
        from optuna import create_study
    except ImportError:
        raise ImportError(
            "Optuna is not installed. Please install it with `pip install optuna`."
        )

    def objective(trial):
        num_layers = trial.suggest_int("num_layers", 2, 16)
        dropout = trial.suggest_float("router_dropout", 0.1, 0.5)
        train_config = MowTrainerConfig(
            max_steps=trial.suggest_int("max_steps", 1000, 5000),
            eval_steps=100,
            update_embeddings_steps=trial.suggest_int(
                "update_embeddings_steps", 50, 500, log=True
            ),
            logging_steps=1,
            batch_size=64,
            learning_rate=trial.suggest_float("learning_rate", 1e-6, 1e-1),
            lr_scheduler_type=trial.suggest_categorical(
                "lr_scheduler_type",
                [
                    SchedulerType.LINEAR,
                    SchedulerType.COSINE,
                    SchedulerType.POLYNOMIAL,
                    SchedulerType.CONSTANT,
                ],
            ),
            warmup_steps=trial.suggest_int("warmup_steps", 50, 500, log=True),
        )
        new_config = TrainMoWConfig(
            mow_config=MoWConfig(
                expert_models=config.mow.expert_models,
                strategy=config.mow.strategy,
                router_model=config.mow.router_model,
                router_config=GraphRouterConfig(
                    num_layers=num_layers,
                    dropout=dropout,
                    aggregate_layers=(
                        config.mow.router_config.aggregate_layers
                        if config.mow.router_config is not None
                        else None
                    ),
                ),
                lora_config=config.mow.lora_config,
            ),
            sentence_transformer_model=config.sentence_transformer_model,
            train_config=train_config,
            datasets=config.datasets,
            num_train_samples=config.num_train_samples,
            num_eval_samples=config.num_eval_samples,
        )
        trainer = train_mow(new_config)
        return trainer.evaluate()["eval_loss"]

    study = create_study(direction=direction, study_name=study_name)
    study.optimize(objective, n_trials=n_trials)
    return study


class MoWTrainerCallback(TrainerCallback):
    def __init__(self, trainer: MoWTrainer, trainer_args: MowTrainerConfig):
        self.trainer = trainer
        self.update_embeddings_steps = trainer_args.update_embeddings_steps

    @override
    def on_step_begin(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if (
            self.update_embeddings_steps > 0
            and state.global_step % self.update_embeddings_steps == 0
        ):
            self.trainer.update_embedding_set()


class MoWTrainer(CustomTrainer):
    def __init__(
        self,
        model: MoW,
        args: MowTrainerConfig,
        train_dataset: Dataset,
        eval_dataset: Dataset,
        tokenizer: PreTrainedTokenizerBase,
        embedding_set_source: dict[str, dict] | None = None,
        compute_metrics: Callable[[EvalPrediction], dict] | None = None,
        data_collator: Callable[..., Any] | None = None,
        callbacks: list[TrainerCallback] | None = None,
        **trainer_kwargs,
    ):
        self.model = model
        self.update_embeddings_steps = args.update_embeddings_steps

        super().__init__(
            model,
            args,
            train_dataset,
            eval_dataset,
            tokenizer,
            compute_metrics,
            data_collator,
            [
                MoWTrainerCallback(self, args),
                *(callbacks or []),
            ],
            **trainer_kwargs,
        )

        self.embedding_set_source = embedding_set_source

    @property
    def device(self):
        return self.model.device

    def update_embedding_set(self):
        if self.embedding_set_source is None:
            raise ValueError(
                "embedding_set_source is not set. "
                "Please provide a dictionary of embedding sets."
            )

        print("🚀 Updating embedding set...")
        for name, dataset in self.embedding_set_source.items():
            for router in self.model.router.values():
                router.update_embedding_set(
                    name=name,
                    hidden_states=dataset["hidden_states"].to(self.device),
                    adjacency_matrix=dataset["adjacency_matrix"].to(
                        self.device
                    ),
                    relation_matrix=dataset.get("relation_matrix", None),
                    context=dataset["context"].to(self.device),
                )
