import argparse
import shutil
import typing as ty
from pathlib import Path

import numpy as np
import scipy
import sklearn.metrics as skm
import torch
from ablator import (
    ConfigBase,
    Literal,
    ModelWrapper,
    Optional,
    ParallelConfig,
    ParallelTrainer,
    ProtoTrainer,
    RunConfig,
)
from ablator import TrainConfig as TrainConfigBase
from ablator import configclass
from torch.utils.data import DataLoader

import tablator as tablator_lib
from tablator import data_directory, package_directory
from tablator.dataset import TabDataset
from tablator.model import Tablator, TablatorConfig
from tablator.optim import make_optimizer


@configclass
class OptimizerConfig(ConfigBase):
    name: Literal["adabelief", "adam", "adamw", "radam", "sgd"] = "adam"
    lr: float = 1e-2
    weight_decay: float = 0

    def make_optimizer(self, model: torch.nn.Module):
        return make_optimizer(self.name, model, self.lr, self.weight_decay)


@configclass
class TrainConfig(TrainConfigBase):
    # Configurable attributes
    optimizer_config: OptimizerConfig = OptimizerConfig()
    dataset: Literal[
        "year",
        "yahoo",
        "helena",
        "covtype",
        "epsilon",
        "jannis",
        "adult",
        "aloi",
        "higgs_small",
        "microsoft",
        "california_housing",
    ] = "helena"
    dataset_root: str = data_directory
    normalization: Optional[Literal["standard", "quantile"]] = "quantile"
    cat_nan_policy: Literal["new", "most_frequent"] = "new"
    cat_policy: Literal["ohe", "indices", "counter"] = "indices"
    cat_min_frequency: float = 0.0
    dataset_seed: int = 0


@configclass
class ParallelRunConfig(ParallelConfig):
    model_config: TablatorConfig
    train_config: TrainConfig


@configclass
class BaseRunConfig(RunConfig):
    model_config: TablatorConfig
    train_config: TrainConfig


class MyModelWrapper(ModelWrapper):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def make_dataloader_train(self, run_config: ParallelRunConfig):  # type: ignore
        return load_dataset(run_config, flag="train")

    def make_dataloader_val(self, run_config: ParallelRunConfig):  # type: ignore
        return load_dataset(run_config, flag="val")

    def config_parser(self, config: ParallelRunConfig):
        dataset: TabDataset = self.train_dataloader.dataset
        config.model_config.categories = dataset.categories
        config.model_config.data_type = dataset.task_type
        config.model_config.d_out = dataset.d_out
        config.model_config.d_numerical = dataset.n_num_features
        return super().config_parser(config)

    def evaluation_functions(self) -> dict[str, ty.Callable] | None:
        dataset: TabDataset = self.train_dataloader.dataset

        if dataset.is_regression:
            rmse = lambda target, pred: (
                skm.mean_squared_error(target[~np.isnan(pred)], pred[~np.isnan(pred)])
                ** 0.5
                * dataset.y_std
            )
            return {"rmse": rmse}
        elif dataset.is_binclass:
            roc_score = lambda target, pred: skm.roc_auc_score(
                target, scipy.special.expit(pred)
            )
            acc = lambda target, pred: skm.accuracy_score(
                target, scipy.special.expit(pred) > 0.5
            )
            return {"acc": acc, "auc": roc_score}
        elif dataset.is_multiclass:
            acc = lambda target, pred: skm.accuracy_score(
                target.squeeze(), pred.squeeze().argmax(1)
            )
            return {"acc": acc}
        else:
            raise NotImplementedError()


def mp_train(mp_config):
    wrapper = MyModelWrapper(
        model_class=Tablator,
    )
    run_config = ParallelRunConfig.load(mp_config)  # type: ignore

    run_config.experiment_dir = Path().home().joinpath("tablator-save-dir", "parallel")

    shutil.rmtree(run_config.experiment_dir, ignore_errors=True)
    package_dir = package_directory.parent

    ablator = ParallelTrainer(
        wrapper=wrapper,
        run_config=run_config,
    )
    # NOTE to run on a cluster you will need to start ray with `ray start --head` and pass ray_head_address="auto"
    ablator.launch(
        working_directory=package_dir,
        auxilary_modules=[tablator_lib],
        ray_head_address=None,
    )
    ablator.evaluate()


def base_train(config):
    wrapper = MyModelWrapper(
        model_class=Tablator,
    )
    run_config = BaseRunConfig.load(config)  # type: ignore
    run_config.experiment_dir = (
        Path().home().joinpath("tablator-save-dir", "prototyping")
    )

    shutil.rmtree(run_config.experiment_dir, ignore_errors=True)
    ablator = ProtoTrainer(
        wrapper=wrapper,
        run_config=run_config,
    )
    ablator.launch()


def run(mp: bool):
    configs_folder = Path(__file__).parent.parent.joinpath("configs")
    base_config = configs_folder.joinpath("proto_config.yaml")
    mp_config = configs_folder.joinpath("parallel_config.yaml")
    if mp:
        mp_train(mp_config)
    else:
        base_train(base_config)


def load_dataset(config: ParallelRunConfig, flag="train") -> DataLoader:
    train_config = config.train_config
    dataset = TabDataset(
        train_config.dataset_root,
        train_config.dataset,
        split=flag,
        normalization=train_config.normalization,
        cat_nan_policy=train_config.cat_nan_policy,
        cat_policy=train_config.cat_policy,
        cat_min_frequency=train_config.cat_min_frequency,
        seed=train_config.dataset_seed,
    )
    dataloader: DataLoader = DataLoader(
        dataset,
        batch_size=config.train_config.batch_size,
        shuffle=flag == "train",
    )
    return dataloader


if __name__ == "__main__":
    args = argparse.ArgumentParser()
    args.add_argument("--mp", action="store_true")
    kwargs = vars(args.parse_args())
    config = run(mp=True)  # **kwargs)
