import torch
import numpy as np
import seaborn as sns
from tqdm import tqdm

from ai_control_games.generic_game import MixedBlueTeamPolicy, MixedRedTeamPolicy

SAFETY_PALETTE = sns.color_palette("icefire_r", as_cmap=True)
SAFETY_PALETTE.set_extremes(bad="grey", under="red", over="blue")

USEFULNESS_PALETTE = sns.light_palette("seagreen", as_cmap=True)
USEFULNESS_PALETTE.set_extremes(bad="grey", under="white", over="green")

UTILITY_PALETTE = sns.color_palette("copper", as_cmap=True, )
UTILITY_PALETTE.set_extremes(bad="grey", under="yellow", over="brown")


def grid_then_ternary_search(f, left, right, epsilon, f_options, grid_res,
                             show_progress: bool = False) -> tuple[float, float]:
    xs = np.linspace(left, right, grid_res + 1)
    ys = [f(x, **f_options) for x in tqdm(xs, delay=5, desc="Grid search before ternary search", leave=False)]
    min_y_ind = np.argmin(ys)
    left = xs[max(0, min_y_ind - 1)]
    right = xs[min(len(xs) - 1, min_y_ind + 1)]
    last_x, last_f_x = ternary_search(f, left, right, epsilon, f_options, show_progress)
    return last_x, last_f_x


def ternary_search(f, left, right, epsilon, f_options, show_progress: bool = False) -> tuple[float, float]:
    diff = right - left
    while diff > epsilon:
        if show_progress:
            log_diff = np.log10(diff)
            log_epsilon = np.log10(epsilon)
            percentage_ish = 100 * (log_diff / log_epsilon)
            print(f"{percentage_ish}% done")
        m1 = left + diff / 3
        m2 = right - diff / 3
        f1 = f(m1, **f_options)
        f2 = f(m2, **f_options)
        if f1 > f2:
            left = m1
        else:
            right = m2
        diff = right - left

    last_x = (left + right) / 2
    last_f_x = f(last_x, **f_options)

    return last_x, last_f_x


def get_log_2_linspace(min_exp: int = -10,
                       max_exp: int = 0,
                       num_points: int = 11) -> list[float]:
    exps = torch.linspace(min_exp, max_exp, num_points)
    return [float(x) for x in 2 ** exps]


def floats_are_close(a: float, b: float, atol=1e-6, rtol=1e-6) -> bool:
    a = float(a)
    b = float(b)
    return bool(torch.isclose(torch.tensor(a), torch.tensor(b), atol=atol, rtol=rtol))


def print_source(func):
    import inspect
    print(inspect.getsource(func))


def XXXs_profile(f, file_str_to_be_included: str = '', num_rows: int = 20):
    """
    Profile a function and print the results, sorted by cumtime and limited to my code
    :param f: Function to profile
    :param file_str_to_be_included: String to limit the output rows to a specific file
    :param num_rows: Number of rows to print
    :return: The return value of f
    """
    import cProfile
    pr = cProfile.Profile()
    pr.enable()
    x = f()
    pr.disable()
    # Limit to my code from ../*
    import pstats
    p = pstats.Stats(pr)
    print("=" * 20 + "Profile" + "=" * 20)
    p.strip_dirs().sort_stats('tottime').print_stats(file_str_to_be_included, num_rows)
    return x


def get_leader_follower_eq_from_s_u_and_weight(
        blue_team_strategies: list[object],
        red_team_strategies: list[object],
        safeties: dict,
        uses: dict,
        weight: float,
        return_mixed: bool = False,
):
    value_matrix = np.zeros((len(blue_team_strategies), len(red_team_strategies)))
    for i, bts in enumerate(blue_team_strategies):
        value_matrix[i, :] += (1 - weight) * uses[bts]
        for j, rts in enumerate(red_team_strategies):
            value_matrix[i, j] += weight * safeties[bts, rts]
    if return_mixed:
        return get_leader_follower_eq_mixed(blue_team_strategies, red_team_strategies, value_matrix)
    else:
        return get_leader_follower_eq(blue_team_strategies, red_team_strategies, value_matrix)


def get_leader_follower_eq(
        blue_team_strategies: list[object],
        red_team_strategies: list[object],
        value_matrix: np.ndarray,
) -> tuple[tuple[tuple[object], tuple[float]], object, float]:
    assert not np.isnan(value_matrix).any()
    # Note that maxmin is the same as nash because we are in a zero-sum game (see paper)
    # value_matrix[i, j] = value if blue team plays i and red team plays j
    assert value_matrix.shape == (len(blue_team_strategies), len(red_team_strategies))
    assert len(blue_team_strategies) == len(red_team_strategies)
    if len(blue_team_strategies) == 1:
        return ((blue_team_strategies[0],), (1.0,)), red_team_strategies[0], float

    import nashpy as nash
    rps = nash.Game(value_matrix)
    nash_eqs = list(rps.support_enumeration())

    if len(nash_eqs) == 0:
        raise ValueError("No Nash Equilibria found")
    elif len(nash_eqs) > 1:
        ne_values = [p_vec @ value_matrix @ q_vec for p_vec, q_vec in nash_eqs]
        if not np.allclose(ne_values, ne_values[0]):
            raise ValueError(f"Multiple differently-valued Nash Equilibria found: {nash_eqs}")
        else:
            pass
            # print(f"Warning, we have a tie! {ne_values}")
    p_vec, q_vec = nash_eqs[0]

    policy_p_pairs = tuple((pol, p) for pol, p in zip(blue_team_strategies, p_vec) if p > 0.0)

    non_zero_red_payoffs = (p_vec @ value_matrix)[np.where(q_vec > 0.0)[0]]
    assert np.allclose(non_zero_red_payoffs, non_zero_red_payoffs[0]), "Red should have equal det payoffs"
    red_index = np.argmax(q_vec)  # If the above assertion passes, we can take any nonzero index

    policies, probs = zip(*policy_p_pairs)
    red_response = red_team_strategies[red_index]
    value = non_zero_red_payoffs[0]
    return (policies, probs), red_response, value


def get_leader_follower_eq_mixed(
        blue_team_strategies: list[object],
        red_team_strategies: list[object],
        value_matrix: np.ndarray,
) -> tuple[MixedBlueTeamPolicy, MixedRedTeamPolicy]:
    # TODO: try to return lower support policies where possible

    # cleaned-up version of get_leader_follower_eq for Adapted Double Oracle algorithm
    # gets a Nash Equilibrium for the given value matrix and converts the returned
    # probabilities into Mixed policies for Blue and Red.

    assert not np.isnan(value_matrix).any()
    # Note that maxmin is the same as nash because we are in a zero-sum game (see paper)
    # value_matrix[i, j] = value if blue team plays i and red team plays j

    import nashpy as nash
    rps = nash.Game(value_matrix)
    nash_eqs = list(rps.support_enumeration())

    if len(nash_eqs) == 0:
        raise ValueError("No Nash Equilibria found")
    elif len(nash_eqs) > 1:
        ne_values = [p_vec @ value_matrix @ q_vec for p_vec, q_vec in nash_eqs]
        if not np.allclose(ne_values, ne_values[0]):
            raise ValueError(f"Multiple differently-valued Nash Equilibria found: {nash_eqs}")
        else:
            pass
            # print(f"Warning, we have a tie! {ne_values}")
    p_vec, q_vec = nash_eqs[0]

    blue_policy_p_pairs = tuple((pol, p) for pol, p in zip(blue_team_strategies, p_vec) if p > 0.0)
    red_policy_q_pairs = tuple((pol, q) for pol, q in zip(red_team_strategies, q_vec) if q > 0.0)

    blue_determ_policies, blue_probs = zip(*blue_policy_p_pairs)
    red_determ_policies, red_probs = zip(*red_policy_q_pairs)
    blue_mix = MixedBlueTeamPolicy(blue_determ_policies, blue_probs)
    red_mix = MixedRedTeamPolicy(red_determ_policies, red_probs)
    return blue_mix, red_mix
