import itertools
from typing import Optional, Union
from dataclasses import dataclass
import copy
import pickle

import torch
from torch import nn
from pytorch_lightning import LightningDataModule
import pandas as pd

from vis_models.training import Trainer
from vis_models.training.supervised import SupervisedLearning
from vis_models.architectures.utils.fine_tuning import (
    prepare_torchvision_finetuning,
    ModelConfig,
)
from .eval import eval_models, BASE_METRICS
from . import persistence


@dataclass
class TrainingConfig:
    max_epochs: int# = 100
    save_checkpoints: bool# = True
    train: bool = True
    eval: bool = True

ModelSet = dict[str, nn.Module]

@dataclass
class TrainingResult:
    models: ModelSet
    metrics: dict[str, pd.DataFrame]

TrainingData = Union[
    dict[str, LightningDataModule],
    tuple[str, LightningDataModule]
]
        
def training_experiment(
    exp_name: list[str],
    config: TrainingConfig,
    models: ModelSet,
    datasets: TrainingData,
    metrics: list[str] = [],
    dataset_classes: Optional[dict[str, Union[int, list[str]]]] = None
) -> TrainingResult:
    if config.train:
        print(
            "--------------\n"
            "Training models\n"
            "--------------"
        )
        trained_models = _train_models(
            exp_name,
            models,
            datasets,
            config.max_epochs,
            save_checkpoints=config.save_checkpoints,
        )
    else:
        trained_models = persistence.load_models(
            exp_name,
            models,
        )
    # print("models", trained_models)

    if config.eval:
        print(
            "--------------\n"
            "Evaluating trained models\n"
            "--------------"
        )
        training_metrics_res = eval_models(
            exp_name,
            trained_models,
            datasets,
            metrics=metrics,
            dataset_classes=dataset_classes,
        )
        persistence.save_result(
            exp_name,
            training_metrics_res,
            result_name="eval_res",
        )
    else:
        training_metrics_res = persistence.load_result(
            exp_name,
            "eval_res",
        )

    return TrainingResult(
        models=trained_models,
        metrics=training_metrics_res,
    )

# Includes training, evaluating, saving results and loading
# results and models if necessary
def fine_tuning_experiment(
    exp_name: list[str],
    config: TrainingConfig,
    model_config: ModelConfig,
    models: ModelSet,
    datasets: dict[str, LightningDataModule],
    metrics: list[str] = [],
    dataset_classes: Optional[dict[str, Union[int, list[str]]]] = None,
    eval_cross_roduct: bool = True,
) -> dict[str, pd.DataFrame]:
    fine_tune_metrics_res = {}
    model_iter = iter(models.items())
    if not eval_cross_roduct:
        assert len(datasets) == len(models)

    for data_name, data in datasets.items():
        data_model_config = copy.deepcopy(model_config)
        if dataset_classes is not None:
            data_classes = dataset_classes[data_name]
            if isinstance(data_classes, list):
                data_model_config.num_classes = len(data_classes)
            else:
                data_model_config.num_classes = data_classes

        if eval_cross_roduct:
            dataset_models = models
        else:
            model_name, model = next(model_iter)
            dataset_models = {
                model_name: model
            }
        frozen_models = {
            model_name: prepare_torchvision_finetuning(
                (
                    model_config[model_name] if isinstance(model_config, dict)
                    else model_config
                ),
                copy.deepcopy(model),
                feature_extract=True,
            )
            for model_name, model in dataset_models.items()
        }

        data_fine_tuning_result = training_experiment(
            [*exp_name, data_name],
            config,
            frozen_models,
            (data_name, data),
            metrics=metrics,
            dataset_classes=dataset_classes,
        )
        for metric_name in BASE_METRICS + metrics:
            fine_tune_metrics_res.setdefault(
                metric_name, pd.DataFrame(
                    # index=models.keys(),
                    index=dataset_models.keys(),
                    dtype=float,
                )
            )[data_name] = data_fine_tuning_result.metrics[metric_name]
    return fine_tune_metrics_res

    # if config.eval:
    #     fine_tune_res = {
    #         metric_name: pd.DataFrame({
    #             data_name: data_result[data_name]
    #             for data_name, data_result
    #             in metric_results.items()
    #         })
    #         for metric_name, metric_results in fine_tune_metrics_res.items()
    #     }
    #     persistence.save_result(
    #         exp_name,
    #         fine_tune_metrics_res,
    #         result_name="fine_tuning_performance",
    #     )
    # else:
    #     fine_tune_res = persistence.load_result(
    #         exp_name,
    #         "fine_tuning_performance",
    #     )
    # return fine_tune_res


def _train_models(
    task_name: list[str],
    models: dict[str, torch.nn.Module],
    datasets: TrainingData,
    max_epochs: int,
    save_checkpoints: bool,
) -> dict[str, torch.nn.Module]:
    # with create_pool() as pool:
    #     trained_models = pool.starmap(training_iteration, zip(
    #         datasets.keys(),
    #         datasets.values(),
    #         itertools.repeat(self.config),
    #     ))
    if isinstance(datasets, tuple):
        dataset_names = itertools.repeat(datasets[0])
        dataset_data = itertools.repeat(datasets[1])
    else:
        dataset_names = datasets.keys()
        dataset_data = datasets.values()
    trained_models = [
        _train_on_task(
            task_name,
            *params,
            max_epochs=max_epochs,
            save_checkpoints=save_checkpoints,
        ) for params in zip(
            models.keys(),
            models.values(),
            dataset_names,
            dataset_data,
        )
    ]
    return {
        model_name: model
        for model_name, model in zip(models.keys(), trained_models)
    }

def _train_on_task(
    task_name: list[str],
    model_name: str,
    model: torch.nn.Module,
    data_name: str,
    data: LightningDataModule,
    max_epochs: int,
    save_checkpoints: bool,
) -> torch.nn.Module:
    print(
        "--------------\n"
        f"Training model {model_name} on dataset {data_name}\n"
        "--------------"
    )
    training_task = SupervisedLearning(model=model)

    trainer = Trainer(
        task_name=[*task_name, model_name],
        accelerator="gpu",
        devices=1,
        max_epochs=max_epochs,
        enable_checkpointing=save_checkpoints,
    )
    trainer.fit(training_task, data)
    return model
