from typing import cast
import itertools

from IPython.display import display, Markdown as md
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from vis_analysis_utils.visualize.tables import TableFormatter
from utils import persistence
from utils.eval import DFAggregator
from utils.visualize.objects_2d import show_samples
from .experiment import (
    TMExperimentResult,
    EXP_NAME,
)


def load(
    config_name: str,
    seeds: list[tuple[int, int]],
) -> list[TMExperimentResult]:
    return [
        cast(
            TMExperimentResult,
            persistence.load_experiment_result([EXP_NAME, config_name], seed)
        )
        for seed in seeds
    ]

def summarize(
    results: list[TMExperimentResult]
) -> TMExperimentResult:
    configs = []
    expanded_data_configs = []
    in_dist_performances = DFAggregator()
    quant_mismatch_performances = DFAggregator()
    # qual_mismatch_performances = DFAggregator()
    order_mismatch_performances = DFAggregator()

    for result in results:
        # print("result:", result)
        configs.append(result.config)
        expanded_data_configs.append(result.expanded_data_config)
        in_dist_performances.append_seed_result(
            result.in_dist_performance
        )
        quant_mismatch_performances.append_seed_result(
            result.quant_mismatch_performance
        )
        # qual_mismatch_performances.append_seed_result(
        #     result.qual_mismatch_performance
        # )
        order_mismatch_performances.append_seed_result(
            result.order_mismatch_performance
        )

    in_dist_aggregate = in_dist_performances.get_aggregate()
    quant_mismatch_aggregate = quant_mismatch_performances.get_aggregate()
    # qual_mismatch_aggregate = qual_mismatch_performances.get_aggregate()
    order_mismatch_aggregate = order_mismatch_performances.get_aggregate()

    mean_result = TMExperimentResult(
        config=configs[0],
        expanded_data_config=expanded_data_configs[0],
        in_dist_performance=in_dist_aggregate,
        quant_mismatch_performance=quant_mismatch_aggregate,
        # qual_mismatch_performance=qual_mismatch_aggregate,
        order_mismatch_performance=order_mismatch_aggregate,
    )
    return mean_result

def show(
    result: TMExperimentResult,
    show_transforms_info: bool = False,
    show_features_info: bool = False,
) -> None:
    if show_transforms_info:
        print("transforms:", result.expanded_data_config.all_transforms)
        # print(
        #     "mismatch transforms:",
        #     result.expanded_data_config.mismatch_transforms
        # )
    if show_features_info:
        print("features:", result.expanded_data_config.features)

    display(md("#### Training (in-distribution) performance:"))
    in_dist_performance = result.in_dist_performance.astype(float)
    in_dist_performance.style.background_gradient(cmap="Reds")
    display(in_dist_performance)

    display(md("#### Quantitative transformation difference performance:"))
    display(md("Evaluating models on datasets with different numbers of transformations"))
    _, quant_axes = plt.subplots(1, figsize=(4, 3), squeeze=False)
    quant_plot = quant_axes[0][0]
    quant_plot.set_title("Transfer performance (accuracy)")
    sns.heatmap(
        result.quant_mismatch_performance,
        annot=True,
        ax=quant_plot,
    )
    quant_plot.set_xlabel("# transformations")
    quant_plot.set_ylabel("models")
    plt.show()

    # display(md("#### Qualitative transformation difference performance:"))
    # display(md("Evaluating models on datasets with different types of transformations"))
    # _, qual_axes = plt.subplots(1, figsize=(4, 3), squeeze=False)
    # qual_plot = qual_axes[0][0]
    # qual_plot.set_title("Transfer performance (accuracy)")
    # sns.heatmap(
    #     result.qual_mismatch_performance,
    #     annot=True,
    #     ax=qual_plot,
    # )
    # qual_plot.set_xlabel("# of different transformations")
    # qual_plot.set_ylabel("models")
    # plt.show()

    display(result.order_mismatch_performance)

DATA_KEY = "data"
MODEL_KEY = "# Training\nTransforms"
PERF_KEY = "perf"

TRANSFORM_COUNTS = list(range(1, 9))

def show_line_plots(
    results: list[TMExperimentResult],
) -> plt.Figure:
    data_perfs = []#pd.DataFrame(columns=[DATA_KEY, MODEL_KEY, PERF_KEY])
    for result in results:
        quant_res = result.quant_mismatch_performance
        for data_name, data_res in quant_res.iteritems():
            for model_name, model_res in data_res.iteritems():
                res_row = {
                    DATA_KEY: [int(cast(str, data_name)[2:])],
                    MODEL_KEY: [cast(str, model_name)[2:]],
                    PERF_KEY: [model_res],
                }
                # data_perfs = pd.concat(
                #     [data_perfs, pd.DataFrame(res_row)],
                #     ignore_index=True,
                # )
                data_perfs.append(pd.DataFrame(res_row))
    data_perf_aggregate = pd.concat(data_perfs, ignore_index=True)
        
    perf_fig, perf_axes = plt.subplots(1, figsize=(4, 3), squeeze=False)
    perf_plot = perf_axes[0][0]
    # perf_plot.set_title("Transfer performance (accuracy)")
    sns.lineplot(
        data=data_perf_aggregate,
        x=DATA_KEY,
        y=PERF_KEY,
        hue=MODEL_KEY,
        style=MODEL_KEY,
        ax=perf_plot,
    )
    perf_plot.set_xlabel("# Dataset Transforms")
    perf_plot.set_ylabel("Transfer accuracy")
    perf_plot.set_xticks(TRANSFORM_COUNTS, TRANSFORM_COUNTS)
    # perf_plot.set_xticklabels([1, 2, 3, 4])

    perf_fig.tight_layout()
    plt.show()
    return perf_fig
