from typing import cast

from IPython.display import display, Markdown as md
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

from utils.eval import (
    DFAggregator,
)
from utils import persistence
from .experiment import (
    CTExperimentResult,
    EXP_NAME,
)

def load(
    config_name: str,
    seeds: list[tuple[int, int]],
) -> list[CTExperimentResult]:
    return [
        cast(
            CTExperimentResult,
            persistence.load_experiment_result([EXP_NAME, config_name], seed)
        )
        for seed in seeds
    ]

def summarize(
    results: list[CTExperimentResult]
) -> CTExperimentResult:
    configs = []
    transforms = []
    objects = []
    in_dist_performances = DFAggregator()
    cross_transforms_performances = DFAggregator()

    for result in results:
        configs.append(result.config)
        transforms.append(result.transforms)
        objects.append(result.objects)
        in_dist_performances.append_seed_result(
            result.in_dist_performance
        )
        cross_transforms_performances.append_seed_result(
            result.cross_transforms_performance
        )

    in_dist_aggregate = in_dist_performances.get_aggregate()
    cross_transforms_aggregate = cross_transforms_performances.get_aggregate()

    mean_result = CTExperimentResult(
        config=configs[0],
        transforms=transforms[0],
        objects=objects[0],
        in_dist_performance=in_dist_aggregate,
        cross_transforms_performance=cross_transforms_aggregate,
    )
    return mean_result

def show(
    result: CTExperimentResult,
) -> None:
    print("Transforms:", result.transforms)
    print("Objects:", result.objects)

    _, idp_axes = plt.subplots(1, figsize=(6, 3), squeeze=False)
    idp_plot = idp_axes[0][0]
    idp_plot.set_title("Training (source-dataset) performance")
    sns.barplot(
        x=result.in_dist_performance.index,
        y=result.in_dist_performance["accuracy"],
        ax=idp_plot,
    )
    idp_plot.set_xticklabels(idp_plot.get_xticklabels(), rotation=80)

    transfer_fig = plot_transfer_performance(result)
    plt.show()

def plot_transfer_performance(
    result: CTExperimentResult,
    with_title: bool = False,
) -> plt.Figure:
    ood_fig, ood_axes = plt.subplots(1, figsize=(5, 4), squeeze=False)
    ood_plot = ood_axes[0][0]
    if with_title:
        ood_plot.set_title("Cross-transforms (transfer) performance (accuracy)")

    sns.heatmap(
        result.cross_transforms_performance,
        annot=True,
        ax=ood_plot,
    )
    x_ticks, y_ticks = _get_transfer_axis_ticks(
        result.cross_transforms_performance
    )
    ood_plot.set_xticklabels(x_ticks, rotation=45)
    ood_plot.set_yticklabels(y_ticks, rotation=0)
    # ood_plot.vlines([9], 0, 18, linestyle="dashed", colors="blue")
    # ood_plot.hlines([9], 0, 18, linestyle="dashed", colors="blue")
    ood_plot.set_xlabel("Datasets")
    ood_plot.set_ylabel("Models")

    ood_fig.tight_layout()
    plt.show()
    return ood_fig

def _get_transfer_axis_ticks(df: pd.DataFrame) -> tuple[list[str], list[str]]:
    x_ticks = [col_name[5:] for col_name in df.columns]
    y_ticks = [row_name[5:] for row_name in df.index]
    return x_ticks, y_ticks
