import itertools
from typing import Optional, Union, cast

import torch
import pandas as pd
from pytorch_lightning import LightningDataModule

from vis_models.training import Trainer
from vis_models.training.supervised import (
    SupervisedLearning,
    LOSS_METRIC,
    ACCURACY_METRIC,
    CLASSWISE_ACCURACY_METRIC,
)

BASE_METRICS = [ACCURACY_METRIC]
METRIC_DTYPES = {
    LOSS_METRIC: float,
    ACCURACY_METRIC: float,
    CLASSWISE_ACCURACY_METRIC: object,
}

def eval_models(
    task_name: list[str],
    models: dict[str, torch.nn.Module],
    datasets: Union[
        dict[str, LightningDataModule], tuple[str, LightningDataModule]
    ],
    metrics: list[str] = [],
    dataset_classes: Optional[dict[str, Union[int, list[str]]]] = None
) -> dict[str, pd.DataFrame]:
    if isinstance(datasets, tuple):
        column_name = datasets[0]
        data_iter = itertools.repeat(datasets)
    else:
        column_name = "accuracy"
        data_iter = datasets.items()
    # performance_res = pd.DataFrame(
    #     index=list(models.keys()),
    #     columns=[column_name],
    #     dtype=float,
    # )

    metric_results = {
        # metric_name: pd.DataFrame(
        #     index=list(models.keys()),
        #     columns=[column_name],
        #     dtype=METRIC_DTYPES[metric_name],
        # )
        # for metric_name in BASE_METRICS + metrics
    }
    # iter_func = itertools.product if all_combinations else zip
    for (model_name, model), (data_name, data) in zip(
        models.items(), data_iter
    ):
        classes = (
            dataset_classes[data_name]
            if dataset_classes is not None
            else None
        )
        res = _eval_on_task(
            [*task_name, f"{model_name}-{data_name}"],
            model,
            data,
            metrics,
            classes=classes,
        )
        for metric_name, metric_res in res.items():
            metric_result = metric_results.setdefault(metric_name, pd.DataFrame(
                index=list(models.keys()),
                columns=[column_name],
                dtype=(
                    METRIC_DTYPES[metric_name]
                    if metric_name in METRIC_DTYPES
                    else object
                )
            ))
            metric_result.loc[model_name, column_name] = (
                metric_res
            )
            # performance_res.loc[model_name, column_name] = (
            #     res["accuracy"]#.detach().cpu().numpy()
            # )
    return metric_results

def _eval_on_task(
    task_name: list[str],
    model: torch.nn.Module,
    data: LightningDataModule,
    metrics: list[str],
    classes: Union[int, list[str], None] = None,
) -> dict[str, float]:
    eval_task = SupervisedLearning(
        model=model, test_metric_names=metrics, classes=classes,
    )
    trainer = Trainer(
        task_name=task_name,
        accelerator="gpu",
        devices=1,
        max_epochs=1,
    )
    # res = trainer.test(eval_task, data)[0]
    trainer.test(eval_task, data)
    res = eval_task.test_metrics.compute()
    np_res = {
        metric_name: metric_res.numpy()
        for metric_name, metric_res in res.items()
    }
    return np_res


class DFAggregator:

    GROUP_IDX_COL = "group_idx"

    def __init__(self) -> None:
        self.dfs = []
        self.cur_seed_idx = 0

    def append_seed_result(self, df: Union[pd.DataFrame, pd.Series]) -> None:
        df = pd.DataFrame(df)
        df = cast(
            pd.DataFrame,
            df.assign(group_idx=self.cur_seed_idx)
            .set_index(DFAggregator.GROUP_IDX_COL, append=True)
        )
        self.dfs.append(df)

    def get_aggregate(self) -> pd.DataFrame:
        combined_dfs = pd.concat(self.dfs)
        n_idx_cols = len(combined_dfs.index.names)
        group_by_levels = tuple(i for i in range(n_idx_cols - 1))
        return (
            combined_dfs
            .groupby(level=group_by_levels)
            .mean()
            # .agg(
            #     **{
            #         col_name: (col_name, "mean")
            #         for col_name in self.dfs[0].columns
            #     }
            # )
        )
