import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from loguru import logger

from tabicl.config.config_run import ConfigRun
from tabicl.core.dataset_split import make_dataset_split
from tabicl.core.enums import DatasetSize, DataSplit, ModelName, Task
from tabicl.core.get_model import get_model
from tabicl.core.get_trainer import get_trainer
from tabicl.data.dataset_openml import OpenMLDataset
from tabicl.results.run_metrics import RunMetrics
from tabicl.utils.debugger import debugger_is_active
from tabicl.utils.paths_and_filenames import CONFIG_RUN_FILE_NAME
from tabicl.utils.set_seed import set_seed


def run_experiment(cfg: ConfigRun) -> Optional[RunMetrics]:

    set_cpus(cfg)
    torch.cuda.set_device(cfg.device)
    cfg.save(cfg.output_dir / CONFIG_RUN_FILE_NAME)

    logger.info(f"Start experiment on {cfg.openml_dataset_name} (id={cfg.openml_dataset_id}) with {cfg.model_name.value} doing {cfg.task.value}")

    set_seed(cfg.seed)
    logger.info(f"Set seed to {cfg.seed}")

    logger.info(f"We are using the following hyperparameters:")
    for key, value in cfg.hyperparams.items():
        logger.info(f"    {key}: {value}")

    if debugger_is_active():
        metrics = run_experiment_(cfg)
    else:
        try:
            metrics = run_experiment_(cfg)
        except Exception as e:
            logger.exception("Exception occurred while running experiment")        
            return None
    
    logger.info(f"Finished experiment on {cfg.openml_dataset_name} (id={cfg.openml_dataset_id}) with {cfg.model_name} doing {cfg.task.name}")
    logger.info(f"Final scores: ")

    for i in range(metrics.ds.sizes['cv_split']):
        logger.info((
            f"cv_split_{i} :: "
            f"train: {metrics.ds['score'].sel(data_split=DataSplit.TRAIN.value, cv_split=i):.4f}, "
            f"val: {metrics.ds['score'].sel(data_split=DataSplit.VALID.value, cv_split=i):.4f}, "
            f"test: {metrics.ds['score'].sel(data_split=DataSplit.TEST.value, cv_split=i):.4f}"
        ))

    logger.info((
        f"cv_average :: "
        f"train: {metrics.ds['score'].sel(data_split=DataSplit.TRAIN.value).mean():.4f}, "
        f"val: {metrics.ds['score'].sel(data_split=DataSplit.VALID.value).mean():.4f}, "
        f"test: {metrics.ds['score'].sel(data_split=DataSplit.TEST.value).mean():.4f}"
    ))

    if metrics is not None:
        metrics.save(cfg.output_dir / "metrics.nc")

    return metrics


def run_experiment_(cfg: ConfigRun) -> RunMetrics:

    dataset = OpenMLDataset(cfg.datafile_path, cfg.task)
    metrics = RunMetrics()

    for split_i, (x_train, x_val, x_test, y_train, y_val, y_test, categorical_indicator) in enumerate(dataset.split_iterator()):

        logger.info(f"Start split {split_i+1}/{dataset.n_splits} of {cfg.openml_dataset_name} (id={cfg.openml_dataset_id}) with {cfg.model_name.name} doing {cfg.task.name}")

        data = Data.from_standard_datasplits(
            x_train, 
            x_val, 
            x_test, 
            y_train, 
            y_val, 
            y_test, 
            cfg.task, 
            cfg.hyperparams['early_stopping_data_split'],
            cfg.hyperparams['early_stopping_max_samples']
        )

        model = get_model(cfg, data.x_train_cut, data.y_train_cut, categorical_indicator)
        trainer = get_trainer(cfg, model, dataset.n_classes)
        trainer.train(data.x_train_cut, data.y_train_cut, data.x_val_earlystop, data.y_val_earlystop)

        logger.info("Testing on training data...")
        prediction_metrics_train = trainer.evaluate(data.x_train, data.y_train, data.x_train, data.y_train)
        logger.info("Testing on validation data...")
        prediction_metrics_val = trainer.evaluate(data.x_train, data.y_train, data.x_val_hyperparams, data.y_val_hyperparams)
        logger.info("Testing on test data...")
        prediction_metrics_test = trainer.evaluate(data.x_train, data.y_train, data.x_test, data.y_test)

        logger.info(f"split_{split_i} :: train: {prediction_metrics_train.score:.4f}, val: {prediction_metrics_val.score:.4f}, test: {prediction_metrics_test.score:.4f}")

        metrics.append(prediction_metrics_train, prediction_metrics_val, prediction_metrics_test)

    metrics.post_process()
    return metrics


@dataclass
class Data():
    """
    x_train: the training data
    x_train_cut: in case of early stopping on the training data, 
                 this is a cut of the training data that excludes the early stopping part,
                 otherwise it is the full training data
    x_train_and_val: the training data and the validation data combined
    x_val_earlystop: the data used for early stopping, either from the validation or the training dataset
    x_val_hyperparams: the data used for hyperparameter search, always from the validation dataset
    x_test: the test data
    """

    x_train: np.ndarray
    x_train_cut: np.ndarray
    x_train_and_val: np.ndarray
    x_val_earlystop: np.ndarray
    x_val_hyperparams: np.ndarray
    x_test: np.ndarray
    y_train: np.ndarray
    y_train_cut: np.ndarray
    y_train_and_val: np.ndarray
    y_val_earlystop: np.ndarray
    y_val_hyperparams: np.ndarray
    y_test: np.ndarray


    @classmethod
    def from_standard_datasplits(
        cls, 
        x_train, 
        x_val, 
        x_test, 
        y_train, 
        y_val, 
        y_test, 
        task: Task, 
        early_stopping_data_split: str,
        early_stopping_max_samples: Optional[int] = None
    ):

        match early_stopping_data_split:
            case "VALID":
                # Use the full validation set for early stopping and for hyperparameter search
                x_train_cut = x_train
                y_train_cut = y_train
                x_val_earlystop = x_val
                y_val_earlystop = y_val
            case "TRAIN":
                # Use a cut of the training set for early stopping and the full validation set for hyperparameter search
                x_train_cut, x_val_earlystop, y_train_cut, y_val_earlystop = make_dataset_split(x_train, y_train, task=task)
            case _:
                raise NotImplementedError(f"DataSplit {early_stopping_data_split} not implemented")
            
        if early_stopping_max_samples is not None:
            # Use only a subset of the early stopping data, because otherwise it is too slow
            early_stopping_indices_count = min(early_stopping_max_samples, len(x_val_earlystop))
            early_stopping_indices = np.random.choice(len(x_val_earlystop), early_stopping_indices_count, replace=False)

            x_val_earlystop = x_val_earlystop[early_stopping_indices]
            y_val_earlystop = y_val_earlystop[early_stopping_indices]
            
        x_train_and_val = np.concatenate([x_train, x_val], axis=0)
        y_train_and_val = np.concatenate([y_train, y_val], axis=0)

        return cls(
            x_train=x_train,
            y_train=y_train,
            x_train_cut=x_train_cut,
            y_train_cut=y_train_cut,
            x_val_earlystop=x_val_earlystop,
            y_val_earlystop=y_val_earlystop,
            x_val_hyperparams=x_val,
            y_val_hyperparams=y_val,
            x_train_and_val=x_train_and_val,
            y_train_and_val=y_train_and_val,
            x_test=x_test,
            y_test=y_test
        )



def set_cpus(cfg: ConfigRun) -> None:

    if cfg.cpus is not None:

        total_cpus = os.cpu_count()
        assert total_cpus is not None, "Could not determine number of cpus"
        assert all([cpu < total_cpus for cpu in cfg.cpus]), f"cpus {cfg.cpus} contain cpu ids that are not available on this machine"

        os.sched_setaffinity(os.getpid(), cfg.cpus)


if __name__ == "__main__":

    import torch

    # cfg = ConfigRun.create(
    #     output_dir = Path("output_run_experiment"),
    #     device = torch.device("cuda:6"),
    #     cpus = list(range(64)),
    #     model_name = ModelName.AUTOGLUON,
    #     seed = 0,
    #     task = Task.REGRESSION,
    #     dataset_size = DatasetSize.MEDIUM,
    #     datafile_path = Path("data/datasets_evaluation/whytrees_44065_MEDIUM.nc"),
    #     hyperparams = dict({
    #         'time_limit': 60,
    #         'num_cpus': 64,
    #         'early_stopping_data_split': 'VALID',
    #         'early_stopping_max_samples': 1e9
    #     })
    # )

    cfg = ConfigRun.create(
        output_dir = Path("output_run_experiment"),
        device = torch.device("cuda:0"),
        cpus = None,
        model_name = ModelName.TAB2D,
        seed = 0,
        task = Task.REGRESSION,
        dataset_size = DatasetSize.MEDIUM,
        datafile_path = Path("data/datasets_evaluation/whytrees_44065_MEDIUM.nc"),
        hyperparams = {
            "dim": 512,
            "dim_embedding": None,
            "dim_output": 1,
            "n_layers": 12,
            "n_heads": 4,
            "task": Task.REGRESSION,
            "max_samples_support": 8192,
            "max_samples_query": 1024,
            "max_epochs": 300,
            "optimizer": "adamw",
            "lr": 1.e-3,
            "weight_decay": 0.0,
            "lr_scheduler": False,
            "lr_scheduler_patience": 25,
            "warmup_steps": 0,
            "early_stopping_patience": 40,
            "early_stopping_data_split": "VALID",
            "early_stopping_max_samples": 2048,
            "precision": "bfloat16",
            "grad_scaler_enabled": False,
            "grad_scaler_scale_init": 65536.,
            "grad_scaler_scale_min": 65536.,
            "grad_scaler_growth_interval": 1000,
            "label_smoothing": 0.0,
            "use_pretrained_weights": False,
            "path_to_weights": "outputs/runs/2024-08-08/23-58-03/weights/model_step_12000.pt",
            "use_quantile_transformer": False,
            "use_feature_count_scaling": False,
            "shuffle_classes": True,
            "shuffle_features": False,
            "random_mirror_x": True,
            "random_mirror_regression": True
        }
    )

    run_experiment(cfg)



