from functools import partial
from pathlib import Path
from typing import Literal, Self, TypeGuard, override

import torch
from sentence_transformers import SentenceTransformer
from torch.utils.data import RandomSampler
from torch.utils.data.sampler import Sampler
from transformers import SchedulerType, Trainer, TrainingArguments

from datasets import concatenate_datasets
from mow.common.data import (
    collapse_data_by_trajectory,
    indexing,
    prepare_batch_data,
)
from mow.common.loss import contrastive_loss
from mow.common.trainer import CustomTrainerConfig
from mow.dataset import AutoChatDatasetBuilder
from mow.dataset.history import ChatHistoryMixin
from mow.modules.routers import GraphRouter, GraphRouterConfig
from mow.utils.config import TrainConfigMixin
from mow.utils.types import instanceof


class TrainRouterConfig(TrainConfigMixin, key="config"):
    """
    Configuration class for training an router model.
    """

    default_train_config = CustomTrainerConfig(
        batch_size=4,
        logging_steps=500,
        save_steps=1000,
        eval_steps=1000,
    )

    def __init__(
        self,
        *,
        router_config: GraphRouterConfig,
        sentence_transformer_model: str,
        train_config: CustomTrainerConfig | None = None,
        datasets: dict[str, str] | dict[str, Path],
        num_train_samples: int = 0,
        num_eval_samples: int = 0,
        temperature: float = 0.1,
        lambda_: float = 0,
        collapse_data: bool = False,
    ):
        if train_config is None:
            train_config = self.default_train_config
        super().__init__(train_config=train_config)

        self.router = router_config
        self.sentence_transformer_model = sentence_transformer_model
        self.datasets = {key: Path(value) for key, value in datasets.items()}
        self.num_train_samples = num_train_samples
        self.num_eval_samples = num_eval_samples
        self.temperature = temperature
        self.lambda_ = lambda_
        self.collapse_data = collapse_data

    @override
    @classmethod
    def from_file(cls, path: str | Path) -> Self:
        config = super().from_file(path)
        assert config.train_config.output_dir is not None
        config.train_config.output_dir = (
            config.train_config.output_dir / "router"
        )
        return config


def train_router(config: TrainRouterConfig):
    sentence_transformer = SentenceTransformer(
        config.sentence_transformer_model
    )
    sentence_transformer.to(
        torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )
    sentence_transformer.eval()

    sentence_embedding_dimension = (
        sentence_transformer.get_sentence_embedding_dimension() or 384
    )
    if config.router.embed_dim is not None:
        print(
            f"⚠️ Provided embed_dim {config.router.embed_dim} will be "
            f"overridden by the sentence transformer dimension "
            f"{sentence_embedding_dimension}."
        )
    config.router.embed_dim = sentence_embedding_dimension

    model = GraphRouter(
        config=config.router,
        compute_loss=partial(
            contrastive_loss,
            temperature=config.temperature,
            lambda_=config.lambda_,
        ),
    )
    print(
        f"Model size: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M"
    )

    def prepare_dataset(phase: Literal["train", "test"]):
        return {
            name: AutoChatDatasetBuilder.load(dataset / phase)
            .mapif(
                lambda _: config.collapse_data,
                collapse_data_by_trajectory,
                batched=True,
                batch_size=None,
                desc=f"Collapsing {name} {phase} dataset",
            )
            .map(
                partial(indexing, dataset_idx=i),
                batched=False,
                desc=f"Indexing {name} {phase} dataset",
            )
            .doif(
                lambda builder: instanceof(builder, ChatHistoryMixin),
                lambda builder: builder.expand(
                    desc=f"Expanding {name} {phase} histories"
                ),
            )
            .prepare_graph_representation(
                sentence_transformer=sentence_transformer,
                batched=False,
                desc=f"Preparing graph representation for {name} {phase} dataset",
            )
            .shuffle(seed=42)
            .unwrap(
                type="pt",
                columns=[
                    "context",
                    "nodes",
                    "adjacency_matrix",
                    "relation_matrix",
                    "labels",
                ],
            )
            for i, (name, dataset) in enumerate(config.datasets.items())
        }

    original_train_datasets = prepare_dataset("train")
    original_eval_datasets = prepare_dataset("test")

    train_dataset = concatenate_datasets(list(original_train_datasets.values()))
    eval_dataset = concatenate_datasets(list(original_eval_datasets.values()))

    print(
        f"Train dataset size: {len(train_dataset)} samples, "
        f"Eval dataset size: {len(eval_dataset)} samples"
    )

    if config.num_train_samples > 0:
        train_dataset = train_dataset.take(config.num_train_samples)
    if config.num_eval_samples > 0:
        eval_dataset = eval_dataset.take(config.num_eval_samples)

    output_dir = config.train_config.output_dir

    sentence_transformer.to(model.device)

    trainer_args = TrainingArguments(
        output_dir=str(output_dir),
        per_device_train_batch_size=config.train_config.batch_size,
        per_device_eval_batch_size=config.train_config.batch_size,
        logging_steps=config.train_config.logging_steps,
        save_steps=config.train_config.save_steps,
        evaluation_strategy="steps",
        max_steps=config.train_config.max_steps,
        eval_steps=config.train_config.eval_steps,
        learning_rate=config.train_config.learning_rate,
        lr_scheduler_type=config.train_config.lr_scheduler_type,
        warmup_steps=config.train_config.warmup_steps,
        save_total_limit=1,
        remove_unused_columns=False,
        run_name=config.train_config.run_name,
        load_best_model_at_end=True,
    )

    class RouterTrainer(Trainer):
        @override
        def _get_train_sampler(self) -> Sampler | None:
            return RandomSampler(
                train_dataset,  # type: ignore
                replacement=True,
                num_samples=config.train_config.batch_size,
            )

    trainer = RouterTrainer(
        model=model,
        args=trainer_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=partial(prepare_batch_data, graph_augmentation=True),
    )
    trainer.train()

    for name, dataset in original_train_datasets.items():
        num_samples = 20 if len(dataset) >= 20 else len(dataset)
        sample = dataset.take(num_samples)
        sample = prepare_batch_data(sample, graph_augmentation=False)
        model.update_embedding_set(
            name=name,
            hidden_states=sample["hidden_states"].to(model.device),
            adjacency_matrix=sample["adjacency_matrix"].to(model.device),
            relation_matrix=sample["relation_matrix"].to(model.device),
            context=sample["context"].to(model.device),
        )

    trainer.save_model(output_dir / "best")
    return trainer


def hyperparameter_search(
    config: TrainRouterConfig,
    n_trials: int = 10,
    direction: str = "maximize",
    study_name: str | None = None,
):
    try:
        from optuna import create_study
    except ImportError:
        raise ImportError(
            "Optuna is not installed. Please install it with `pip install optuna`."
        )

    def objective(trial):
        num_layers = trial.suggest_int("num_layers", 2, 16)
        dropout = trial.suggest_float("router_dropout", 0.1, 0.5)
        train_config = CustomTrainerConfig(
            max_steps=trial.suggest_int("max_steps", 1000, 5000),
            eval_steps=100,
            logging_steps=1,
            batch_size=64,
            learning_rate=trial.suggest_float("learning_rate", 1e-6, 1e-1),
            lr_scheduler_type=trial.suggest_categorical(
                "lr_scheduler_type",
                [
                    SchedulerType.LINEAR,
                    SchedulerType.COSINE,
                    SchedulerType.POLYNOMIAL,
                    SchedulerType.CONSTANT,
                ],
            ),
            warmup_steps=trial.suggest_int("warmup_steps", 50, 500, log=True),
        )
        new_config = TrainRouterConfig(
            router_config=GraphRouterConfig(
                hidden_size=config.router.hidden_size,
                context_size=config.router.context_size,
                embed_dim=config.router.embed_dim,
                output_dim=config.router.output_dim,
                use_mlp=config.router.use_mlp,
                aggregate_layers=config.router.aggregate_layers,
                num_layers=num_layers,
                router_dropout=dropout,
            ),
            sentence_transformer_model=config.sentence_transformer_model,
            train_config=train_config,
            datasets=config.datasets,
            num_train_samples=config.num_train_samples,
            num_eval_samples=config.num_eval_samples,
        )
        trainer = train_router(new_config)
        return trainer.evaluate()["eval_similarity"]

    study = create_study(direction=direction, study_name=study_name)
    study.optimize(objective, n_trials=n_trials)
    return study
