from typing import cast

import pandas as pd
from IPython.display import display, Markdown as md
import matplotlib.pyplot as plt
import seaborn as sns

from vis_analysis_utils.visualize.tables import TableFormatter
from utils.eval import DFAggregator
from utils import persistence
from .experiment import (
    RIExperimentResult,
    EXP_NAME,
    REPRESENTATION_DISTANCE_KEY,
)
from .invariance_estimation import CIFAR_RESNET_18_LAYERS

LAYER_ORDER = [layer[0] for layer in CIFAR_RESNET_18_LAYERS]

def load(
    config_name: str,
    seeds: list[tuple[int, int]],
) -> list[dict[str, RIExperimentResult]]:
    seed_results = []
    for seed in seeds:
        rep_dists = cast(
            pd.DataFrame,
            persistence.load_experiment_result(
                [EXP_NAME, config_name], seed, [REPRESENTATION_DISTANCE_KEY]
            )
        )

        result = RIExperimentResult(
            config=None,
            representation_distances=rep_dists,
        )
        seed_results.append(result)
        # seed_results.append(cast(
        #     ITExperimentResult,
        #     persistence.load_experiment_result(
        #         [EXP_NAME, config_name], seed,
        #     )
        # ))
    return seed_results

def summarize(
    results: list[RIExperimentResult],
) -> RIExperimentResult:
    configs = []
    rep_dist_aggregator = DFAggregator()
    for result in results:
        configs.append(result.config)
        if result.representation_distances is None:
            raise ValueError()
        rep_dist_aggregator.append_seed_result(result.representation_distances)

    rep_dist_aggregate = rep_dist_aggregator.get_aggregate()
    rep_dist_aggregate.index = (
        rep_dist_aggregate.index.set_levels(LAYER_ORDER, level="layer")
    )
    mean_result = RIExperimentResult(
        config=configs[0],
        representation_distances=rep_dist_aggregate,
    )
    return mean_result

def show(result: RIExperimentResult) -> None:
    # print("transforms:", self.transforms)
    # print("objects:", self.objects)
    if result.representation_distances is None:
        raise ValueError()

    print("Representation distances:")
    TableFormatter(result.representation_distances) \
        .display_with_heatmap() \
        .show()

    # for model_name, rep_dist in result.rep_distances.items():
    #     print(f"\nModel {model_name} representation distances:")
    #     TableFormatter(rep_dist) \
    #         .display_with_heatmap() \
    #         .show()
    #         # .highlight_extreme(max=False) \
    #     print(
    #         "Distance across all samples "
    #         "(transformations and object types): "
    #         f"{result.ref_distances[model_name]:.2f}"
    #     )

TRAINING_RATIO_COL = "training_ratio"
HOLDOUT_RATIO_COL = "holdout_ratio"

def to_ratios(result: RIExperimentResult) -> RIExperimentResult:
    if result.representation_distances is None:
        raise ValueError()
    rep_dists = result.representation_distances
    training_obj_ratio = (
        cast(pd.Series, rep_dists["between_training"])
        / cast(pd.Series, rep_dists["within_training"])
    )
    holdout_obj_ratio = (
        cast(pd.Series, rep_dists["between_holdout"])
        / cast(pd.Series, rep_dists["within_holdout"])
    )
    conv_rep_dists = pd.DataFrame({
        "training_ratio": training_obj_ratio,
        "holdout_ratio": holdout_obj_ratio,
    })
    return RIExperimentResult(
        config=result.config,
        representation_distances=conv_rep_dists,
    )

MODEL_COL = "Model"
DATASET_COL = "Dataset"
# TARGET_LAYER = "layer4"
TARGET_LAYER = "avgpool"
SELF_RATIO_COL = "Between-within class $l_2$-ratio"

def plot_intra_vs_inter(
    results: list[RIExperimentResult], training_obj: bool = True,
) -> plt.Figure:
    ratios = pd.DataFrame(columns=[
        MODEL_COL, DATASET_COL, SELF_RATIO_COL,
    ])
    for result in results:
        rep_ratios = to_ratios(result).representation_distances
        if rep_ratios is None:
            raise ValueError()
        for row_idx, row_ratios in rep_ratios.iterrows():
            (model, dataset, layer) = cast(tuple, row_idx)
            if layer != TARGET_LAYER or dataset == "none":
                continue
            if training_obj:
                self_ratio = row_ratios[TRAINING_RATIO_COL]
            else:
                self_ratio = row_ratios[HOLDOUT_RATIO_COL]
            res_row = pd.DataFrame({
                MODEL_COL: [model[5:] if model.startswith("m_rw_") else model],
                DATASET_COL: [dataset],
                SELF_RATIO_COL: [self_ratio],
            })
            ratios = pd.concat([ratios, res_row], ignore_index=True)

    fig = sns.catplot(
        data=ratios,
        x=MODEL_COL,
        y=SELF_RATIO_COL,
        hue=DATASET_COL,
        kind="bar",
        legend_out=False,
        height=3,
        aspect=4/3,
    )

    fig.tight_layout()
    plt.show()
    return fig

RATIO_COL = "ratio_val"
RATIO_TYPE_COL = "ratio_type"

def plot_untrained_vs_trained(
    results: list[RIExperimentResult], training_obj: bool = True,
) -> plt.Figure:
    ratios = pd.DataFrame(columns=[
        RATIO_TYPE_COL,
        DATASET_COL,
        RATIO_COL,
    ])
    for result in results:
        rep_dists = result.representation_distances
        if rep_dists is None:
            raise ValueError()
        untrained_dists = rep_dists.loc[("untrained",)]
        for row_idx, row_dists in rep_dists.iterrows():
            (model, dataset, layer) = cast(tuple, row_idx)
            model_name = model[5:] if model.startswith("m_rw_") else model
            if (
                layer != TARGET_LAYER
                or model_name == "untrained"
                # or model_name != dataset
            ):
                continue
            if training_obj:
                intra_class_dists = row_dists["within_training"]
                inter_class_dists = row_dists["between_training"]
            else:
                intra_class_dists = row_dists["within_holdout"]
                inter_class_dists = row_dists["between_holdout"]
            if model_name == dataset:
                ratio_types = ["within_target", "between_target"]
            else:
                ratio_types = ["within_others", "between_others"]
            res_row = pd.DataFrame({
                RATIO_TYPE_COL: ratio_types,
                DATASET_COL: [dataset, dataset],
                RATIO_COL: [intra_class_dists, inter_class_dists],
            })
            ratios = pd.concat([ratios, res_row], ignore_index=True)

    fig = sns.catplot(
        data=ratios,
        x=DATASET_COL,
        y=RATIO_COL,
        hue=RATIO_TYPE_COL,
        kind="bar",
        legend_out=False,
        height=4,
        aspect=4/3,
    )

    fig.tight_layout()
    plt.show()
    return fig

LAYER_COL = "Layer"

def plot_layerwise_ratios(
    results: list[RIExperimentResult], training_obj: bool = True,
) -> plt.Figure:
    ratios = pd.DataFrame(columns=[
        MODEL_COL, LAYER_COL, SELF_RATIO_COL,
    ])
    for result in results:
        rep_ratios = to_ratios(result).representation_distances
        if rep_ratios is None:
            raise ValueError()
        for row_idx, row_ratios in rep_ratios.iterrows():
            (model, dataset, layer) = cast(tuple, row_idx)
            model_name = model[5:] if model.startswith("m_rw_") else model
            if dataset != model_name:
                continue
            if training_obj:
                self_ratio = row_ratios[TRAINING_RATIO_COL]
            else:
                self_ratio = row_ratios[HOLDOUT_RATIO_COL]
            res_row = pd.DataFrame({
                MODEL_COL: [model_name],
                LAYER_COL: [layer],
                SELF_RATIO_COL: [self_ratio],
            })
            ratios = pd.concat([ratios, res_row], ignore_index=True)

    # print("ratios:", ratios)

    fig, axes = plt.subplots(1, figsize=(4, 3), squeeze=False)
    ratio_plot = axes[0][0]
    sns.lineplot(
        data=ratios,
        x=LAYER_COL,
        y=SELF_RATIO_COL,
        hue=MODEL_COL,
        style=MODEL_COL,
        markers=True,
        # legend_out=False,
        axes=ratio_plot,
    )

    fig.tight_layout()
    plt.show()
    return fig
