import itertools
import re
import typing
from typing import Dict, Tuple, Union, Optional

import matplotlib.pyplot as plt
import numpy as np

from utils.misc_utils import NICE_COLORS12_RGB

plt.rc("font", **{"family": "serif", "serif": ["CMU"]})
plt.rc("text", usetex=True)
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")

FIXED_ADVISOR_STR = r"ADV"

EXPERIMENT_TYPE_TO_LABEL_DICT = {
    "dagger_then_ppo": r"$\dagger \to$ PPO",
    "dagger_then_advisor_fixed_alpha_different_head_weights": r"$\dagger \to$ {}".format(
        FIXED_ADVISOR_STR
    ),
    "bc_then_ppo": r"BC$ \to$ PPO",
    "advisor_fixed_alpha_different_heads": r"{}".format(FIXED_ADVISOR_STR),
    "bc": r"BC",
    "dagger": r"DAgger $(\dagger)$",
    "ppo": r"PPO",
    "ppo_with_offpolicy_advisor_fixed_alpha_different_heads": r"ADV$^{\text{demo}} +$ PPO",
    "ppo_with_offpolicy": r"BC$^{\text{demo}} +$ PPO",
    "pure_offpolicy": r"BC$^{\text{demo}}$",
    "bc_teacher_forcing": r"BC$^{\text{tf}=1}$",
    "bc_teacher_forcing_then_ppo": r"BC$^{\text{tf}=1} \to$ PPO",
    "bc_teacher_forcing_then_advisor_fixed_alpha_different_head_weights": r"BC$^{\text{tf}=1} \to$ ADV",
}

METHOD_ORDER = [
    "bc",
    "dagger",
    "bc_teacher_forcing",
    "ppo",
    "bc_then_ppo",
    "dagger_then_ppo",
    "bc_teacher_forcing_then_ppo",
    "advisor_fixed_alpha_different_heads",
    "dagger_then_advisor_fixed_alpha_different_head_weights",
    "bc_teacher_forcing_then_advisor_fixed_alpha_different_head_weights",
    "pure_offpolicy",
    "ppo_with_offpolicy",
    "ppo_with_offpolicy_advisor_fixed_alpha_different_heads",
]

METHOD_TO_BASE_METHOD = {
    "pure_offpolicy": "bc",
    "ppo_with_offpolicy": "bc_then_ppo",
    "ppo_with_offpolicy_advisor_fixed_alpha_different_heads": "dagger_then_advisor_fixed_alpha_different_head_weights",
}


METHOD_TO_COLOR = {}
METHOD_TO_LINE_STYLE = {}
METHOD_TO_LINE_MARKER = {}
NICE_MARKERS = ("", "|", "x", "^")
for i, method in enumerate(METHOD_ORDER):
    ind = i
    n = len(NICE_COLORS12_RGB)
    METHOD_TO_COLOR[method] = NICE_COLORS12_RGB[(ind + (ind // n)) % n]
    METHOD_TO_LINE_STYLE[method] = ["solid", "dashed", "dashdot"][ind % 3]
    METHOD_TO_LINE_MARKER[method] = NICE_MARKERS[ind % len(NICE_MARKERS)]


def minigrid_env_to_label(arr: np.array):
    assert len(arr.shape) == 1
    d = {
        "Crossing": "Lava Crossing",
    }
    return np.array([d[x] for x in arr], dtype=np.str)


def run_info_to_pretty_label(run_info: Dict[str, Optional[Union[int, str, float]]]):
    exp_type = run_info["exp_type"]
    optimized = run_info["optimized"]
    if optimized:
        if EXPERIMENT_TYPE_TO_LABEL_DICT[exp_type][-1] == "$":
            return EXPERIMENT_TYPE_TO_LABEL_DICT[exp_type][:-1] + "^*$"
        else:
            return EXPERIMENT_TYPE_TO_LABEL_DICT[exp_type] + "$^*$"
    else:
        return EXPERIMENT_TYPE_TO_LABEL_DICT[exp_type]


def add_columns_to_df(df):
    alpha_queries = {
        "alpha_start": r"hyperparams.anneal_alpha_start = (.*)",
        "alpha_stop": r"hyperparams.anneal_alpha_stop = (.*)",
        "fixed_alpha": r"hyperparams.fixed_alpha = (.*)",
    }
    for key in itertools.chain(
        alpha_queries,
        ["lr", "tf_ratio", "lr_optimized", "tf_optimized", "pretty_label"],
    ):
        df[key] = [None] * df.shape[0]

    df.loc[:, "gp_params"] = [
        gps if not isinstance(gps, str) else eval(gps) for gps in df.loc[:, "gp_params"]
    ]

    for i in range(df.shape[0]):
        row = df.loc[i, :]

        # ALPHA KEYS
        gp_params: Tuple[str, ...] = row["gp_params"]
        if (
            (isinstance(gp_params, float) and np.isnan(gp_params))
            or gp_params is None
            or gp_params == "None"
        ):
            gp_params = tuple()

        for alpha_query_key, alpha_query in alpha_queries.items():
            for gp_param in gp_params:
                match = re.search(alpha_query, gp_param)
                try:
                    value = float(match.group(1))
                    if value == int(value):
                        value = int(value)
                    df.loc[i, alpha_query_key] = value
                    break
                except AttributeError:  # no match
                    pass

        # LR
        lr_queries = [
            r"hyperparams.lr = (.*)",
            r"lr_([^_].*)",
        ]
        df.loc[i, "lr_optimized"] = False
        for param in itertools.chain(gp_params, [row["extra_tag"]]):
            for lr_query in lr_queries:
                match = re.search(lr_query, param)
                try:
                    value = match.group(1)

                    df.loc[i, "lr_optimized"] = (
                        df.loc[i, "lr_optimized"] or "optimal" in value
                    )
                    try:
                        df.loc[i, "lr"] = float(match.group(1))
                    except ValueError:
                        pass
                except AttributeError:  # no match
                    pass
            if df.loc[i, "lr"] is not None:
                break

        # TF
        tf_queries = [
            r"hyperparams.tf_ratio = (.*)",
            r"tf_([^_].*)",
        ]
        df.loc[i, "tf_optimized"] = False
        for param in itertools.chain(gp_params, [row["extra_tag"]]):
            for tf_query in tf_queries:
                match = re.search(tf_query, param)
                try:
                    value = match.group(1)

                    df.loc[i, "tf_optimized"] = "optimal" in value
                    try:
                        df.loc[i, "tf_ratio"] = float(match.group(1))
                    except ValueError:
                        pass
                except AttributeError:  # no match
                    pass
            if df.loc[i, "lr"] is not None:
                break

    df.loc[:, "optimized"] = np.logical_or(
        df.loc[:, "lr_optimized"], df.loc[:, "tf_optimized"]
    )

    if df["optimized"].all():
        # If everything is optimized then nothing is
        df["optimized"] = False

    for i in range(df.shape[0]):
        df.loc[i, "pretty_label"] = run_info_to_pretty_label(dict(df.loc[i, :]))

    return df


def all_equal(s: typing.Sequence):
    if len(s) <= 1:
        return True
    return all(s[0] == ss for ss in s[1:])
