from dataclasses import dataclass
from typing import cast

from IPython.display import display, Markdown as md
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from vis_analysis_utils.visualize.tables import TableFormatter
from utils import persistence
from utils.eval import (
    DFAggregator,
)
from .experiment import (
    IFEExperimentResult,
    EXP_NAME,
    CIFAR_VS_NO_CIFAR_COL,
    PATCH_VS_NO_PATCH_COL,
)
from .data import (
    AVAILABILITY_VALUES,
    CIFAR_ONLY_KEY,
    OBJECTS_ONLY_KEY,
    MIXED_OBJECTS_KEY,
    MIXED_CIFAR_KEY,
    MIXED_CIFAR_AVAILABILITY_KEY,
)

def load(
    config_name: str,
    seeds: list[tuple[int, int]],
    transforms: list[str],
) -> list[dict[str, IFEExperimentResult]]:
    seed_results = []
    for seed in seeds:
        seed_results.append({
            transform_name: cast(
                IFEExperimentResult,
                persistence.load_experiment_result(
                    [EXP_NAME, config_name], seed, [transform_name],
                )
            )
            for transform_name in transforms
        })
    return seed_results

def summarize(
    results: list[dict[str, IFEExperimentResult]],
) -> dict[str, IFEExperimentResult]:
    configs = []
    objects = []
    in_dist_performances = {}#DFAggregator()
    transfer_performances = {}#DFAggregator()
    # rep_distances = {}#DFAggregator()

    for seed_results in results:
        for transform_name, transform_result in seed_results.items():
            configs.append(transform_result.config)
            objects.append(transform_result.objects)
            in_dist_performances.setdefault(transform_name, DFAggregator()) \
                .append_seed_result(transform_result.in_dist_performance)
            transfer_performances.setdefault(transform_name, DFAggregator()) \
                .append_seed_result(transform_result.transfer_performance)
            # rep_distances.setdefault(transform_name, DFAggregator()) \
                # .append_seed_result(transform_result.representation_distances)
                # .append_seed_result(transform_result.rep_distances_aggregate)

    in_dist_aggregate = {
        transform_name: transform_in_dist_performances.get_aggregate()
        for transform_name, transform_in_dist_performances
        in in_dist_performances.items()
    }
    transfer_aggregate = {
        transform_name: transform_transfer_performances.get_aggregate()
        for transform_name, transform_transfer_performances
        in transfer_performances.items()
    }
    # rep_distances_aggregate = {
    #     transform_name: transform_rep_distances.get_aggregate()
    #     for transform_name, transform_rep_distances
    #     in rep_distances.items()
    # }

    mean_results = {
        transform_name: IFEExperimentResult(
            config=configs[0],
            objects=objects[0],
            in_dist_performance=in_dist_aggregate[transform_name],
            transfer_performance=transfer_aggregate[transform_name],
            cka_rep_dists=None,
            # cka_rep_dists=rep_distances_aggregate[transform_name],
        )
        for transform_name in in_dist_aggregate.keys()
    }
    return mean_results


def show(
    result: IFEExperimentResult,
) -> None:
    print("In distribution/training performance:")
    display(result.in_dist_performance)

    print("Transfer performance:")
    # TableFormatter(self.transfer_performance) \
    #     .display_with_heatmap() \
    #     .show()
    # display(self.transfer_performance)
    plot_transfer_heatmap(result, with_title=True)
    plt.show()

    # print("l2 distances of penultimate layer representations derived from contrastive inputs")
    # TableFormatter(result.representation_distances) \
    #     .display_with_heatmap() \
    #     .show()
    # display(self.representation_distances)

def plot_transfer_heatmap(
    result: IFEExperimentResult,
    with_title: bool = False,
    full_results: bool = False,
) -> plt.Figure:
    transfer_res = result.transfer_performance
    if not full_results:
        filtered_rows = [f"m_{key}" for key in TRANSFER_SUBSTITUTIONS.keys()]
        filtered_cols = list(TRANSFER_SUBSTITUTIONS.keys())
        transfer_res = transfer_res.loc[filtered_rows, filtered_cols]
        ood_fig, ood_axes = plt.subplots(1, figsize=(4, 3), squeeze=False)
    else:
        ood_fig, ood_axes = plt.subplots(1, figsize=(5, 7), squeeze=False)
    ood_plot = ood_axes[0][0]
    if with_title:
        ood_plot.set_title("Transfer Accuracy")
    if full_results:
        tick_labels = result.transfer_performance.columns, transfer_res.index
    else:
        tick_labels = _get_transfer_axis_labels(transfer_res)
    sns.heatmap(
        transfer_res,
        annot=True,
        ax=ood_plot,
        xticklabels=tick_labels[0],
        yticklabels=tick_labels[1],
    )
    # ood_plot.set_xticklabels(ood_plot.get_xticklabels(), rotation=0)
    # ood_plot.set_yticklabels(ood_plot.get_yticklabels(), rotation=0)
    ood_plot.set_xlabel("Target Dataset")
    ood_plot.set_ylabel("Pre-training Dataset")
    ood_fig.tight_layout()
    return ood_fig

def _get_transfer_axis_labels(df: pd.DataFrame) -> tuple[list[str], list[str]]:
    col_substitutions = [
        TRANSFER_SUBSTITUTIONS[data_name]
        for data_name in df.columns
    ]
    row_substitutions = [
        TRANSFER_SUBSTITUTIONS[model_name[2:]]
        for model_name in df.index
    ]
    return col_substitutions, row_substitutions

TRANSFER_SUBSTITUTIONS = {
    CIFAR_ONLY_KEY: "X = C\nY = C",
    f"{MIXED_CIFAR_KEY}0": "X = C + O\nY = C",
    MIXED_OBJECTS_KEY: "X = C + O\nY = O",
    OBJECTS_ONLY_KEY: "X = O\nY = O",
}


@dataclass
class CorrelationPerformancePlots:
    cifar_target_perf: plt.Figure
    objects_target_perf: plt.Figure
    cifar_rep_dists: plt.Figure
    objects_rep_dists: plt.Figure

CORRELATION_VALUES = [
    0,
    0.2,
    0.4,
    0.6,
    0.8,
    0.85,
    0.9,
    0.95,
    1.0,
]
CORRELATION_ROWS = [
    f"m_{MIXED_CIFAR_KEY}{cor_value}"
    for cor_value in CORRELATION_VALUES
]
CIFAR_PERF_COL = f"{MIXED_CIFAR_KEY}0"
OBJECTS_PERF_COL = MIXED_OBJECTS_KEY
CORRELATION_COL = "CIFAR-Object Label Correlation $\\alpha$"

def plot_correlation_performance(
    results: list[IFEExperimentResult],
) -> CorrelationPerformancePlots:
    full_results = [
        res.transfer_performance for res in results
    ]
    transfer_per_results = [
        res.transfer_performance.loc[CORRELATION_ROWS] for res in results
    ]
    cifar_target_fig = _plot_label_type_performance(
        transfer_per_results,
        full_results,
        CORRELATION_COL,
        CORRELATION_VALUES,
        CIFAR_PERF_COL,
        "CIFAR Accuracy",
        # with_legend=True,
    )
    objects_target_fig = _plot_label_type_performance(
        transfer_per_results,
        full_results,
        CORRELATION_COL,
        CORRELATION_VALUES,
        OBJECTS_PERF_COL,
        "Objects Accuracy",
    )

    # rep_dist_results = []
    # for res in results:
    #     rep_dists = res.representation_distances
    #     offset_df = pd.DataFrame(
    #         1.0, columns=rep_dists.columns, index=rep_dists.index
    #     )
    #     rep_dist_results.append(offset_df - rep_dists)
    # cifar_rep_dists = _plot_label_type_performance(
    #     rep_dist_results,
    #     CIFAR_VS_NO_CIFAR_COL,
    #     "CIFAR Representation Sensitivity",
    # )
    # objects_rep_dist = _plot_label_type_performance(
    #     rep_dist_results,
    #     PATCH_VS_NO_PATCH_COL,
    #     "Objects Representation Sensitivity",
    # )

    return CorrelationPerformancePlots(
        cifar_target_perf=cifar_target_fig,
        objects_target_perf=objects_target_fig,
        # cifar_rep_dists=cifar_rep_dists,
        # objects_rep_dists=objects_rep_dist,
        cifar_rep_dists=None,
        objects_rep_dists=None,
    )


AVAILABILITY_ROWS = [f"m_{CIFAR_ONLY_KEY}"] + [
    f"m_{MIXED_CIFAR_AVAILABILITY_KEY}{av_value}"
    for av_value in AVAILABILITY_VALUES
] + [f"m_{CIFAR_PERF_COL}"]
FULL_AVAILABILITY_VALUES = [0.0] + AVAILABILITY_VALUES + [1.0]
AVAILABILITY_COL = "Object Availability $\\beta$"

def plot_availability_performance(
    results: list[IFEExperimentResult],
) -> tuple[plt.Figure, plt.Figure]:
    full_results = [
        res.transfer_performance for res in results
    ]
    transfer_per_results = [
        res.transfer_performance.loc[AVAILABILITY_ROWS] for res in results
    ]
    cifar_target_fig = _plot_label_type_performance(
        transfer_per_results,
        full_results,
        AVAILABILITY_COL,
        FULL_AVAILABILITY_VALUES,
        CIFAR_PERF_COL,
        "CIFAR Accuracy",
        with_legend=True,
    )
    objects_target_fig = _plot_label_type_performance(
        transfer_per_results,
        full_results,
        AVAILABILITY_COL,
        FULL_AVAILABILITY_VALUES,
        OBJECTS_PERF_COL,
        "Objects Accuracy",
    )

    # rep_dist_results = []
    # for res in results:
    #     rep_dists = res.representation_distances
    #     offset_df = pd.DataFrame(
    #         1.0, columns=rep_dists.columns, index=rep_dists.index
    #     )
    #     rep_dist_results.append(offset_df - rep_dists)
    # cifar_rep_dists = _plot_label_type_performance(
    #     rep_dist_results,
    #     CIFAR_VS_NO_CIFAR_COL,
    #     "CIFAR Representation Sensitivity",
    # )
    # objects_rep_dist = _plot_label_type_performance(
    #     rep_dist_results,
    #     PATCH_VS_NO_PATCH_COL,
    #     "Objects Representation Sensitivity",
    # )
    
    return cifar_target_fig, objects_target_fig

def _plot_label_type_performance(
    results: list[pd.DataFrame],
    full_results: list[pd.DataFrame],
    target_col: str,
    target_values: list[float],
    perf_col: str,
    y_label: str,
    with_legend: bool = False,
) -> plt.Figure:
    correlation_cifar_perf = pd.concat([
        # result[[perf_col]].assign(**{CORRELATION_COL: CORRELATION_VALUES})
        result[[perf_col]].assign(**{target_col: target_values})
        for result in results
    ], ignore_index=True)
    
    perf_fig, perf_axes = plt.subplots(1, figsize=(3, 2.5), squeeze=False)
    perf_plot = perf_axes[0][0]
    sns.lineplot(
        data=correlation_cifar_perf,
        x=target_col,
        y=perf_col,
        label="mixed CIFAR labels",
        # label="X = C + O, Y = C",
        ax=perf_plot,
    )
    const_results = {
        "CIFAR-only": _get_const_res(
        # "X = C, Y = C": _get_const_res(
            full_results,
            target_col,
            target_values,
            f"m_{CIFAR_ONLY_KEY}",
            perf_col,
        ),
        "mixed object labels": _get_const_res(
        # "X = C + O, Y = O": _get_const_res(
            full_results,
            target_col,
            target_values,
            f"m_{MIXED_OBJECTS_KEY}",
            perf_col,
        ),
        "objecs-only": _get_const_res(
        # "X = O, Y = O": _get_const_res(
            full_results,
            target_col,
            target_values,
            f"m_{OBJECTS_ONLY_KEY}",
            perf_col,
        ),
    }
    for model_name, model_res in const_results.items():
        sns.lineplot(
            data=model_res,
            x=target_col,
            y=perf_col,
            label=model_name,
            ax=perf_plot,
            linestyle="dashed",
        )
    # perf_plot.set_xlabel("CIFAR-Object correlation in Training Data")
    perf_plot.set_ylabel(y_label)
    if not with_legend:
        perf_plot.legend().remove()

    perf_fig.tight_layout()
    return perf_fig

def _get_const_res(
    results: list[pd.DataFrame],
    target_col: str,
    target_values: list[float],
    model_name: str,
    col_name: str,
) -> pd.DataFrame:
    return pd.concat([
        pd.DataFrame({
            col_name: [res.loc[model_name, col_name]] * len(target_values)
        })
        .assign(**{target_col: target_values})
        for res in results
    ], ignore_index=True)
