import copy
import glob
import itertools
import os
import re
from typing import Tuple, Sequence, Dict, Optional, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import ttest_rel
from scipy.stats.mstats_basic import Ttest_relResult

from extensions.rl_lighthouse.lighthouse_scripts.summarize_lr_optimization import unzip
from extensions.rl_lighthouse.lighthouse_scripts.summarize_pairwise_imitation_data import (
    set_size,
)
from extensions.rl_minigrid.minigrid_scripts.summarize_minigrid_data import (
    run_info_to_pretty_label,
    METHOD_ORDER,
    METHOD_TO_COLOR,
    METHOD_TO_LINE_MARKER,
)
from main import _get_args
from utils.misc_utils import (
    bootstrap_max_of_subset_statistic,
    expected_max_of_subset_statistic,
)

plt.rc("font", **{"family": "serif", "serif": ["CMU"], "size": 16})

plt.rc("xtick", labelsize=12)
plt.rc("ytick", labelsize=12)
plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")

METRIC_TO_LABEL = {
    "reward": "Reward",
    "avg_ep_length": "Avg. Ep. Length",
    "success": "Success",
}


def add_columns_to_df(df):
    alpha_queries = {
        "alpha_start": r"hyperparams.anneal_alpha_start = (.*)",
        "alpha_stop": r"hyperparams.anneal_alpha_stop = (.*)",
        "fixed_alpha": r"hyperparams.fixed_alpha = (.*)",
    }
    for key in itertools.chain(
        alpha_queries,
        ["lr", "tf_ratio", "lr_optimized", "tf_optimized", "pretty_label"],
    ):
        df[key] = [None] * df.shape[0]

    df.loc[:, "gp_params"] = [
        gps if not isinstance(gps, str) else eval(gps) for gps in df.loc[:, "gp_params"]
    ]

    for i in range(df.shape[0]):
        row = df.loc[i, :]

        # ALPHA KEYS
        gp_params: Tuple[str, ...] = row["gp_params"]
        if (
            (isinstance(gp_params, float) and np.isnan(gp_params))
            or gp_params is None
            or gp_params == "None"
        ):
            gp_params = tuple()

        for alpha_query_key, alpha_query in alpha_queries.items():
            for gp_param in gp_params:
                match = re.search(alpha_query, gp_param)
                try:
                    value = float(match.group(1))
                    if value == int(value):
                        value = int(value)
                    df.loc[i, alpha_query_key] = value
                    break
                except AttributeError:  # no match
                    pass

        # LR
        lr_queries = [
            r"hyperparams.lr = (.*)",
            r"lr_([^_].*)",
        ]
        df.loc[i, "lr_optimized"] = False
        for param in itertools.chain(gp_params, [row["extra_tag"]]):
            for lr_query in lr_queries:
                match = re.search(lr_query, param)
                try:
                    value = match.group(1)

                    df.loc[i, "lr_optimized"] = (
                        df.loc[i, "lr_optimized"] or "optimal" in value
                    )
                    try:
                        df.loc[i, "lr"] = float(match.group(1))
                    except ValueError:
                        pass
                except AttributeError:  # no match
                    pass
            if df.loc[i, "lr"] is not None:
                break

        # TF
        tf_queries = [
            r"hyperparams.tf_ratio = (.*)",
            r"tf_([^_].*)",
        ]
        df.loc[i, "tf_optimized"] = False
        for param in itertools.chain(gp_params, [row["extra_tag"]]):
            for tf_query in tf_queries:
                match = re.search(tf_query, param)
                try:
                    value = match.group(1)

                    df.loc[i, "tf_optimized"] = "optimal" in value
                    try:
                        df.loc[i, "tf_ratio"] = float(match.group(1))
                    except ValueError:
                        pass
                except AttributeError:  # no match
                    pass
            if df.loc[i, "lr"] is not None:
                break

    df.loc[:, "optimized"] = np.logical_or(
        df.loc[:, "lr_optimized"], df.loc[:, "tf_optimized"]
    )

    if df["optimized"].all():
        # If everything is optimized then nothing is
        df["optimized"] = False

    for i in range(df.shape[0]):
        df.loc[i, "pretty_label"] = run_info_to_pretty_label(dict(df.loc[i, :]))

    return df


def all_equal(s: Sequence):
    if len(s) <= 1:
        return True
    return all(s[0] == ss for ss in s[1:])


def plot_max_hp_curves(
    subset_size_to_expected_mas_est_list: List[Dict[int, float]],
    subset_size_to_bootstrap_points_list: Sequence[Dict[int, Sequence[float]]],
    method_labels: Sequence[str],
    colors: Sequence[Tuple[int, int, int]],
    line_styles: Optional[Sequence] = None,
    line_markers: Optional[Sequence] = None,
    title: str = "",
    ylabel: str = "",
    fig_size=(4, 4 * 3.0 / 5.0),
    save_path: Optional[str] = None,
    put_legend_outside: bool = True,
    include_legend: bool = False,
):
    line_styles = ["solid"] * len(colors) if line_styles is None else line_styles
    line_markers = [""] * len(colors) if line_markers is None else line_markers

    plt.grid(
        b=True,
        which="major",
        color=np.array([0.93, 0.93, 0.93]),
        linestyle="-",
        zorder=-2,
    )
    plt.minorticks_on()
    plt.grid(
        b=True,
        which="minor",
        color=np.array([0.97, 0.97, 0.97]),
        linestyle="-",
        zorder=-2,
    )
    ax = plt.gca()
    ax.set_axisbelow(True)
    # Hide the right and top spines
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    for (
        index,
        (
            subset_size_to_expected_max_est,
            subset_size_to_bootstrap_points,
            method_label,
            color,
            line_style,
            line_marker,
        ),
    ) in enumerate(
        zip(
            subset_size_to_expected_mas_est_list,
            subset_size_to_bootstrap_points_list,
            method_labels,
            colors,
            line_styles,
            line_markers,
        )
    ):
        xvals = list(sorted(subset_size_to_bootstrap_points.keys()))
        points_list = [subset_size_to_bootstrap_points[x] for x in xvals]
        points = [subset_size_to_expected_max_est[x] for x in xvals]

        try:
            lower, _, upper = unzip(
                [np.percentile(points, [25, 50, 75]) for points in points_list]
            )

        except Exception as _:
            print(
                "Could not generate max_hp_curve for {}, too few points".format(
                    method_label
                )
            )
            continue
        plt.gca().fill_between(
            xvals, lower, upper, color=np.array(color + (25,)) / 255, zorder=1
        )
        plt.plot(
            xvals,
            points,
            label=r"{}.{}".format(index + 1, "\ \ " if index + 1 < 10 else " ")
            + method_label,
            color=np.array(color) / 255,
            lw=1.5,
            linestyle=line_style,
            marker=line_marker,
            markersize=3,
            markevery=4,
            zorder=2,
        )

    plt.title(title)
    plt.xlabel("Hyperparam. Evals")
    plt.ylabel(ylabel)

    plt.tight_layout()

    if include_legend:
        if put_legend_outside:
            ax = plt.gca()
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
        else:
            plt.legend()

    set_size(*fig_size)

    if save_path is None:
        plt.show()
    else:
        plt.savefig(
            save_path, bbox_inches="tight",
        )
        plt.close()


def paired_t_test(exp_type0, exp_type1, metric_key, df):
    exp0_df = copy.deepcopy(df.query("exp_type == '{}'".format(exp_type0))).sort_values(
        by=["seed"]
    )
    exp1_df = copy.deepcopy(df.query("exp_type == '{}'".format(exp_type1))).sort_values(
        by=["seed"]
    )

    seeds0 = np.array(exp0_df.loc[:, "seed"])
    seeds1 = np.array(exp1_df.loc[:, "seed"])

    if len(seeds0) != len(seeds1) or (seeds0 != seeds1).any():
        return None, Ttest_relResult(None, None)

    x0 = np.array(exp0_df.loc[:, metric_key])
    x1 = np.array(exp1_df.loc[:, metric_key])
    return (x0 - x1).mean(), ttest_rel(x0, x1)


def create_comparison_hp_plots_from_tsv(
    env_type: str,
    tsv_file_path: str,
    overwrite=True,
    include_subset_max_plots: bool = False,
    include_legend: bool = False,
):
    assert os.path.exists(tsv_file_path)

    file_dir, file_name = os.path.split(tsv_file_path)

    df = pd.read_csv(tsv_file_path, sep="\t")

    df = add_columns_to_df(df)
    df["gp_params"] = df["gp_params"].where(pd.notnull(df["gp_params"]), "None")
    df["gp_params_str"] = [str(gps) for gps in df["gp_params"]]

    env_type_key = env_type + "_env"
    assert (
        df[env_type_key] == df[env_type_key][0]
    ).all(), "mingrid_env must be the same for all elements of df"

    task_name = df[env_type_key][0]

    del df[env_type_key]

    df = df.sort_values(by=["exp_type", "seed"])

    group_keys = ["exp_type"]

    df_grouped = df.groupby(by=group_keys)
    df_grouped_lists = df_grouped.agg(list)

    # One sort index, based on the first metric
    for metric_key in [
        "reward",
        "success",
    ]:
        if not os.path.exists(file_dir):
            print("IN WRONG DIRECTORY.")
        else:
            os.makedirs(os.path.join(file_dir, "plots", task_name), exist_ok=True)

            box_save_path = os.path.join(
                os.path.join(file_dir, "plots", task_name),
                "{}__box_{}_{}.pdf".format(
                    file_name.replace(".tsv", ""), task_name, metric_key,
                ),
            )
            if (not overwrite) and os.path.exists(box_save_path):
                print(
                    "Plot {} exists and overwrite is `False`, skipping...".format(
                        box_save_path
                    )
                )
                continue

            grouped_df_index = df_grouped_lists.index.to_frame(index=False)
            method_keys = list(grouped_df_index["exp_type"])
            sort_index = [
                ind
                for _, ind in sorted(
                    [
                        (METHOD_ORDER.index(method_key), sort_ind)
                        if method_key in METHOD_ORDER
                        else 1e6
                        for sort_ind, method_key in enumerate(method_keys)
                        if method_key in METHOD_ORDER
                    ]
                )
            ]
            colors = [
                METHOD_TO_COLOR.get(method_keys[ind], (0, 0, 0),) for ind in sort_index
            ]

            line_styles = None
            line_markers = [
                METHOD_TO_LINE_MARKER.get(method_keys[ind], "",) for ind in sort_index
            ]

            sorted_multi_index = [
                tuple(grouped_df_index.loc[ind, :]) for ind in sort_index
            ]

            sorted_multi_index = [
                x if len(x) != 1 else x[0] for x in sorted_multi_index
            ]

            result_lens = {
                multi_ind: len(df_grouped_lists.loc[multi_ind, metric_key])
                for multi_ind in sorted_multi_index
            }
            print(result_lens)
            print(sum(result_lens.values()))

            points_list = [
                [df_grouped_lists.loc[multi_ind, metric_key]]  # [:min_len]]
                for multi_ind in sorted_multi_index
            ]

            pretty_label_lists = [
                df_grouped_lists.loc[multi_ind, "pretty_label"]
                for multi_ind in sorted_multi_index
            ]
            assert all(all_equal(l) for l in pretty_label_lists)

            yticklabels = [l[0] for l in pretty_label_lists]

            if include_subset_max_plots:
                subset_size_to_bootstrap_points_list = []
                subset_size_to_expected_mas_est_list = []
                for i in range(len(points_list)):
                    max_subset_size = len(points_list[i][0]) + 1 - 5
                    subset_size_to_expected_mas_est_list.append(
                        {
                            m: expected_max_of_subset_statistic(points_list[i][0], m=m)
                            for m in range(1, max_subset_size)
                        }
                    )
                    subset_size_to_bootstrap_points_list.append(
                        {
                            m: bootstrap_max_of_subset_statistic(
                                points_list[i][0], m=m, reps=500, seed=m
                            )
                            for m in range(1, max_subset_size)
                        }
                    )

                plot_max_hp_curves(
                    subset_size_to_expected_mas_est_list=subset_size_to_expected_mas_est_list,
                    subset_size_to_bootstrap_points_list=subset_size_to_bootstrap_points_list,
                    method_labels=yticklabels,
                    ylabel="$E[${}$]$".format(METRIC_TO_LABEL[metric_key]),
                    colors=colors,
                    line_styles=line_styles,
                    line_markers=line_markers,
                    fig_size=(5 * 2 / 5, 5 * 3.0 / 5),
                    save_path=box_save_path.replace("_box_", "_maxofsubset_"),
                    put_legend_outside=True,
                    include_legend=include_legend,
                )

            # Compute pairwise t-tests
            paired_t_test_results = {
                (method_keys[i], method_keys[j]): paired_t_test(
                    method_keys[i], method_keys[j], metric_key, df
                )
                for i in range(len(method_keys) - 1)
                for j in range(i + 1, len(method_keys))
            }
            if any(
                v[0] is None
                for (k0, k1), v in paired_t_test_results.items()
                if (k0 in METHOD_ORDER and k1 in METHOD_ORDER)
            ):
                print("Different seeds encountered!")
            k0s, k1s, mds, pvals, stats = unzip(
                [
                    (k0, k1, md, test_result.pvalue, test_result.statistic)
                    for (k0, k1), (md, test_result) in paired_t_test_results.items()
                ]
            )
            pd.DataFrame(
                dict(method0=k0s, method1=k1s, mean_diff=mds, pval=pvals, stats=stats)
            ).to_csv(
                box_save_path.replace("_box_", "_ttest_").replace(".pdf", ".tsv"),
                sep="\t",
                index=False,
            )


if __name__ == "__main__":
    args = _get_args()
    dir = "experiment_output/minigrid_random_hp_runs"
    paths = []

    overwrite = True

    if args.env_name == "":
        paths = glob.glob(os.path.join(dir, "random_*.tsv"))
    else:
        paths = [
            "experiment_output/minigrid_random_hp_runs/random_hp_search_minigrid_runs_{}.tsv".format(
                args.env_name
            )
        ]

    for path in paths:
        print()
        print(os.path.basename(path))
        create_comparison_hp_plots_from_tsv(
            env_type="minigrid",
            tsv_file_path=path,
            overwrite=overwrite,
            include_subset_max_plots=True,
        )
