import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import AutoMinorLocator
from DataManipulation import load_from_file
from utils import fibonacci_array


def load_data_from_different_setup_and_save_them():
    for p in [0.1, 0.3, 0.5, 0.7, 0.9]:
        results = {
            "full": {"gain": [], "expl": []},
            "ses": {"gain": [], "expl": []},
            "exp-strat": {"gain": [], "expl": []},
            "unsafe": {"gain": [], "expl": []},
            "comb": {"gain": [], "expl": []},
        }
        for i in fibonacci_array(9, True):
            data = load_from_file(f"results/br_sbr/all_leduc_holdem_lp_cfr_iter={i}")
            for gadget in ["full", "ses", "exp-strat", "unsafe", "comb"]:
                for index, g_type in enumerate(data["labels"]):
                    if g_type == gadget:
                        for p_index, j in enumerate(data["p"][index]):
                            if round(j, 1) == p:
                                results[gadget]["gain"].append(data["gain"][index][p_index])
                                results[gadget]["expl"].append(data["expl"][index][p_index])

        for gadget in ["full", "ses", "exp-strat", "unsafe", "comb"]:
            with open(f"results/br_sbr/split_subgame_setup/{gadget}_leduc_holdem_step_r_p={p}_cfr_var_iter.txt", "w+") as gain_file:
                gain_file.write(f"# gain of {gadget} on Leduc Hold'em. Each line are different iterations in format num_iterations gain 0\n")
                gain_file.write("steps r\n")
                for i, iteration in enumerate(fibonacci_array(9, True)):
                    gain_file.write(f"{iteration} {results[gadget]['gain'][i]} 0\n")
            with open(f"results/br_sbr/split_subgame_setup/{gadget}_leduc_holdem_step_r_p={p}_cfr_var_iter_expl.txt", "w+") as expl_file:
                expl_file.write(f"# exploitability of {gadget} on Leduc Hold'em. Each line are different iterations in format num_iterations expl 0\n")
                expl_file.write("steps r\n")
                for i, iteration in enumerate(fibonacci_array(9, True)):
                    expl_file.write(f"{iteration} {results[gadget]['expl'][i]} 0\n")
        print(results)


def load_general_index_value_br(file_name, value_limit, include_average):
    iterations = []
    rnr_values = []
    br_values = []
    value_count = 0
    with open(file_name, "r") as file:
        for line in file:
            if value_count == value_limit:
                break
            if line.startswith("#") or line.startswith("steps"):
                continue
            tokens = line.split()
            iterations.append(int(tokens[0]))
            rnr_values.append(float(tokens[1]))
            br_values.append(float(tokens[2]))
            value_count += 1
    if include_average:
        iterations.append("a")
        rnr_values.append(np.average(rnr_values))
        br_values.append(np.average(br_values))
    return iterations, rnr_values, br_values


def load_sbrs(file_name, steps, value_limit, include_average):
    iterations = []
    sbr_values = []
    br_values = []
    for step in steps:
        local_iterations = []
        local_sbr_values = []
        local_br_values = []
        value_count = 0
        with open(file_name.format(step), "r") as file:
            for line in file:
                if value_count == value_limit:
                    break
                if line.startswith("#") or line.startswith("steps"):
                    continue
                tokens = line.split()
                local_iterations.append(int(tokens[0]))
                local_sbr_values.append(float(tokens[1]))
                local_br_values.append(float(tokens[2]))
                value_count += 1
        if include_average:
            local_iterations.append("a")
            local_sbr_values.append(np.average(local_sbr_values))
            local_br_values.append(np.average(local_br_values))
        iterations.append(local_iterations)
        sbr_values.append(local_sbr_values)
        br_values.append(local_br_values)
    return iterations, sbr_values, br_values


def plot_br_srnr_steps_gain(game_name, steps, p, dynamic_steps=True, value_limit=-1, include_average=True, ax=None, cmap=None, mode="cfr", yticks=None, print_set=None, folder="", print_steps=True):
    plot_br_srnr_steps(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=value_limit, include_average=include_average, ax=ax, cmap=cmap, mode=mode, yticks=yticks, metric="gain", print_set=print_set,
                       folder=folder, print_steps=print_steps)


def plot_br_srnr_steps_expl(game_name, steps, p, dynamic_steps=True, value_limit=-1, include_average=True, ax=None, cmap=None, mode="cfr", yticks=None, print_set=None, folder="", print_steps=True):
    plot_br_srnr_steps(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=value_limit, include_average=include_average, ax=ax, cmap=cmap, mode=mode, yticks=yticks, metric="expl", print_set=print_set,
                       folder=folder, print_steps=print_steps)


def plot_br_srnr_steps(game_name, steps_in, p, dynamic_steps=True, value_limit=-1, include_average=True, ax=None, cmap=None, mode="cfr", yticks=None, metric="gain", print_set=None, folder="", print_steps=True):
    if metric == "gain":
        metric = ""
    elif metric == "expl":
        metric = "_expl"
    step_values = []
    step_labels = []
    step_size = []

    other_values = []
    other_labels = []
    steps = []

    step_values_count = 0
    values_count = 0

    if mode == "cfr":
        if print_set is None or "br" in print_set:
            rnr_iterations, rnr_values, rnr_br_values = load_general_index_value_br(f"results/br_sbr/{folder}rnr_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((rnr_iterations, rnr_br_values, rnr_br_values))
            other_labels.append(("BR", "Best response"))
            values_count += 1
        if print_set is None or "srnr" in print_set:
            srnr_iterations, srnr_values, srnr_br_values = load_sbrs(f"results/br_sbr/{folder}srnr_" + game_name + "_step_{0}_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt",
                                                                     steps_in[0],
                                                                     value_limit, include_average)
            step_values.append((srnr_iterations, srnr_values, srnr_br_values))
            step_labels.append(("S", "Continual depth limited best response"))
            step_size.append(len(steps_in[0]))
            values_count += len(steps_in[0])
            step_values_count += len(steps_in[0])
            steps.append(steps_in[0])
        if print_set is None or "rnr" in print_set:
            rnr_iterations, rnr_values, rnr_br_values = load_general_index_value_br(f"results/br_sbr/{folder}rnr_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((rnr_iterations, rnr_values, rnr_br_values))
            other_labels.append(("RNR", "Restricted Nash response"))
            values_count += 1
        if print_set is None or "bne" in print_set:
            bne_iterations, bne_values, bne_br_values = load_general_index_value_br(f"results/br_sbr/{folder}bne_" + game_name + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((bne_iterations, bne_values, bne_br_values))
            other_labels.append(("BNE", "Best Nash equilibrium"))
            values_count += 1
        if print_set is None or "srnrvf" in print_set:
            srnrvf_iterations, srnrvf_values, srnrvf_br_values = load_general_index_value_br(f"results/br_sbr/{folder}srnr_vf_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit,
                                                                                             include_average)
            other_values.append((srnrvf_iterations, srnrvf_values, srnrvf_br_values))
            other_labels.append(("VF", "Continual depth limited restricted Nash response with value function"))
            values_count += 1
        if print_set is None or "srnrg" in print_set:
            srnrg_iterations, srnrg_values, srnrg_br_values = load_general_index_value_br(f"results/br_sbr/{folder}srnr_g_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((srnrg_iterations, srnrg_values, srnrg_br_values))
            other_labels.append(("VFG", "Continual depth limited restricted Nash response with value function using gadget"))
            values_count += 1
        if print_set is None or "srnru" in print_set:
            srnru_iterations, srnru_values, srnru_br_values = load_general_index_value_br(f"results/br_sbr/srnr_nog_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((srnru_iterations, srnru_values, srnru_br_values))
            other_labels.append(("VFU", "Continual depth limited restricted Nash response with unsafe resolving"))
            values_count += 1
        if print_set is None or "comb" in print_set:
            comb_iterations, comb_values, comb_br_values = load_sbrs(
                f"results/br_sbr/{folder}comb_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[2], value_limit, include_average)
            step_values.append((comb_iterations, comb_values, comb_br_values))
            step_labels.append(("CDBR-NE", "Combination of NE and BR"))
            step_size.append(len(steps_in[2]))
            values_count += len(steps_in[2])
            step_values_count += len(steps_in[2])
            steps.append(steps_in[2])
        if print_set is None or "ses" in print_set:
            ses_iterations, ses_values, ses_br_values = load_sbrs(
                f"results/br_sbr/{folder}ses_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[1], value_limit, include_average)
            step_values.append((ses_iterations, ses_values, ses_br_values))
            step_labels.append(("SES", "Safe exploitation search"))
            step_size.append(len(steps_in[1]))
            values_count += len(steps_in[1])
            step_values_count += len(steps_in[1])
            steps.append(steps_in[1])
        if print_set is None or "full" in print_set:
            full_iterations, full_values, full_br_values = load_sbrs(
                f"results/br_sbr/{folder}full_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[1], value_limit, include_average)
            step_values.append((full_iterations, full_values, full_br_values))
            step_labels.append(("CDRNR", "Continual depth-limited response"))
            step_size.append(len(steps_in[3]))
            values_count += len(steps_in[3])
            step_values_count += len(steps_in[3])
            steps.append(steps_in[3])
        if print_set is None or "unsafe" in print_set:
            unsafe_iterations, unsafe_values, unsafe_br_values = load_sbrs(
                f"results/br_sbr/{folder}unsafe_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[1], value_limit, include_average)
            step_values.append((unsafe_iterations, unsafe_values, unsafe_br_values))
            step_labels.append(("U", "Unsafe solving"))
            step_size.append(len(steps_in[4]))
            values_count += len(steps_in[4])
            step_values_count += len(steps_in[4])
            steps.append(steps_in[4])
        if print_set is None or "exp" in print_set:
            exp_iterations, exp_values, exp_br_values = load_sbrs(
                f"results/br_sbr/{folder}exp-strat_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[1], value_limit, include_average)
            step_values.append((exp_iterations, exp_values, exp_br_values))
            step_labels.append(("EXP", "Exp-strat from chinese paper"))
            step_size.append(len(steps_in[5]))
            values_count += len(steps_in[5])
            step_values_count += len(steps_in[5])
            steps.append(steps_in[5])
    elif mode == "random":
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/srnr_" + game_name + "_step_{0}_p=" + str(p) + "_random_seeded" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps, value_limit,
                                                          include_average)
        rnr_iterations, rnr_values, rnr_br_values = load_general_index_value_br("results/br_sbr/rnr_" + game_name + "_p=" + str(p) + "_random_seeded" + metric + ".txt", value_limit, include_average)
        bne_iterations, bne_values, bne_br_values = load_general_index_value_br("results/br_sbr/bne_" + game_name + "_random_seeded" + metric + ".txt", value_limit, include_average)
        srnrvf_iterations, srnrvf_values, srnrvf_br_values = load_general_index_value_br("results/br_sbr/srnr_vf_" + game_name + "_p=" + str(p) + "_random_seeded" + metric + ".txt", value_limit, include_average)
        comb_iterations, comb_values, comb_br_values = load_general_index_value_br(
            "results/br_sbr/comb_" + game_name + "_step_1" + "_p=" + str(p) + "_random_seeded" + metric + ("_dynamic" if dynamic_steps else "") + ".txt",
            value_limit, include_average)
    positions = list(range(len(other_values[0][0])))
    other_positions = []
    step_positions = []
    for i in range(len(other_values)):
        other_positions.append([])
    for i in range(len(step_values)):
        step_positions.append([])
        for _ in range(step_size[i]):
            step_positions[i].append([])
    total_bars = values_count
    diff = 1. / (total_bars + 1)
    for pos in positions:
        start_pos = pos - 0.5
        for i in range(len(other_positions)):
            other_positions[i].append(start_pos + diff * (i + 1))
        offset = 0
        for i in range(len(step_positions)):
            for step in range(step_size[i]):
                step_positions[i][step].append(start_pos + diff * ((values_count - step_values_count) + 1 + offset))
                offset += 1
    if ax is None:
        for i in range(values_count - step_values_count):
            plt.bar(other_positions[i], other_values[i][1], width=1. / (total_bars + 2), label=other_labels[i][1])
        for i in range(len(step_positions)):
            for step in range(step_size[i]):
                if print_steps:
                    local_label = str(step_labels[i][1]) + str(steps[i][step])
                else:
                    local_label = str(step_labels[i][1])
                plt.bar(step_positions[i][step], step_values[i][1][step], width=1. / (total_bars + 2), label=local_label)
        plt.xlabel("Cfr iterations of opponent strategy")
        if metric == "":
            plt.ylabel("Gain")
        else:
            plt.ylabel("Exploitability")
        plt.grid(which="major", axis="y")
        plt.grid(which="minor", axis="y", alpha=0.3)
        plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
        plt.axes().set_axisbelow(True)
        plt.xticks(positions, other_values[0][0])
        plt.title("Gain comparison of SRNR and BR against CFR with low iterations\n"
                  "on " + game_name + " with varying SBR step sizes and p = " + str(p))
        plt.legend()
        plt.show()
    else:
        for i in range(values_count - step_values_count):
            ax.bar(other_positions[i], other_values[i][1], width=1. / (total_bars + 2), label=other_labels[i][0])
        for i in range(len(step_positions)):
            for step in range(step_size[i]):
                if print_steps:
                    local_label = str(step_labels[i][0]) + str(steps[i][step])
                else:
                    local_label = str(step_labels[i][0])
                ax.bar(step_positions[i][step], step_values[i][1][step], width=1. / (total_bars + 2), label=local_label)
        if metric == "":
            ax.set_ylabel("Gain")
        else:
            ax.set_ylabel("Exploitability")
        ax.grid(which="major", axis="y")
        ax.grid(which="minor", axis="y", alpha=0.3)
        if yticks is not None:
            ax.set_yticks(yticks)
        ax.yaxis.set_minor_locator(AutoMinorLocator(4))
        ax.set_axisbelow(True)
        ax.set_xticks(positions)
        ax.set_xticklabels(other_values[0][0])


## SRNR part
def plot():
    br = True
    rnr = True
    bne = True
    srnrg = False
    srnru = False
    srnrvf = False
    srnr = True
    comb = False
    ses = True
    full = False
    unsafe = False
    exp = False

    print_steps = True
    dynamic_steps = False
    cdrnr_steps = [1, 2, 3, 4, 5]
    ses_steps = [5]
    comb_steps = ["r"]
    full_steps = ["r"]
    unsafe_steps = ["r"]
    exp_steps = ["r"]
    steps = [cdrnr_steps, ses_steps, comb_steps, full_steps, unsafe_steps, exp_steps]
    game_name = "iigs5"
    mode = "cfr"
    p = 0.9
    print_set = set()
    bar_count = 0
    folder = ""

    if br:
        print_set.add("br")
        bar_count += 1
    if rnr:
        print_set.add("rnr")
        bar_count += 1
    if bne:
        print_set.add("bne")
        bar_count += 1
    if srnrg:
        print_set.add("srnrg")
        bar_count += 1
    if srnru:
        print_set.add("srnru")
        bar_count += 1
    if srnrvf:
        print_set.add("srnrvf")
        bar_count += 1
    if srnr:
        print_set.add("srnr")
        bar_count += len(cdrnr_steps)
    if comb:
        print_set.add("comb")
        bar_count += len(comb_steps)
    if ses:
        print_set.add("ses")
        bar_count += len(ses_steps)
    if full:
        print_set.add("full")
        bar_count += len(full_steps)
    if unsafe:
        print_set.add("unsafe")
        bar_count += len(unsafe_steps)
    if exp:
        print_set.add("exp")
        bar_count += len(exp_steps)

    y_ticks_gain = [0, 1, 2, 3, 4]
    y_ticks_expl = [0, 1, 2, 3, 4]
    # y_ticks_gain = [0, 0.25, 0.5, 0.75, 1]
    # y_ticks_expl = [0, 0.25, 0.5, 0.75, 1]

    broken_axe = 0
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    plt.gcf().subplots_adjust(bottom=0.23, left=0.07, right=0.99, top=0.84, wspace=0.2)

    cmap = plt.get_cmap('CMRmap')
    indices = list(np.linspace(0, cmap.N, bar_count + 1))[:-1]
    indices = indices[0::2] + indices[1::2]
    my_colors = [cmap(int(i)) for i in indices]
    ax1.set_prop_cycle(color=my_colors)
    ax2.set_prop_cycle(color=my_colors)
    plot_br_srnr_steps_gain(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=9, include_average=True, ax=ax1, mode=mode, yticks=y_ticks_gain, print_set=print_set, folder=folder, print_steps=print_steps)
    plot_br_srnr_steps_expl(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=9, include_average=True, ax=ax2, mode=mode, yticks=y_ticks_expl, print_set=print_set, folder=folder, print_steps=print_steps)
    plt.legend(bbox_to_anchor=(-1.2, 1.02, 2.2, .102), loc='lower left',
               ncol=bar_count, mode="expand", borderaxespad=0., handlelength=1, handletextpad=0.3, borderpad=0.1)
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
    plt.xlabel("CFR iterations of opponent's strategy (p=" + str(p) + ")")
    plt.show()


if __name__ == '__main__':
    # load_data_from_different_setup_and_save_them()
    plot()
