import pandas as pd
import numpy as np
from scipy.stats.mstats import gmean
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import argparse
import os
import logging


STORAGE_LOCATION = "artifacts/results"
FONT_FAMILY = 'Times New Roman'
# Explicitly set the font family and name
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']


def get_results_table_clean(
    result_file_name: str, train_mse: bool = True
) -> pd.DataFrame:
    """
    Get a table of the results of a benchmarking run.
    This should include the mse and its variance.
    The dataset is on the vertical axis of the table and the models are on the horizontal axis.
    """

    df = pd.read_csv(f"{STORAGE_LOCATION}/{result_file_name}/results.csv", header=0)
    columns_of_interest = ["model_name", "dataset"]
    value_columns = ["final_epoch_test_mse"]
    if train_mse:
        value_columns.append("final_epoch_train_mse")
    columns_of_interest.extend(value_columns)
    df = df[columns_of_interest]
    agg_df = df.groupby(["model_name", "dataset"])[value_columns].agg(["mean", "var"])

    # Separate into mean and variance DataFrames
    mean_df = agg_df.xs("mean", level=1, axis=1)
    variance_df = agg_df.xs("var", level=1, axis=1)

    mean_df.reset_index(inplace=True)
    mean_df = (
        mean_df.pivot(columns=["model_name"], index=["dataset"], values=value_columns)
        .swaplevel(axis=1)
        .sort_index(axis=1)
    )

    mean_over_baseline = pd.DataFrame(columns=mean_df.columns)
    for model_col in list(mean_df.columns.get_level_values(0)):
        normalized_df = mean_df[(model_col,)] / mean_df[("PCA",)]
        for inner_col in normalized_df:
            mean_over_baseline[(model_col, inner_col)] = normalized_df[inner_col]

    variance_df.reset_index(inplace=True)
    variance_df = (
        variance_df.pivot(
            columns=["model_name"], index=["dataset"], values=value_columns
        )
        .swaplevel(axis=1)
        .sort_index(axis=1)
    )

    # Now save the paper results
    os.makedirs(f"./{STORAGE_LOCATION}/{result_file_name}/paper", exist_ok=True)

    # Save DataFrames as CSV files
    mean_df.to_csv(
        os.path.join(f"./{STORAGE_LOCATION}/", result_file_name, "paper/mean.csv")
    )
    mean_over_baseline.to_csv(
        os.path.join(
            f"./{STORAGE_LOCATION}/", result_file_name, "paper/mean_over_baseline.csv"
        )
    )
    variance_df.to_csv(
        os.path.join(
            f"./{STORAGE_LOCATION}/",
            result_file_name,
            "paper/variance.csv",
        )
    )

    # Generate LaTeX code for the DataFrames
    mean_latex = mean_df.to_latex()
    mean_over_baseline_latex = mean_over_baseline.to_latex()
    variance_latex = variance_df.to_latex()

    # Save LaTeX code to files
    with open(
        os.path.join(f"./{STORAGE_LOCATION}/", result_file_name, "paper/mean.tex"),
        "w",
    ) as file:
        file.write(mean_latex)

    with open(
        os.path.join(
            f"./{STORAGE_LOCATION}/", result_file_name, "paper/mean_over_baseline.tex"
        ),
        "w",
    ) as file:
        file.write(mean_over_baseline_latex)

    with open(
        os.path.join(
            f"./{STORAGE_LOCATION}/",
            result_file_name,
            "paper/variance.tex",
        ),
        "w",
    ) as file:
        file.write(variance_latex)

    logging.info(
        f"Saved the tables for mean, variances and the corresponding LaTex code to {STORAGE_LOCATION}/{result_file_name}/paper!"
    )


def get_result_plots(result_file_name: str) -> None:
    "Plot results of the benchmarking and save to /plots/."

    # 1. Plot normalized results
    normalized_df = pd.read_csv(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/mean_over_baseline.csv",
        header=[0, 1],
        index_col=0,
    )

    # Aggregate by geometric mean across datasets.
    agg_df = pd.DataFrame(normalized_df.agg(gmean, axis=0), columns=["Value"])

    # Get test MSE and rename PCA to KernelPCA.
    reset_df = agg_df.reset_index()
    reset_df = reset_df[reset_df["level_1"] == "final_epoch_test_mse"]
    reset_df.drop(["level_1"], axis=1, inplace=True)
    reset_df["model_name"] = (
        reset_df["model_name"].replace("PCA", "KernelPCA").replace("JointVAE", "VAE")
    )
    reset_df.set_index("model_name", inplace=True)
    reset_df.sort_values("Value", ascending=False, axis=0, inplace=True)

    # Plotting
    plt.style.use("default")
    plt.figure(figsize=(10, 3), dpi=100)
    bars = plt.barh(reset_df.index, reset_df["Value"], height=0.6, color="skyblue") # Bar height of 0.8 for FinalResults, 0.5 for CAE comparison.

    # Annotating the bars with the exact values up to four digits of precision
    for bar in bars:
        width = bar.get_width()
        plt.text(
            width * 1.02,
            bar.get_y() + bar.get_height() / 2,
            f"{width:.3f}",
            ha="left",
            va="center",
            fontsize=10,
            fontweight="bold",
        )

    # Adding labels and title with larger and bold font
    text_style = {"fontweight": 600, "fontfamily": FONT_FAMILY}
    plt.xlabel("Agg. Test MSE Normalized by KernelPCA", fontsize=14, **text_style)
    plt.ylabel("Model Name", fontsize=14, **text_style)
    plt.title("Test Reconstruction MSE", fontsize=18, **text_style)

    plt.xscale("log")
    ticks = np.logspace(
        np.floor(np.log10(reset_df["Value"].min())),
        np.ceil(np.log10(reset_df["Value"].max())),
        num=10,
    )
    plt.xticks(ticks)

    # Additional styling
    plt.grid(True, axis="x", linestyle="--", alpha=0.7, color="w")
    plt.grid(False, axis="y")
    plt.xticks(fontsize=12, fontfamily=FONT_FAMILY)
    plt.yticks(fontsize=12, fontfamily=FONT_FAMILY)

    # Adjusting the ylim to pull the bars closer to the center
    bar_count = len(reset_df)
    plt.ylim(-0.5, bar_count - 0.5)
    plt.tight_layout()

    os.makedirs(f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots", exist_ok=True)

    # Display the plot
    plt.savefig(f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots/test_mse_agg.svg")
    plt.savefig(f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots/test_mse_agg.jpg")


def get_downstream_results_table_clean(result_file_name: str) -> pd.DataFrame:
    """
    Get a table of the downstream results of a benchmarking run.
    The dataset is on the vertical axis of the table and the models are on the horizontal axis.
    The values are dictionaries with keys being metric names and the values being the metric values.
    """

    df = pd.read_csv(
        f"{STORAGE_LOCATION}/{result_file_name}/downstream_results.csv",
        header=0,
        index_col=0,
    )

    ### Clean and re-organize downstream results.
    df.index.name = "Dataset"
    df.rename_axis("Method", axis="columns", inplace=True)

    # Initialize an empty dictionary to store new columns
    new_classification_columns = {}
    new_regression_columns = {}

    # Collect all possible metric names
    all_metric_names = set()
    classification_metric_names = {'Recall', 'F1-Score', 'Precision', 'Accuracy'}
    regression_metric_names = {'MAE', 'RMSE'}

    for col in df.columns:
        for metrics in df[col]:
            all_metric_names.update(eval(metrics).keys())

    assert all_metric_names == classification_metric_names.union(regression_metric_names), f"The following metrics were found in the downstream results but not recognized: {all_metric_names - classification_metric_names.union(regression_metric_names)}"
    
    # See which (i.e. which datasets) are regression and which are classification tasks by looking at the metrics.
    classification_sets = []
    regression_sets = []
    for dataset_name, data in df.iterrows():
        metrics = set()
        for result in data:
            metrics.update(set(eval(result).keys()))
        if metrics == classification_metric_names:
            classification_sets.append(dataset_name)
        elif metrics == regression_metric_names:
            regression_sets.append(dataset_name)
        else:
            raise ValueError(f"The metrics {metrics} are neither the same as the classification metrics, nor the same as the regression metrics.")

    # Create new columns with NaN-filled lists
    for col in df.columns:
        # For classification results.
        for metric_name in classification_metric_names:
            new_col_name = (col, metric_name)
            new_classification_columns[new_col_name] = []

        # For regression results.
        for metric_name in regression_metric_names:
            new_col_name = (col, metric_name)
            new_regression_columns[new_col_name] = []

    # Populate the new columns with corresponding metric values or NaN
    for col in df.columns:
        for metrics in df[col]:
            metrics = eval(metrics)
            for metric_name in all_metric_names:
                new_col_name = (col, metric_name)
                if metric_name in metrics and metric_name in classification_metric_names:
                    new_classification_columns[new_col_name].append(metrics[metric_name])
                elif metric_name in metrics and metric_name in regression_metric_names:
                    new_regression_columns[new_col_name].append(metrics[metric_name])

    # Create the new DataFrames
    classification_performance_df = pd.DataFrame(new_classification_columns)
    regression_performance_df = pd.DataFrame(new_regression_columns)

    # Classifiction: Set the new multi-level column headers
    classification_performance_df.columns = pd.MultiIndex.from_tuples(classification_performance_df.columns)
    classification_performance_df.index = classification_sets
    # Sort the columns at both levels alphabetically
    classification_performance_df = classification_performance_df.sort_index(
        axis=1, level=[0, 1], sort_remaining=True
    )

    # Regression: Set the new multi-level column headers
    regression_performance_df.columns = pd.MultiIndex.from_tuples(regression_performance_df.columns)
    regression_performance_df.index = regression_sets
    # Sort the columns at both levels alphabetically
    regression_performance_df = regression_performance_df.sort_index(
        axis=1, level=[0, 1], sort_remaining=True
    )

    # Sort the index
    classification_performance_df = classification_performance_df.sort_index()
    regression_performance_df = regression_performance_df.sort_index()

    ### Now prepare normalized tables.
    classification_performance_over_baseline = pd.DataFrame(columns=classification_performance_df.columns)
    for model_col in list(classification_performance_df.columns.get_level_values(0)):
        normalized_df = classification_performance_df[(model_col,)] / classification_performance_df[("RawData",)]
        for inner_col in normalized_df:
            classification_performance_over_baseline[(model_col, inner_col)] = normalized_df[inner_col]
    
    regression_performance_over_baseline = pd.DataFrame(columns=regression_performance_df.columns)
    for model_col in list(regression_performance_df.columns.get_level_values(0)):
        normalized_df = regression_performance_df[(model_col,)] / regression_performance_df[("RawData",)]
        for inner_col in normalized_df:
            regression_performance_over_baseline[(model_col, inner_col)] = normalized_df[inner_col]

    # Now save the paper results
    os.makedirs(f"./{STORAGE_LOCATION}/{result_file_name}/paper", exist_ok=True)

    # Save DataFrames as CSV files
    df.to_csv(
        os.path.join(
            f"./{STORAGE_LOCATION}/", result_file_name, "paper/raw_performance.csv"
        )
    )
    classification_performance_over_baseline.to_csv(
        os.path.join(
            f"./{STORAGE_LOCATION}/",
            result_file_name,
            "paper/classification_performance_over_baseline.csv",
        )
    )
    regression_performance_over_baseline.to_csv(
        os.path.join(
            f"./{STORAGE_LOCATION}/",
            result_file_name,
            "paper/regression_performance_over_baseline.csv",
        )
    )

    # Generate LaTeX code for the DataFrames
    raw_performance_latex = df.to_latex()
    classification_performance_over_baseline_latex = classification_performance_over_baseline.to_latex()
    regression_performance_over_baseline_latex = regression_performance_over_baseline.to_latex()

    # Save LaTeX code to files
    with open(
        os.path.join(
            f"./{STORAGE_LOCATION}/", result_file_name, "paper/raw_performance.tex"
        ),
        "w",
    ) as file:
        file.write(raw_performance_latex)

    with open(
        os.path.join(
            f"./{STORAGE_LOCATION}/",
            result_file_name,
            "paper/classification_performance_over_baseline.tex",
        ),
        "w",
    ) as file:
        file.write(classification_performance_over_baseline_latex)
    
    with open(
        os.path.join(
            f"./{STORAGE_LOCATION}/",
            result_file_name,
            "paper/regression_performance_over_baseline.tex",
        ),
        "w",
    ) as file:
        file.write(regression_performance_over_baseline_latex)

    logging.info(
        f"Saved the tables for the raw and normalized downstream performance and the corresponding LaTex code to {STORAGE_LOCATION}/{result_file_name}/paper!"
    )


def get_downstream_results_plots(result_file_name: str) -> None:
    "Plot results of the benchmarking and save to /plots/."

    # 1. Plot normalized results
    regression_normalized_df = pd.read_csv(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/regression_performance_over_baseline.csv",
        header=[0, 1],
        index_col=0,
    )
    classification_normalized_df = pd.read_csv(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/classification_performance_over_baseline.csv",
        header=[0, 1],
        index_col=0,
    )


    # 2. Perform aggregation.
    regression_agg_df = pd.DataFrame(
        regression_normalized_df.agg(gmean, axis=0), columns=["Value"]
    )
    classification_agg_df = pd.DataFrame(
        classification_normalized_df.agg(gmean, axis=0), columns=["Value"]
    )
    print(regression_agg_df.head(10))
    print(classification_agg_df.head(25))
    ### Classification problems
    # Sort to comply with the order of the reconstruction results.
    classification_agg_df.rename(
        index={"JointVAE": "VAE", "PCA": "KernelPCA", "RawData": "Raw Data"},
        inplace=True,
    )
    # result_order = reversed(
    #     [
    #         "DeepCAE",
    #         "StandardAE",
    #         # "StackedCAE"
    #         "ConvAE",
    #         "VAE",
    #         "KernelPCA",
    #         "TransformerAE",
    #         # "Raw Data",
    #     ]
    # )  # Reversed bc the plotting reverses again for some reason.
    result_order = reversed(
        [
            "DeepCAE",
            "StackedCAE",
            "KernelPCA",
            # "Raw Data",
        ]
    )  # Reversed bc the plotting reverses again for some reason.
    classification_agg_df = classification_agg_df.reindex(result_order, level=0)

    # Unstack the DataFrame to reshape it for plotting
    classification_df_unstacked = classification_agg_df.unstack(level=1)
    classification_df_unstacked.columns = (
        classification_df_unstacked.columns.droplevel()
    )

    # Plotting the DataFrame
    base_color = to_rgba("skyblue")
    colors = [
        tuple(np.clip(np.array(base_color[:3]) * (0.4 + 0.2 * i), 0, 1))
        + (base_color[3],)
        for i in range(4)
    ]
    ax = classification_df_unstacked.plot(
        kind="barh", figsize=(12, 8), style="default", logx=True, width=0.8
    )
    # TODO: Use default instead of ggplot

    # Set the colors for each bar group
    for i, bars in enumerate(ax.containers):
        for bar in bars:
            bar.set_color(colors[i])

    # Create a small space between the bars
    for bars in ax.containers:
        for bar in bars:
            bar.set_height(0.15)  # Adjust the height for spacing

    # Annotate the bars with the exact values up to three digits of precision
    for bars in ax.containers:
        for bar in bars:
            width = bar.get_width()
            ax.text(
                width * 1.0004, # Use 0.002 for FinalResults
                bar.get_y() + bar.get_height() / 2,
                f"{width:.3f}",
                ha="left",
                va="center",
                fontsize=10,
                fontweight="bold",
            )

    # Adding labels and title with larger and bold font
    text_style = {"fontweight": 600, "fontfamily": FONT_FAMILY}
    plt.xlabel("Metric Value", fontsize=14, **text_style)
    plt.ylabel("Embedding Model", fontsize=14, **text_style)
    plt.title(
        "Classification Downstream Performance with Embeddings",
        fontsize=18,
        **text_style,
    )

    # Additional styling
    plt.grid(True, axis="x", linestyle="--", alpha=0.7, color="w")
    plt.grid(False, axis="y")
    plt.xticks(fontsize=12, fontfamily=FONT_FAMILY)
    plt.yticks(fontsize=12, fontfamily=FONT_FAMILY)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        handles[::-1], # Reverse order of handles
        labels[::-1], # Reverse order of labels
        title="Metrics",
        fontsize=12,
        title_fontsize=14,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.12),
        ncol=4,
    )
    plt.tight_layout()

    os.makedirs(f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots", exist_ok=True)

    # Display the plot
    plt.savefig(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots/classification_downstream_performance.svg"
    )
    plt.savefig(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots/classification_downstream_performance.jpg"
    )

    ### Regression problems
    # Sort to comply with the order of the reconstruction results.
    regression_agg_df.rename(
        index={"JointVAE": "VAE", "PCA": "KernelPCA", "RawData": "Raw Data"},
        inplace=True,
    )
    # result_order = reversed(
    #     [
    #         "DeepCAE",
    #         "StandardAE",
    #         "ConvAE",
    #         "VAE",
    #         "KernelPCA",
    #         "TransformerAE",
    #         # "Raw Data",
    #     ]
    # )  # Reversed bc the plotting reverses again for some reason.
    result_order = reversed(
        [
            "DeepCAE",
            "StackedCAE",
            "KernelPCA",
            # "Raw Data",
        ]
    )  # 
    regression_agg_df = regression_agg_df.reindex(result_order, level=0)

    # Unstack the DataFrame to reshape it for plotting
    regression_df_unstacked = regression_agg_df.unstack(level=1)
    regression_df_unstacked.columns = (
        regression_df_unstacked.columns.droplevel()
    )
    # print(regression_normalized_df.head(10))
    # print(classification_df_unstacked.head())

    # Plotting the DataFrame
    base_color = to_rgba("skyblue")
    colors = [
        tuple(np.clip(np.array(base_color[:3]) * (0.4 + 0.2 * i), 0, 1))
        + (base_color[3],)
        for i in range(4)
    ]
    ax = regression_df_unstacked.plot(
        kind="barh", figsize=(12, 6), style="default", logx=True, width=0.6
    )

    # Set the colors for each bar group
    for i, bars in enumerate(ax.containers):
        for bar in bars:
            bar.set_color(colors[i])

    # Create a small space between the bars
    for bars in ax.containers:
        for bar in bars:
            bar.set_height(0.25)  # Adjust the height for spacing

    # Annotate the bars with the exact values up to three digits of precision
    for bars in ax.containers:
        for bar in bars:
            width = bar.get_width()
            ax.text(
                width * 1.001, # + 0.005,
                bar.get_y() + bar.get_height() / 2,
                f"{width:.3f}",
                ha="left",
                va="center",
                fontsize=10,
                fontweight="bold",
            )

    # Adding labels and title with larger and bold font
    text_style = {"fontweight": 600, "fontfamily": FONT_FAMILY}
    plt.xlabel("Metric Value", fontsize=14, **text_style)
    plt.ylabel("Embedding Model", fontsize=14, **text_style)
    plt.title(
        "Regression Downstream Performance with Embeddings", fontsize=18, **text_style
    )

    # Additional styling
    plt.grid(True, axis="x", linestyle="--", alpha=0.7, color="w")
    plt.grid(False, axis="y")
    plt.xticks(fontsize=12, fontfamily=FONT_FAMILY)
    plt.yticks(fontsize=12, fontfamily=FONT_FAMILY)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        handles[::-1],
        labels[::-1],
        title="Metrics",
        fontsize=12,
        title_fontsize=14,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.12),
        ncol=4,
    )
    plt.tight_layout()

    os.makedirs(f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots", exist_ok=True)

    # Display the plot
    plt.savefig(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots/regression_downstream_performance.svg"
    )
    plt.savefig(
        f"./{STORAGE_LOCATION}/{result_file_name}/paper/plots/regression_downstream_performance.jpg"
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tables", type=bool, default=True)
    parser.add_argument("--plots", type=bool, default=True)
    parser.add_argument("--tables_downstream", type=bool, default=True)
    parser.add_argument("--plots_downstream", type=bool, default=True)

    parser.add_argument("--train_mse", type=bool, default=True)
    parser.add_argument(
        "--result_file_name",
        type=str,
        required=True,
        help="The name of the directory that contains the results, being the date of the benchmarking run without the .csv ending.",
    )

    args = parser.parse_args()

    if args.tables:
        get_results_table_clean(args.result_file_name, train_mse=args.train_mse)

    if args.plots:
        get_result_plots(args.result_file_name)

    if args.tables_downstream:
        get_downstream_results_table_clean(args.result_file_name)

    if args.plots_downstream:
        get_downstream_results_plots(args.result_file_name)

    logging.info(f"Finished extracting the results!")
