from dataclasses import dataclass
from typing import Optional, cast
import itertools

from IPython.display import display, Markdown as md
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from utils import persistence
from utils.eval import DFAggregator
from utils.visualize.objects_2d import show_samples
from .experiment import (
    ITExperimentResult,
    EXP_NAME,
)


def load(
    config_name: str,
    seeds: list[tuple[int, int]],
) -> list[ITExperimentResult]:
    return [
        cast(
            ITExperimentResult,
            persistence.load_experiment_result([EXP_NAME, config_name], seed)
        )
        for seed in seeds
    ]

def summarize(
    results: list[ITExperimentResult]
) -> ITExperimentResult:
    configs = []
    expanded_data_configs = []
    in_dist_performances = DFAggregator()
    transfer_performances = DFAggregator()

    for result in results:
        configs.append(result.config)
        expanded_data_configs.append(result.expanded_data_config)
        in_dist_performances.append_seed_result(
            result.in_dist_performance
        )
        transfer_performances.append_seed_result(
            result.transfer_performance
        )

    in_dist_aggregate = in_dist_performances.get_aggregate()
    transfer_aggregate = transfer_performances.get_aggregate()

    mean_result = ITExperimentResult(
        config=configs[0],
        expanded_data_config=expanded_data_configs[0],
        in_dist_performance=in_dist_aggregate,
        transfer_performance=transfer_aggregate,
    )
    return mean_result

def show(
    result: ITExperimentResult,
    show_transforms_info: bool = False,
    show_features_info: bool = False,
    show_n_samples: int = -1,
) -> None:
    if show_transforms_info:
        print("t1 transforms:", result.expanded_data_config.transforms_1)
        print("t2 transforms:", result.expanded_data_config.transforms_2)
    if show_features_info:
        print("i1 feature:", result.expanded_data_config.features_1)
        print("i2 features:", result.expanded_data_config.features_2)

    if show_n_samples > 0:
        for (ti, transform_names), (ii, class_indices) in itertools.product(
            enumerate((
                result.expanded_data_config.transforms_1,
                result.expanded_data_config.transforms_2,
            )),
            enumerate((
                result.expanded_data_config.features_1,
                result.expanded_data_config.features_2,
            )),
        ):
            print(f"Data samples t{ti + 1}_i{ii + 1}:")
            show_samples(show_n_samples, transform_names, class_indices)

    print("Training (in-distribution) performance:")
    in_dist_performance = result.in_dist_performance.astype(float)
    in_dist_performance.style.background_gradient(cmap="Reds")
    display(in_dist_performance)
    print("Fine-tuning (OOD) performance:")
    display(result.transfer_performance)

    plot_transfer_performance([result])
    plt.show()

def plot_transfer_performance(
    results: list[ITExperimentResult],
) -> plt.Figure:
    transformed_transfer_perf = _transform_results(
        [res.transfer_performance for res in results]
    )
    perf_fig = sns.catplot(
        data=transformed_transfer_perf,
        x=DATASET_KEY,
        y=PERF_KEY,
        hue=MODEL_KEY,
        kind="bar",
        legend_out=False,
        height=3,
        aspect=4/3,
    )
    perf_fig.tight_layout()
    return perf_fig

DATASET_KEY = "Dataset"
MODEL_KEY = "Model"
PERF_KEY = "Transfer Accuracy"

def _transform_results(results: list[pd.DataFrame]) -> pd.DataFrame:
    serialized_res = pd.DataFrame(columns=[
        DATASET_KEY, MODEL_KEY, PERF_KEY
    ])
    for transfer_res in results:
        for data_name, data_res in transfer_res.iteritems():
            for model_name, model_res in data_res.iteritems():
                subst_data_name = DATA_NAME_SUBSTITUTIONS[
                    cast(str, data_name)[2:]
                ]
                subst_model_name = DATA_NAME_SUBSTITUTIONS[
                    cast(str, model_name)[2:]
                ]
                res_row = {
                    DATASET_KEY: [subst_data_name],
                    MODEL_KEY: [subst_model_name],
                    PERF_KEY: [model_res],
                }
                serialized_res = pd.concat(
                    [serialized_res, pd.DataFrame(res_row)],
                    ignore_index=True,
                )
    return serialized_res

DATA_NAME_SUBSTITUTIONS = {
    "t1": "$T_1$",
    "t2": "$T_2$",
    "mixed": "$T_1 \\oplus T_2$",
    "full": "$T_1 \\circ T_2$",
}
