from dataclasses import dataclass
from typing import Optional, cast
import itertools

from IPython.display import display, Markdown as md
import matplotlib.pyplot as plt
from numpy import mod
import seaborn as sns
import pandas as pd

from utils import persistence
from utils.eval import DFAggregator
from utils.visualize.objects_2d import show_samples
from .experiment import (
    TvIExperimentResult,
    EXP_NAME,
)


def load(
    config_name: str,
    seeds: list[tuple[int, int]],
) -> list[TvIExperimentResult]:
    return [
        cast(
            TvIExperimentResult,
            persistence.load_experiment_result([EXP_NAME, config_name], seed)
        )
        for seed in seeds
    ]

def summarize(
    results: list[TvIExperimentResult]
) -> TvIExperimentResult:
    configs = []
    expanded_data_configs = []
    in_dist_performances = DFAggregator()
    transfer_performances = DFAggregator()

    for result in results:
        # print("result:", result.in_dist_performance)
        configs.append(result.config)
        expanded_data_configs.append(result.expanded_data_config)
        in_dist_performances.append_seed_result(
            result.in_dist_performance
        )
        transfer_performances.append_seed_result(
            result.transfer_performance
        )
        # cifar_10_confusion_matrices.append_seed_result(
        #     result.cifar_10_confusion_matrix
        # )

    in_dist_aggregate = in_dist_performances.get_aggregate()
    transfer_aggregate = transfer_performances.get_aggregate()
    # cifar_10_confusion_matrix_aggregate = (
    #     cifar_10_confusion_matrices.get_aggregate()
    # )

    mean_result = TvIExperimentResult(
        config=configs[0],
        expanded_data_config=expanded_data_configs[0],
        in_dist_performance=in_dist_aggregate,
        transfer_performance=transfer_aggregate,
    )
    return mean_result

def show(
    result: TvIExperimentResult,
    show_transforms_info: bool = False,
    show_features_info: bool = False,
    show_n_samples: int = -1,
) -> None:
    if show_transforms_info:
        print("t1 transforms:", result.expanded_data_config.transforms_1)
        print("t2 transforms:", result.expanded_data_config.transforms_2)
    if show_features_info:
        print("i1 feature:", result.expanded_data_config.features_1)
        print("i2 features:", result.expanded_data_config.features_2)

    if show_n_samples > 0:
        for (ti, transform_names), (ii, class_indices) in itertools.product(
            enumerate((
                result.expanded_data_config.transforms_1,
                result.expanded_data_config.transforms_2,
            )),
            enumerate((
                result.expanded_data_config.features_1,
                result.expanded_data_config.features_2,
            )),
        ):
            print(f"Data samples t{ti + 1}_i{ii + 1}:")
            show_samples(show_n_samples, transform_names, class_indices)

    print("Training (in-distribution) performance:")
    in_dist_performance = result.in_dist_performance.astype(float)
    in_dist_performance.style.background_gradient(cmap="Reds")
    display(in_dist_performance)
    print("Fine-tuning (OOD) performance:")
    display(result.transfer_performance)


DATA_REL_COL = "Dataset"
PERF_COL = "Transfer Accuracy"
MODEL_TYPE_COL = "Model"

def plot_model_type_comparison(
    model_type_results: dict[str, list[TvIExperimentResult]],
) -> plt.Figure:
    data_perfs = pd.DataFrame(columns = [
        DATA_REL_COL,
        PERF_COL,
        MODEL_TYPE_COL,
    ])
    for model_type, model_type_res in model_type_results.items():
        for seed_res in model_type_res:
            for data_name, data_res in seed_res.transfer_performance.iteritems():
                data_name_split = split_name(cast(str, data_name))
                for model_name, model_perf in data_res.iteritems():
                    res_row = {
                        PERF_COL: [model_perf],
                        MODEL_TYPE_COL: [model_type],
                    }
                    model_name_split = split_name(cast(str, model_name))
                    if (
                        data_name_split.transforms == model_name_split.transforms
                        and data_name_split.objects == model_name_split.objects
                    ):
                        res_row[DATA_REL_COL] = [ALL_SAME_KEY]
                    elif data_name_split.transforms == model_name_split.transforms:
                        res_row[DATA_REL_COL] = [TRANSFORMS_SAME_KEY]
                    elif data_name_split.objects == model_name_split.objects:
                        res_row[DATA_REL_COL] = [OBJECTS_SAME_KEY]
                    else:
                        res_row[DATA_REL_COL] = [NONE_SAME_KEY]
                    data_perfs = pd.concat([
                        data_perfs,
                        pd.DataFrame(res_row)
                    ], ignore_index=True)

    fig = sns.catplot(
        data=data_perfs,
        x=DATA_REL_COL,
        y=PERF_COL,
        hue=MODEL_TYPE_COL,
        kind="bar",
        legend=len(data_perfs) > 1,
        legend_out=False,
        height=3,
        aspect=4/3,
    )
    fig.set_axis_labels(
        "Properties shared with training data",
        "Transfer accuracy",
    )
    fig.tight_layout()
    plt.show()
    return fig

def show_data_performance_plots(
    results: list[TvIExperimentResult],
    # output_file: str,
    # data_rand_filter: Optional[str] = None,
    model_rand_filters: list[str] = ["all"],
) -> plt.Figure:
    data_perfs: dict[str, list[pd.DataFrame]] = {
        model_filter: [] for model_filter in model_rand_filters
    }
    for result in results:
        for data_name, data_res in result.transfer_performance.iteritems():
            for model_rand_filter in model_rand_filters:
                data_perf = get_per_data_perf(
                    cast(str, data_name),
                    data_res,
                    model_type_filter=model_rand_filter,
                )
                if data_perf is not None:
                    data_perfs[model_rand_filter].append(data_perf)

    if len(data_perfs) == 1:
        all_model_perfs = list(data_perfs.values())[0]
        data_perf_aggregate = pd.concat(all_model_perfs, ignore_index=True)
    else:
        data_perf_aggregates = [
            pd.concat(model_type_perfs)
            # .assign(**{MODEL_TYPE_KEY: [model_type] * len(model_type_perfs)})
            for model_type, model_type_perfs in data_perfs.items()
        ]
        data_perf_aggregate = pd.concat(data_perf_aggregates, ignore_index=True)
    data_perf_aggregate[MODEL_TYPE_KEY] = (
        data_perf_aggregate[MODEL_TYPE_KEY]
        .replace("rw", "photographed")
        .replace("rand", "random")
    )

    perf_fig = sns.catplot(
        data=data_perf_aggregate,
        x=DATA_REL_KEY,
        y=PERF_KEY,
        hue=MODEL_TYPE_KEY,
        kind="bar",
        legend=len(data_perfs) > 1,
        legend_out=False,
        height=3,
        aspect=4/3,
    )

    perf_fig.set_axis_labels(
        "Properties shared with training data",
        "Transfer accuracy",
    )

    perf_fig.tight_layout()
    plt.show()
    # perf_fig.savefig(output_file)
    return perf_fig

MODEL_TYPE_KEY = "model type"
DATA_REL_KEY = "data_relationship"
PERF_KEY = "perf"

ALL_SAME_KEY = "all"
TRANSFORMS_SAME_KEY = "transforms"
OBJECTS_SAME_KEY = "objects"
NONE_SAME_KEY = "none"

def get_per_data_perf(
    data_name: str,
    data_res: pd.Series,
    data_type_filter: str = "all",
    model_type_filter: str = "all",
) -> Optional[pd.DataFrame]:
    data_name_split = split_name(data_name)
    if (
        data_type_filter != "all"
        and data_type_filter != data_name_split.rw_vs_rand
    ):
        return None
    data_perf = pd.DataFrame(
        columns=[MODEL_TYPE_KEY, DATA_REL_KEY, PERF_KEY],
    )
    for model_name, model_res in data_res.items():
        model_name_split = split_name(cast(str, model_name))
        if (
            model_type_filter != "all"
            and model_type_filter != model_name_split.rw_vs_rand
        ):
            continue
        res_row = {
            MODEL_TYPE_KEY: [model_type_filter],
            PERF_KEY: [model_res],
        }
        if (
            data_name_split.transforms == model_name_split.transforms
            and data_name_split.objects == model_name_split.objects
        ):
            res_row[DATA_REL_KEY] = [ALL_SAME_KEY]
        elif data_name_split.transforms == model_name_split.transforms:
            res_row[DATA_REL_KEY] = [TRANSFORMS_SAME_KEY]
        elif data_name_split.objects == model_name_split.objects:
            res_row[DATA_REL_KEY] = [OBJECTS_SAME_KEY]
        else:
            res_row[DATA_REL_KEY] = [NONE_SAME_KEY]
        data_perf = pd.concat(
            [data_perf, pd.DataFrame(res_row)],
            ignore_index=True,
        )
    return data_perf
    
@dataclass
class NameSplitRes:
    transforms: str
    objects: str
    rw_vs_rand: str

def split_name(name: str) -> NameSplitRes:
    split_items = name.split("_")
    return NameSplitRes(
        transforms=split_items[1],
        objects=split_items[2],
        rw_vs_rand=split_items[3],
    )

