import matplotlib.pyplot as plt
import numpy as np
from numpy import ediff1d, arange

from CFR import CFR
from SequenceNash import SequenceNash
from DataManipulation import *
from ExtensiveGame import ExtensiveGame, Node
from RNR import RNR
from Combination import Combination
import time
from glob import glob
from os import path
from BestNEStatic import BestNashStatic
from GenerateRnRGame import GenerateRNRGame
import scipy.stats as st
from CFRD import CFRD
from RNRD import RNRD
import copy
from SubgameCFR import SubgameCFR
from ExtensiveSubgame import ExtensiveSubgame
from ExtensiveTrunk import ExtensiveTrunk
import random
import os
import scipy
import gurobipy as g
import pylab as pl
from matplotlib import collections as mc
import sklearn.metrics
from Combination import combine_strategies
from queue import PriorityQueue
from disjoint_set import DisjointSet
from sklearn.cluster import SpectralClustering
from GenerateRnRGameNotFixed import GenerateRNRGameNotFixed
from cycler import cycler
from mpl_toolkits.axes_grid1 import make_axes_locatable
from CFRRNR import CFRRNR
# from RPS import RPS
from CDRNR_Gadget_test import CDRNR_G
import multiprocessing as mp
# import NNMatrixTraining as NetworkTraining
from sklearn.preprocessing import normalize
from utils import fibonacci_array

import time
import sys
from sklearn.cluster import KMeans
import matplotlib


def nodes_number(fname):
    game = ExtensiveGame()
    game.load(fname)
    return game.node_count


def test_only_cfrqr(fname, rationality=1, norm=False, iterations=1000, verbose=0):
    if norm:
        cfrqr = CFRQRNORM(fname, rationality=rationality)
    else:
        cfrqr = CFRQRCFV(fname, rationality=rationality)
    cfrqr.solve(verbose=verbose, iterations=iterations)


def test_only_cfrbr(fname, iterations=1000, verbose=0):
    cfrbr = CFRBR(fname)
    cfrbr.solve(verbose=verbose, iterations=iterations)


def test_only_comb(fname, rationality=1, splits=11, iterations=1000):
    cfrbr = CFRBR(fname)
    cfrbr.solve(iterations)

    cfrqr = CFRQRCFV(fname, rationality=rationality)
    cfrqr.solve(iterations)

    qr_str = cfrqr.strategy
    nash_str = cfrbr.strategy

    combine = Combination(qr_str, nash_str, cfrbr.game)
    return combine.combination_space(splits)


def test_only_rqr(fname, rationality=1, splits=11, iterations=1000):
    rqr = RQR(fname, rationality=rationality)
    for p in np.linspace(0, 1, splits):
        rqr.solve(iterations, p=p)


def test_only_qse(fname, rationality=1):
    seq_qse = SequenceQSE(fname)
    seq_qse.solve(rationality=rationality)


def test_cfrqr(fname, rationality=1, norm=False, iterations=1000, save_strategy=False):
    if norm:
        cfrqr = CFRQRNORM(fname, rationality=rationality)
    else:
        cfrqr = CFRQRCFV(fname, rationality=rationality)
    cfrqr.solve(verbose=3, iterations=iterations)

    # cfr = CFRBR(fname)
    # cfr.solve(verbose=3, iterations=10)

    seq_nash = SequenceNash(fname)
    seq_nash.solve()

    # print(cfrqr.normalized_quantal_response(0, seq_nash.strategy_in_cfr_format(), rationality=rationality)[0])
    # print(cfrqr.best_response(0, seq_nash.strategy_in_cfr_format())[0])

    # print(cfrqr.normalized_quantal_response(0, cfr.average_strategy, rationality=rationality)[0])
    # print(cfrqr.best_response(0, cfr.average_strategy)[0])

    if save_strategy:
        if norm:
            save_to_file(cfrqr.average_strategy, "data/strategies/leduc_holdem/cfrqrnorm_" + str(rationality) + ".str")
        else:
            save_to_file(cfrqr.average_strategy, "data/strategies/leduc_holdem/cfrqrcfv_" + str(rationality) + ".str")

    if norm:
        print(cfrqr.normalized_quantal_response(0, cfrqr.average_strategy, rationality=rationality)[0],
              "QR to average strategy")
    else:
        print(cfrqr.quantal_response(0, cfrqr.average_strategy, rationality=rationality)[0], "QR to average strategy")
    print(cfrqr.best_response(1, cfrqr.average_strategy)[0], "BR to average strategy")
    if norm:
        print(cfrqr.best_response(1, cfrqr.normalized_quantal_response(
            0, cfrqr.average_strategy, rationality=rationality)[1])[0], "BR to QR to average strategy")
    else:
        print(cfrqr.best_response(1, cfrqr.quantal_response(
            0, cfrqr.average_strategy, rationality=rationality)[1])[0], "BR to QR to average strategy")
    print()
    if norm:
        print(cfrqr.normalized_quantal_response(0, cfrqr.strategy, rationality=rationality)[0],
              "QR to current strategy")
    else:
        print(cfrqr.quantal_response(0, cfrqr.strategy, rationality=rationality)[0], "QR to current strategy")
    print(cfrqr.best_response(1, cfrqr.strategy)[0], "BR to current strategy")
    if norm:
        print(cfrqr.best_response(1, cfrqr.normalized_quantal_response(
            0, cfrqr.strategy, rationality=rationality)[1])[0], "BR to QR to current strategy")
    else:
        print(cfrqr.best_response(1, cfrqr.quantal_response(
            0, cfrqr.strategy, rationality=rationality)[1])[0], "BR to QR to current strategy")

    # cfrqr.print_strategy(cfrqr.average_strategy, 1, decimal_points=8)
    # cfrqr.print_strategy(cfrqr.average_strategy, 0)
    #
    # cfrqr.print_strategy(cfrqr.strategy, 1)
    # cfrqr.print_strategy(cfrqr.strategy, 0)
    #
    # cfrqr.print_strategy(cfrqr.normalized_quantal_response(0, cfrqr.average_strategy, rationality=rationality)[1], 0)
    plot_convergence_curves([(cfrqr.regret_sum_accumulator, "CFRQR", "r-")], ["Iterations", "Current regret"],
                            "Regret convergence",
                            hlines=[], labels=None, logscale_x=False, logscale_y=False,
                            zeroline=True, vlines=False)


def test_cfr(fname):
    cfrbr = CFR(fname, cfr_plus=True)
    cfrbr.solve(verbose=3, iterations=1000)

    seq_qse = BestNE(fname)
    seq_qse.solve(rationality=1)

    print(cfrbr.quantal_response(0, cfrbr.average_strategy, 1)[0], "QR")
    print(cfrbr.best_response(0, cfrbr.average_strategy)[0], "BR")

    print(cfrbr.quantal_response(0, seq_qse.strategy_in_cfr_format(), 1)[0], "QR to nash")
    print(cfrbr.best_response(0, seq_qse.strategy_in_cfr_format())[0], "BR to nash")

    # cfr.print_strategy(cfr.average_strategy, 1)
    # cfr.print_strategy(cfr.average_strategy, 0)

    # print(cfrbr.quantal_response(0, cfrbr.average_strategy, 1))
    # print(cfrbr.best_response(0, cfrbr.average_strategy))


def test_plot(fname):
    cfr = CFRQRNORM(fname)
    cfr.solve(verbose=3, iterations=1000, save_progression=True)

    plotable_progression = create_plotable_progression(cfr.progression)
    name, axis_labels = data_for_graph_of_convergence_curve("CFR")
    plot_convergence_curves(plotable_progression, axis_labels, name, logscale_y=True, logscale_x=False)


def test_qrqr(fname):
    qrqr = QRQR(fname, rationality=1)
    qrqr.solve(1000, verbose=3)

    print(qrqr.normalized_quantal_response(0, qrqr.average_strategy, 1)[0])
    print(qrqr.normalized_quantal_response(1, qrqr.average_strategy, 1)[0])

    # qrqr.print_strategy(qrqr.average_strategy, 1)
    # qrqr.print_strategy(qrqr.average_strategy, 0)
    #
    # qrqr.print_strategy(qrqr.strategy, 1)
    # qrqr.print_strategy(qrqr.strategy, 0)


def load_and_print_strategy(dir, game, algorithm, extension):
    cfr = CFR("data/" + dir + "/" + game + "." + extension)
    cfr.solve(1)
    strategy = load_from_file("data/strategies/" + dir + "_" + game + "_" + algorithm + ".str")
    cfr.print_strategy(strategy, 1)
    cfr.print_strategy(strategy, 0)
    print(cfr.best_response(1, strategy)[0])
    print(cfr.normalized_quantal_response(0, strategy, rationality=1)[0])
    cfr.strategy = strategy
    cfr.print_strategy(cfr.normalized_quantal_response(0, strategy, 1)[1], 0)
    cfr.print_strategy(cfr.best_response(1, strategy)[1], 1, compact=True)
    cfr.print_strategy(cfr.best_response(1, cfr.normalized_quantal_response(0, strategy, 1)[1])[1], 1, compact=True)


def test_seq_qse(fname, rationality=1):
    seq_qse = BestNE(fname)
    seq_qse.solve(rationality=rationality)

    strategy = seq_qse.strategy_in_cfr_format()

    cfr = CFR(fname)
    cfr.solve(1)
    # cfr.print_strategy(strategy, 1, compact=True, decimal_points=8)

    seq_nas = SequenceNash(fname)
    seq_nas.solve()

    print(-seq_nas.solution)

    print(cfr.best_response(0, strategy)[0])
    print(cfr.quantal_response(0, strategy, rationality=rationality)[0])
    print(cfr.normalized_quantal_response(0, strategy, rationality=rationality)[0])

    save_to_file(seq_qse.strategy_in_cfr_format(), "data/strategies/leduc_holdem/best_cfv_nash_1.str")


def nash_and_nashqr(fname, rationality=1):
    seq_nash = SequenceNash(fname)
    seq_nash.solve()

    cfr = CFR(fname)
    cfr.solve(1)

    print(-seq_nash.solution, "nash")
    # print(cfr.normalized_quantal_response(0, seq_nash.strategy_in_cfr_format(), rationality=rationality)[0])
    # print(seq_nash.strategy_in_cfr_format())


def opt():
    P = np.asarray([[1, 1, 2, 2], [3, 1, 3, 1]])
    increment = 0.01
    space = np.linspace(increment, 1, int(1 / increment))
    import itertools
    p1strspace = np.asarray([x for x in itertools.product(space, space)])
    p2strspace = np.asarray([x for x in itertools.product(space, repeat=4)])
    print(len(p2strspace))


def nash_value_for_michal():
    results = []
    for i in range(1, 12):
        fname = "data/trees/maze_" + str(i) + ".efg"
        print(fname)
        seq_nash = SequenceNash(fname)
        # seq_nash.game.print_tree()
        seq_nash.solve()

        results.append(("maze_" + str(i), seq_nash.solution))

    with open("data/trees/values.txt", "w+") as file:
        for line in results:
            file.write(line[0] + ": " + str(line[1]) + "\n")


def test_cfrlin(fname, rationality=0.01, iterations=1000):
    cfrlin = CFRLin(fname, rationality=rationality)
    cfrlin.solve(iterations=iterations, verbose=3)

    cfrlin.print_strategy(cfrlin.average_strategy, 1, decimal_points=8)
    cfrlin.print_strategy(cfrlin.strategy, 1, decimal_points=8)
    cfrlin.print_strategy(cfrlin.average_strategy, 0)

    print(cfrlin.linear_response(0, cfrlin.average_strategy, rationality)[0])


def create_cool_graphs_from_times():
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    plt.gcf().subplots_adjust(bottom=0.23, left=0.11, right=0.99, top=0.84, wspace=0.4)
    for case, ax in zip([0, 2], [ax1, ax2]):
        if case == 0:
            fname = "efg_times.txt"
        elif case == 1:
            fname = "zs_nfg_times.txt"
        else:
            fname = "gs_nfg_times.txt"
        values = [[], [], [], [], [], [], []]
        with open("data/" + fname, "r") as file:
            for i, line in enumerate(file):
                for input in line.split():
                    if i > 0:
                        values[i].append(float(input.replace(",", ".")))
                    else:
                        values[i].append(float(input))
        # print(values)
        if case == 0:
            plot_times([(values[2], "COMB", "-"),
                        (values[3], "GA", "dotted"), (values[1], "RQR", "dashed"), (values[5], "NE", "dashdot")
                           , (values[4], "QNE", (0, (3, 1, 1, 1, 1, 1)))],
                       ("Game size [# of Histories]", "Time [s]"), "",
                       labels=values[0], font_size=13, logscale_x=False, x_ticks=[0, 1500, 3000, 4500],
                       y_ticks=[0, 200, 400, 600, 800], ax=ax)
        elif case == 1:
            plot_times([(values[2], "COMB", "-"),
                        (values[3], "GA", "dotted"), (values[1], "RQR", "dashed"), (values[5], "NASH", "dashdot")
                           , (values[4], "QNE", (0, (3, 1, 1, 1, 1, 1)))],
                       ("Game size [# of Actions]", "Time [s]"), "",
                       labels=values[0], font_size=13, logscale_x=False, ax=ax)
        else:
            plot_times([(values[2], "COMB", "-"),
                        (values[3], "GA", "dotted"), (values[1], "RQR", "dashed"), (values[5], "NASH", "dashdot")
                           , (values[4], "QNE", (0, (3, 1, 1, 1, 1, 1))), (values[6], "SE", (0, (3, 1, 3, 1, 1, 1)))],
                       ("Game size [# of Actions]", "Time [s]"), "",
                       labels=values[0], font_size=13, logscale_x=False, limit_y=[-0, 520],
                       x_ticks=[0, 300, 600, 900, 1200], y_ticks=[0, 100, 200, 300, 400, 500], ax=ax)
    handles, labels = ax.get_legend_handles_labels()
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    plt.legend(handles, labels, bbox_to_anchor=(-1.4, 1.05, 2.4, .102), loc='lower left',
               ncol=6, borderaxespad=0., handlelength=1.3, handletextpad=0.2, mode="expand", borderpad=0.2)
    plt.show()


def create_cool_graphs_from_sequence(fname):
    data = load_from_file(fname)
    print(len(data))
    print(data)
    pick = ["Mini", "", "", "Small", "", "", "Bigger", "", "", "Big", "", "", "Leduc Hold'em", "One card poker"]
    color = ["r", "r", "r", "c", "c", "c", "y", "y", "y", "m", "m", "m", "b", "g"]
    data = [(x, pick[i], color[i]) for i, x in enumerate(data)]
    plot_convergence_curves(data,
                            ("Iteration", "Current strategy regret"), "Regret convergence curves for EFGs",
                            font_size=16, logscale_x=True)


def measure_time(function, **kwargs):
    start = time.time()
    function(**kwargs)
    end = time.time()
    return end - start


def combine_graph(fname, iterations=1000, rationality=1, splits=11):
    game_name = fname[fname.rfind("/") + 1:].replace(".efg", "")

    print(game_name)

    if path.exists(
            "data/comb_values/" + game_name + str(rationality) + (
                    "_spl" + str(splits) if splits != 11 else "") + ".val"):
        (qse_br, qse_qr, nash_br, nash_qr, qr, br, qrr, brr) = load_from_file(
            "data/comb_values/" + game_name + str(rationality) + (
                "_spl" + str(splits) if splits != 11 else "") + ".val")

        print(qr, br)
        print(qrr, brr)

        # PLOT ONLY QR
        # plot_combine_graph(
        #     [(qr, "COMB", "r-"), (qrr, "RQR", "r-"), ],
        #     ("Mixing parameter p", "Expected utility",), "",
        #     hlines=[(qse_qr, "GA", "r", "--", "o"), (nash_qr, "NASH", "r", "--", "o"), ], font_size=20,
        #     labels=[x for x in np.linspace(0, 1, splits)], location='best')
        #### PLOT EVERYTHING
        plot_combine_graph(
            [(qr, "COMB", "r--"), (br, "", "b-"), (qrr, "RQR", "r--"), (brr, "", "b-")],
            ("Mixing parameter p", "Expected utility",), "",
            hlines=[(qse_qr, "GA", "r", "--", "o"), (qse_br, "", "b", "-", "s"),
                    (nash_qr, "NASH", "r", "--", "o"), (nash_br, "", "b", "-", "s")], font_size=20,
            labels=[x for x in np.linspace(0, 1, splits)], location='best')

        # plot_exploitability_curves([(br, "COMB", "r-"), (qr, "COMB", "b-"), (brr, "CRQR", "r--"), (qrr, "CRQR", "b--")],
        #                         ("p", "Expected value for player 1",), "One card poker results for all p",)
    else:
        cfrbr = CFRBR(fname)
        cfrbr.solve(iterations)
        nash_str = cfrbr.average_strategy

        cfrqr = CFRQRCFV(fname, rationality=rationality)
        cfrqr.solve(iterations)
        qne_str = cfrqr.strategy

        # bn_seq = BestNE(fname)
        # bn_seq.solve(rationality=rationality)
        # best_nash = bn_seq.strategy_in_cfr_format()
        # print(cfrqr.best_response(0, best_nash)[0], cfrqr.quantal_response(0, best_nash, rationality)[0])

        bn_nash = SequenceNash(fname)
        bn_nash.solve(1)
        best_nash = bn_nash.strategy_in_cfr_format()
        # print(cfrqr.best_response(0, best_nash)[0], cfrqr.quantal_response(0, best_nash, rationality)[0])

        print("Before QSE")
        qse_seq = SequenceQSE(fname)
        qse_seq.solve(rationality=rationality)
        qse_str = qse_seq.strategy_in_cfr_format()

        combine = Combination(qne_str, nash_str, cfrqr.game)

        # print(cfrqr.quantal_response(0, qne_str, 1)[0])
        # print(cfrqr.quantal_response(0, nash_str, 1)[0])
        qse_qr = cfrqr.quantal_response(0, qse_str, rationality)[0]
        nash_qr = cfrqr.quantal_response(0, best_nash, rationality)[0]
        #
        # print(cfrqr.best_response(0, qne_str)[0])
        # print(cfrqr.best_response(0, nash_str)[0])
        qse_br = cfrqr.best_response(0, qse_str)[0]
        nash_br = cfrqr.best_response(0, best_nash)[0]

        qr = []
        br = []
        for s in combine.combination_space(splits=splits):
            qr.append(cfrqr.quantal_response(0, s, rationality)[0])
            br.append(cfrqr.best_response(0, s)[0])

        qrr = []
        brr = []
        rqr = RQR(fname, norm=False, rationality=rationality)
        for p in np.linspace(0, 1, splits):
            print(p)
            rqr.solve(iterations, p=p)
            qrr.append(cfrqr.quantal_response(0, rqr.average_strategy, rationality)[0])
            brr.append(cfrqr.best_response(0, rqr.average_strategy)[0])
        qrr[0] = qr[0]
        brr[0] = br[0]

        save_to_file((qse_br, qse_qr, nash_br, nash_qr, qr, br, qrr, brr),
                     "data/comb_values/" + game_name + str(rationality) + (
                         "_spl" + str(splits) if splits != 11 else "") + ".val")


def test_rnqr(fname, p, rationality):
    rqr = RQR(fname)
    rqr.solve(1000, p=p, verbose=0)

    cfr = CFRBR(fname)
    cfr.solve(1000, verbose=0)
    print(rqr.best_response(0, rqr.average_strategy)[0])
    print(rqr.quantal_response(0, rqr.average_strategy, rationality)[0])
    print(cfr.best_response(0, rqr.average_strategy)[0])
    print(cfr.quantal_response(0, rqr.average_strategy, rationality)[0])


def save_all_time_and_values(iterations, rationality):
    for dir in ["mini"]:
        print(dir)
        acc = []
        for fname in glob("data/" + dir + "/*"):
            print(fname)
            acc.append(save_time_and_values(fname, iterations, rationality))
            save_to_file(acc, "data/after_paper_only_values" + dir + ".val")


def save_time_and_values(fname, iterations, rationality):
    cfr = CFR(fname)
    cfr.solve(iterations)
    nash_str = cfr.average_strategy

    nash_lp = SequenceNash(fname)
    nash_lp.solve(1)
    exact_nash = nash_lp.strategy_in_cfr_format()

    cfrqr = CFRQRCFV(fname, rationality=rationality)
    cfrqr.solve(iterations)
    qne_str = cfrqr.strategy

    bn_seq = BestNE(fname)
    bn_seq.solve(rationality=rationality)
    best_nash = bn_seq.strategy_in_cfr_format()
    bn_qr = cfrqr.quantal_response(0, best_nash, rationality)[0]
    bn_br = cfrqr.best_response(0, best_nash)[0]

    qse_seq = SequenceQSE(fname)
    qse_seq.solve(rationality=rationality)
    qse_str = qse_seq.strategy_in_cfr_format()

    combine = Combination(qne_str, nash_str, cfrqr.game)

    qse_qr = cfrqr.quantal_response(0, qse_str, rationality)[0]
    nash_qr = cfrqr.quantal_response(0, exact_nash, rationality)[0]

    qse_br = cfrqr.best_response(0, qse_str)[0]
    nash_br = cfrqr.best_response(0, exact_nash)[0]

    qr = []
    br = []
    for s in combine.combination_space():
        qr.append(cfrqr.quantal_response(0, s, rationality)[0])
        br.append(cfrqr.best_response(0, s)[0])

    rqr = ADAPTQR(fname, norm=False, rationality=rationality)
    rqr.solve()
    qrr = cfrqr.quantal_response(0, rqr.average_strategy, rationality)[0]
    brr = cfrqr.best_response(0, rqr.average_strategy)[0]
    qrrr = rqr.best_comb_qr
    brrr = rqr.best_comb_br
    return nash_qr, nash_br, qse_qr, qse_br, qr, br, qrr, brr, qrrr, brrr, bn_qr, bn_br


def save_updated_times():
    times_rr_g = []
    times_cb_g = []
    times_qs_g = []
    for dir in ["mini", "small", "bigger", "big"]:
        data = load_from_file("data/times_and_values_" + dir + ".tav")
        times_rr = []
        times_cb = []
        times_qs = []
        for game in data:
            rr = game[8]
            times_rr.append(game[8])
            times_cb.append(game[9] + (rr / 11) * 2)
            times_qs.append(game[10])
        times_rr_g.append(-np.average(times_rr))
        times_cb_g.append(-np.average(times_cb))
        times_qs_g.append(-np.average(times_qs))
    print(times_rr_g, times_cb_g, times_qs_g)
    plot_convergence_curves([(times_rr_g, "RR", "r-"), (times_cb_g, "CB", "b-"), (times_qs_g, "QS", "c-")],
                            ("Size", "Time[s]"), "Time")


def plot_values(best_ne=True, loss=True, comb_value=0, show_concrete_game=None, max_gain_one=False, ax=None):
    results = [[], [], [], [], [], [], [], [], [], [], [], [], [], []]
    stdevs_all = [[], [], [], [], [], [], [], [], [], [], [], [], [], []]
    nes = []
    # for dir in ["big"]:
    indexes = {}
    for dir in ["mini", "small", "bigger", "big"]:
        indexes[dir] = np.zeros(11, dtype=np.int)
        data = load_from_file("data/after_paper_only_values" + dir + ".val")
        results_inter = [[], [], [], [], [], [], [], [], [], []]
        count2 = 0
        count = 0
        l = 0
        np.set_printoptions(suppress=True)
        for game in data:
            # if dir == "big":
            #     print(game)
            if show_concrete_game is not None:
                if l != show_concrete_game:
                    l += 1
                    continue
                l += 1
            # if game[3] > np.max(game[4]):
            #     continue
            # if abs(game[5][0] - game[5][-1]) < 0.01:
            #     count += 1
            #     continue
            thr = 0.0001
            count += 1
            if abs(game[1] - game[3]) < thr:  # and abs(game[0] - game[2]) < thr:
                continue
            count2 += 1
            # nash
            if best_ne and abs(game[1] - game[11]) < 10 ** -6:
                results_inter[0].append(game[10])
                results_inter[1].append(game[11])
            else:
                results_inter[0].append(game[0])
                results_inter[1].append(game[1])
            for i in range(2, 4):
                results_inter[i].append(game[i])
            max_index = np.argmax(game[4])
            results_inter[4].append(game[4][max_index])
            results_inter[5].append(game[5][max_index])
            for i in range(6, 8):
                results_inter[i].append(game[i - 2][-1])
            for i in range(8, 10):
                results_inter[i].append(game[i][comb_value])
                # results_inter[i].append(game[i-2])

            if count2 == 300:
                break
        stdevs = []
        std_size = 5
        for i in range(std_size):
            stdevs.append([])
        for i in range(len(results_inter[0])):
            m = -np.inf
            mtwo = -np.inf
            mtwo = results_inter[1][i]
            if not loss:
                mtwo = results_inter[1][i]
            for j in range(len(results_inter)):
                if loss:
                    if j not in [0, 2, 4, 6, 8]:
                        m = max(results_inter[j][i], m)
                else:
                    if j in [0, 2, 4, 6, 8]:
                        m = max(results_inter[j][i], m)
            index = 0
            for j in range(len(results_inter)):
                if loss:
                    if j not in [0, 2, 4, 6, 8]:
                        stdevs[index].append((results_inter[j][i] - mtwo) * -1)
                        index += 1
                else:
                    if j in [0, 2, 4, 6, 8]:
                        if m - mtwo < 0.01:
                            div = 1
                        else:
                            div = m - mtwo
                        if max_gain_one:
                            stdevs[index].append((results_inter[j][i] - mtwo) / div)
                        else:
                            stdevs[index].append((results_inter[j][i]) - mtwo)
                        index += 1
        # avg = np.average(stdevs, axis=1)
        # print(avg)
        copy_stdev = stdevs
        stdevs = np.std(stdevs, 1) / np.sqrt(len(stdevs[0]))
        # stdevs = np.std(stdevs, axis=1)
        # stdevs = np.multiply(stdevs2, st.t.ppf((1 + 0.95) / 2., len(results_inter[0])))
        results_inter = np.average(copy_stdev, 1)
        # results_inter = np.average(results_inter, 1)
        nes.append(results_inter[1])
        index = 0
        for i in range(10):
            if loss:
                if i not in [0, 2, 4, 6, 8]:
                    results[i].append(results_inter[index])
                    stdevs_all[i].append(stdevs[index])
                    index += 1
            else:
                if i in [0, 2, 4, 6, 8]:
                    results[i].append(results_inter[index])
                    stdevs_all[i].append(stdevs[index])
                    index += 1
        print(count, count2)
    # index = 1 if loss else 0
    # for j in range(len(results[index])):
    #     if loss:
    #         m = -np.inf
    #     else:
    #         m = np.inf
    #         mtwo = -np.inf
    #     for i in range(len(results)):
    #         if len(results[i]) < 1:
    #             continue
    #         if loss:
    #             m = nes[j]
    #         else:
    #             m = nes[j]
    #             mtwo = max(results[i][j], mtwo)
    #     for i in range(len(results)):
    #         if len(results[i]) < 1:
    #             continue
    #         results[i][j] -= m
    #         if not loss:
    #             if max_gain_one:
    #                 results[i][j] /= (mtwo - m)
    #         else:
    #             results[i][j] *= -1
    # for j in range(len(results[0])):
    #     maxim = -np.inf
    #     minim = np.inf
    #     for i in range(len(results)):
    #         print(results)
    #         maxim = max(maxim, results[i][j])
    #         minim = min(minim, results[i][j])
    #     for i in range(len(results)):
    #         results[i][j] = (results[i][j] - minim) / (maxim - minim)

    if loss:
        data_for_graph = [
            (results[5], "COMB", "ro", stdevs_all[5]),
            (results[3], "GA", "gs", stdevs_all[3]),
            (results[9], "RQR", "k>", stdevs_all[9]),
            (results[7], "QNE", "mD", stdevs_all[7]),
        ]
    else:
        data_for_graph = [
            (results[4], "COMB", "ro", stdevs_all[4]),
            (results[2], "GA", "gs", stdevs_all[2]),
            (results[8], "RQR", "k>", stdevs_all[8]),
            (results[6], "QNE", "mD", stdevs_all[6]),
            (results[0], "NASH", "mD", stdevs_all[0]),
        ]

    print(data_for_graph)

    # plot_convergence_curves(data_for_graph, ("Testing set", "Gain for player 1 against CLQR"),
    #                         "Opponent exploitation in EFGs",
    #                         labels=["1", "2", "3", "4"], font_size=16)
    if loss:
        ylabel = "Exploitability"
        title = ""
    else:
        ylabel = "Gain"
        title = ""
    plot_convergence_curves(data_for_graph, ("Testing set", ylabel), title,
                            labels=["1", "2", "3", "4"], font_size=20, leading_hlines=True,
                            ax=ax)


def create_regret_progression():
    regret_progressions = []
    iterations = 1000
    for dir in ["mini", "small", "bigger", "big"]:
        print(dir)
        cfrqr = CFRQRCFV("data/" + dir + "/00000.gbt", 1)
        cfrqr.solve(iterations)
        regret_progressions.append(cfrqr.regret_sum_accumulator)
        cfrqr = CFRQRCFV("data/" + dir + "/00027.gbt", 1)
        cfrqr.solve(iterations)
        regret_progressions.append(cfrqr.regret_sum_accumulator)
        cfrqr = CFRQRCFV("data/" + dir + "/00095.gbt", 1)
        cfrqr.solve(iterations)
        regret_progressions.append(cfrqr.regret_sum_accumulator)
    cfrqr = CFRQRCFV("data/leduc_holdem.efg", 1)
    cfrqr.solve(iterations)
    regret_progressions.append(cfrqr.regret_sum_accumulator)
    cfrqr = CFRQRCFV("data/one_card_poker.efg", 1)
    cfrqr.solve(iterations)
    regret_progressions.append(cfrqr.regret_sum_accumulator)
    save_to_file(regret_progressions, "data/efg_regret_progresion.rpg")


def test_small_rnr(fname, iterations=1000):
    # qrqr = QRQR(fname)
    # qrqr.solve(iterations=iterations)
    cfr = CFR(fname)
    cfr.initialize()
    cfr.initialize_random_strategy()
    strategy = cfr.strategy[0]
    strategy[0][0] = 1
    strategy[0][1] = 0
    strategy[0][2] = 0
    seq_nash = SequenceNash(fname)
    nash_value = -seq_nash.solve()
    bnes = BestNashStatic(fname, strategy)
    exploitation = []
    exploitability = []
    p = 0.4
    rnr = RNR(fname, strategy=strategy)
    rnr.solve(iterations=iterations, p=p, verbose=0)
    # print(rnr.average_strategy)
    # print(rnr.strategy)
    exploitability.append(rnr.best_response(0, rnr.average_strategy)[0])
    exploitation.append(rnr.against_fixed(rnr.average_strategy))

    exploitation = np.asarray(exploitation) - nash_value
    exploitability = -(np.asarray(exploitability) - nash_value)

    # print(exploitability)
    # print(exploitation)

    current = [[], [], []]
    average = [[], [], []]
    for cur, avg in zip(rnr.strategy_accumulator['cur'], rnr.strategy_accumulator['avg']):
        for i in range(3):
            current[i].append(cur[1][0][i])
            average[i].append(avg[1][0][i])

    rnr_gen = GenerateRNRGame(fname, 0, strategy)
    fname_rnr_gen = "data/3ND_FM_RNR.efg"
    rnr_gen.generate(fname_rnr_gen, p)
    seq_nash_rnr = SequenceNash(fname_rnr_gen)
    seq_nash_rnr.solve()
    # seq_nash_rnr.print_strategy()

    plot_convergence_curves([(current[0], "Current action", "b-",), (average[0], "Average action", "r-")],
                            axis_labels=["Iteration", "Action"], name="Actions 3ND_FM_RNR",
                            hlines=[(seq_nash_rnr.strategy_in_cfr_format()[1][0][0], "Optimal action", "g", "-")])

    return [rnr.fixed_strategy, rnr.average_strategy[1]]


def test_rnr(fname, iterations=1000, splits=11, equal=True, is_opponent_cfr=False):
    # qrqr = QRQR(fname)
    # qrqr.solve(iterations=iterations)
    cfr = CFR(fname)
    cfr.initialize()
    cfr.initialize_random_strategy()
    strategy = cfr.strategy[0]
    seq_nash = SequenceNash(fname)
    nash_value = -seq_nash.solve()
    bnes = BestNashStatic(fname, strategy)
    bnesh_value = -bnes.solve()
    exploitation = []
    exploitability = []
    strategies = []
    for p in np.linspace(0, 1, splits):
        # for p in [0, 1]:
        print(p)
        rnr = RNR(fname, strategy=strategy, is_opponent_cfr=is_opponent_cfr)
        rnr.solve(iterations=iterations, p=p, verbose=0)
        # print(rnr.average_strategy)
        exploitability.append(rnr.best_response(0, rnr.average_strategy)[0])
        exploitation.append(rnr.against_fixed(rnr.average_strategy))

    exploitation = np.asarray(exploitation) - nash_value
    exploitability = -(np.asarray(exploitability) - nash_value)

    print(exploitability)
    print(exploitation)

    plot_exploitability_curves([(exploitability, "RNR curve", "b-",), (exploitation,)],
                               axis_labels=["Exploitability", "Exploitation"], name="RNR curve" + fname, zerolines=True,
                               equal=True)
    return [rnr.fixed_strategy, rnr.average_strategy[1]]


def test_best_ne_static():
    fname = "data/one_card_poker.efg"
    qrqr = QRQR(fname)
    qrqr.solve()
    strategy = qrqr.strategy[1]
    bnes = BestNashStatic(fname, strategy)
    print(bnes.solve())
    print(bnes.obj.x)


def excel_test(game_name, sheet_name, strategy=None):
    generate = ExcelGenerate(game_name)
    generate.generate(sheet_name, strategy=strategy)


def test_rnr_creation(fname_in, fname_out, splits, rnr_iterations, generate_new, is_opponent_cfr, strategy=None):
    cfr = CFR(fname_in)
    cfr.initialize()
    if strategy is None:
        cfr.initialize_random_strategy()
        strategy = cfr.strategy[0]
    # strategy = {0: [0.5, 0.3, 0.2]}
    # if generate_new:
    #     save_to_file(strategy, fname_out.format(2))
    # strategy = load_from_file(fname_out.format(2))
    rnr_original = RNR(fname_in, strategy, cfr_player=0)
    rnr_original.solve(iterations=1)
    seq_nash_in = SequenceNash(fname_in)
    nash_value = -seq_nash_in.solve(0)
    cfr = CFR(fname_in)
    cfr.solve(1)
    results = [[], [], [], []]
    for p in np.linspace(0, 1, splits):
        print("{:0.2f}".format(p))
        # combined_strategy = combine.combine_strategies(p)
        # strat = rnr_original.best_response(1, combined_strategy)[1]
        # print(rnr_original.best_response(1, combined_strategy)[0])
        # print(rnr_original.best_response(0, strat)[0])
        # print(rnr_original.against_fixed(strat))
        # results[0].append(rnr_original.best_response(0, combined_strategy)[0])
        # results[1].append(rnr_original.against_fixed(combined_strategy))

        generate = GenerateRNRGame(fname_in, 1, [strategy])
        if generate_new:
            generate.generate(fname_out.format(p), [1 - p, p])
        game = ExtensiveGame()
        game.load(fname_out.format(p))
        seq_nash = SequenceNash(fname_out.format(p))
        optimal_strategy = seq_nash.strategy_in_cfr_format()
        # rnr = RNR(fname_in, strategy=strategy, is_opponent_cfr=is_opponent_cfr)
        # rnr.print_strategy(optimal_strategy, 1, decimal_points=6)
        # rnr.solve(iterations=rnr_iterations, p=p, verbose=0)
        # rnr.normalize_average_strategy_top(1)
        # rnr_strategy = rnr.average_strategy
        # rnr.print_strategy(rnr_strategy, 1, decimal_points=6)
        # results[0].append(rnr_original.best_response(1, rnr_strategy)[0])
        # results[1].append(rnr_original.against_fixed(rnr_strategy))
        results[2].append(rnr_original.best_response(1, optimal_strategy)[0])
        results[3].append(rnr_original.against_fixed(optimal_strategy))
        evalueate_strategy([optimal_strategy[0], strategy], 0, fname_in)

    for i, p in enumerate(np.linspace(0, 1, splits)):
        print("With p = {:0.3f}".format(p))
        # print("{:0.3f} - RNR exploitability".format(results[0][i]))
        # print("{:0.3f} - RNR exploitation".format(results[1][i]))
        print("{:0.3f} - Optimal exploitability".format(results[2][i]))
        print("{:0.3f} - Optimal exploitation".format(results[3][i]))

    results = np.asarray(results)
    # results[0] = [-x + nash_value for x in results[0]]
    # results[1] = [x - nash_value for x in results[1]]
    results[2] = [-x + nash_value for x in results[2]]
    results[3] = [x - nash_value for x in results[3]]
    print(results)
    # plot_exploitability_curves([(results[0], "RNR curve", "bo-",), (results[1],)],
    #                            axis_labels=["Exploitability", "Exploitation"], name="RNR curve" + fname_in,
    #                            zerolines=True,
    #                            equal=True, save="images/rnr/rnr.png")
    figure = plt.figure(figsize=(16, 9))
    fig = figure.add_subplot(111)
    # fig.plot(results[0], label="Rnr br")
    # fig.plot(results[1], label="Rnr fix")
    fig.plot(results[2], label="Opt br")
    fig.plot(results[3], label="Opt fix")
    fig.legend()
    plt.show()
    figure.clf()
    # plot_exploitability_curves(
    #     [(results[2], "Optimal", "bo-",), (results[3],), (results[0], "Mix", "ro-",), (results[1],)],
    #     axis_labels=["Exploitability", "Exploitation"], name="Optimal curve" + fname_in,
    #     zerolines=True,
    #     equal=True, save="images/rnr/random/opt.png")

    # plot_exploitability_curves([(results[0], "RNR curve", "bo:",), (results[1],),
    #                             (results[2], "Optimal", "ro--",), (results[3],)],
    #                            axis_labels=["Exploitability", "Exploitation"],
    #                            name="RNR curves solved by CFR and LP" + fname_in,
    #                            zerolines=True,
    #                            equal=True, save="images/rnr/both.png")
    # plt.show()


def test_dbr_creation(fname_in, fname_out, splits, rnr_iterations):
    cfr = CFR(fname_in)
    cfr.initialize()
    cfr.initialize_random_strategy()
    strategy = cfr.strategy[0]
    rnr_original = RNR(fname_in, strategy)
    rnr_original.solve(iterations=1)
    seq_nash_in = SequenceNash(fname_in)
    nash_value = -seq_nash_in.solve()
    bnes = BestNashStatic(fname_in, strategy)
    bnesh_value = -bnes.solve()
    results = [[], [], [], []]
    for p in np.linspace(0, 1, splits):
        print("{:0.1f}".format(p))
        generate = GenerateDbrGame(fname_in, 0, strategy, {})
        generate.generate(fname_out.format(p), p)
        seq_nash = SequenceNash(fname_out.format(p))
        seq_nash.solve()
        optimal_strategy = seq_nash.strategy_in_cfr_format()
        results[2].append(rnr_original.best_response(0, optimal_strategy)[0])
        results[3].append(rnr_original.against_fixed(optimal_strategy))
    for i, p in enumerate(np.linspace(0, 1, splits)):
        print("With p = {:0.1f}".format(p))
        print("{:0.3f} - Optimal exploitability".format(results[2][i]))
        print("{:0.3f} - Optimal exploitation".format(results[3][i]))
    # results = np.asarray(results)
    # results[0] = -results[0] + nash_value
    # results[1] = results[1] - bnesh_value
    # results[2] = -results[2] + nash_value
    # results[3] = results[3] - bnesh_value
    # plot_exploitability_curves([(results[0], "RNR curve", "bo-",), (results[1],)],
    #                            axis_labels=["Exploitability", "Exploitation"], name="RNR curve" + fname_in,
    #                            zerolines=True,
    #                            equal=True, save="images/rnr/rnr.png")
    # plot_exploitability_curves([(results[2], "Optimal", "bo-",), (results[3],)],
    #                            axis_labels=["Exploitability", "Exploitation"], name="Optimal curve" + fname_in,
    #                            zerolines=True,
    #                            equal=True, save="images/rnr/opt.png")
    # plot_exploitability_curves([(results[0], "RNR curve", "bo:",), (results[1],),
    #                             (results[2], "Optimal", "ro--",), (results[3],)],
    #                            axis_labels=["Exploitability", "Exploitation"],
    #                            name="RNR curves solved by CFR and LP" + fname_in,
    #                            zerolines=True,
    #                            equal=True, save="images/rnr/both.png")


def test_adapt_qr(fname, rationality=1, iterations=2000, splits=11):
    game_name = fname[fname.rfind("/") + 1:].replace(".efg", "")

    (qse_br, qse_qr, nash_br, nash_qr, qr, br, qrr, brr) = load_from_file(
        "data/comb_values/" + game_name + str(rationality) + (
            "_spl" + str(splits) if splits != 11 else "") + ".val")
    adapt = ADAPTQR(fname, rationality=rationality)
    adapt.solve(iterations, verbose=3, save_progression=True)
    results = [[], [], [], []]
    for response in adapt.progression:
        for i in range(4):
            results[i].append(response[i])
    print(adapt.best_comb_br)
    print(adapt.best_comb_qr)
    plot_convergence_curves([(results[0], "BRA", "b-",), (results[1], "BRC", "r--"), (results[2], "QRA", "g-",),
                             (results[3], "QRC", "y--"), (qr, "COMB", "r-o"), (br, "COMB", "b-s"),
                             (qrr, "CRQR", "r--o"), (brr, "CRQR", "b--s")],
                            axis_labels=["Iteration", "Value"], name=game_name + "adaptive p",
                            hlines=[(qse_qr, "GA", "r", "-.", "o"), (qse_br, "GA", "b", "-.", "s"),
                                    (nash_qr, "NASH", "r", ":", "o"), (nash_br, "NASH", "b", ":", "s"),
                                    (adapt.best_comb_qr[1], "ADAPTQR", "c", ":", "o"),
                                    (adapt.best_comb_br[1], "ADAPTBR", "m", ":", "s"),
                                    (adapt.best_comb_qr[-1], "ADAPTQR", "y", ":", "o"),
                                    (adapt.best_comb_br[-1], "ADAPTBR", "g", ":", "s")])

    # plot_convergence_curves(
    #     [(qr, "COMB", "r-o"), (br, "COMB", "b-s"), (qrr, "CRQR", "r--o"), (brr, "CRQR", "b--s")],
    #     ("p", "Expected value for player 1",), game_name + " for all p",
    #     hlines=[(qse_qr, "GA", "r", "-.", "o"), (qse_br, "GA", "b", "-.", "s"),
    #             (nash_qr, "NASH", "r", ":", "o"), (nash_br, "NASH", "b", ":", "s")], font_size=16,
    #     labels=["{:.1f}".format(x) for x in np.linspace(0, 1, 11)], location=(1.02, 0))


def save_adaptqr_values_dir(iterations, rationality):
    for dir in ["big"]:
        print(dir)
        acc = []
        i = 0
        for fname in glob("data/" + dir + "/*"):
            if i > 100:
                i += 1
                break
            print(fname)
            acc.append(save_adaptqr_values(fname, iterations, rationality))
            save_to_file(acc, "data/adaptqr_values_comb_from_0_to_100" + dir + ".avs")


def save_adaptqr_values(fname, iterations, rationality):
    time_start = time.time()
    adapt = ADAPTQR(fname, rationality=rationality)
    adapt.solve(iterations, verbose=0, save_progression=False)
    duration = time.time() - time_start
    return adapt.return_values() + (duration,)


def rnrd_tests(fname, p=0, rnr_iterations=1000, subgame_iterations=100, trunk_iterations=100, average_from=None,
               cfr_player=1, cfr_verbose=0, cfrd_verbose=1, compact=False):
    strategy = load_from_file(fname + "_strategy.strat")

    if average_from is None:
        average_from = int(trunk_iterations / 2)
    if not compact:
        print("RNR-D")
    rnrd = RNRD(fname, p=p, strategy=copy.deepcopy(strategy), cfr_player=cfr_player)
    rnrd.solve(trunk_iterations, subgame_iterations=subgame_iterations, average_from=average_from, verbose=cfrd_verbose)
    rnrd_strategy = [{}, rnrd.average_strategy[1]] if cfr_player == 1 else [rnrd.average_strategy[0], {}]
    # rnrd_strategy = rnrd.strategy
    if not compact:
        print(rnrd_strategy)

    if not compact:
        print("RNR-D-Extension")
    rnrwf = RNR(fname + ".efg", strategy=strategy, cfr_player=cfr_player)
    rnrwf.solve(rnr_iterations, p=p, fixed=rnrd_strategy, verbose=cfr_verbose)
    rnrd_strategy = rnrwf.average_strategy
    if not compact:
        print("RNR")
    rnr = RNR(fname + ".efg", strategy=strategy, cfr_player=cfr_player)
    rnr.solve(rnr_iterations, p=p, verbose=cfr_verbose)

    seq_nash = SequenceNash(fname + ".efg")
    game_value = -seq_nash.solve()

    # rnr.print_responses()
    if not compact:
        print("\nRNR", rnr.against_fixed(rnr.average_strategy), end=", ")
        print(rnr.best_response(1 - cfr_player, rnr.average_strategy)[0])
        rnr.print_strategy(rnr.average_strategy, cfr_player, compact=True, decimal_points=6)
        print("RNRD", rnr.against_fixed(rnrd_strategy), end=", ")
        print(rnrwf.best_response(1 - cfr_player, rnrd_strategy))
        rnr.print_strategy(rnrd_strategy, cfr_player, compact=True, decimal_points=6)
        rnr.print_strategy(rnrd_strategy, 1 - cfr_player, compact=True, decimal_points=6)
    else:
        print("RNR", rnr.against_fixed(rnr.average_strategy), end=", ")
        print(rnr.best_response(1 - cfr_player, rnr.average_strategy)[0], end=' ')
        print("RNRD", rnr.against_fixed(rnrd_strategy), end=", ")
        print(rnr.best_response(1 - cfr_player, rnrd_strategy)[0], end=" ")
        print("Game value", game_value, end=" ")
        print("Subgames: ", len(load_from_file(fname + "_subgames.sg")), "\n")

    print(rnrd.average_root_counterfactual_values)
    print(rnrd.subgame_root_counterfactual_values)
    # cfr = CFR("data/hcg2.efg")
    # cfr.solve(1000)
    # print(cfrd.strategy)
    # print(cfrd.counterfactual_values)
    # print(cfrd.counterfactual_regret)


def decompose_game_at_depth_two(fname, return_subgames=True):
    game = ExtensiveGame()
    game.load("data/" + fname + ".efg")
    subgame_roots = []
    for i, first_level_child in enumerate(game.root.children):
        subgame_roots.append([])
        for second_level_child in first_level_child.children:
            subgame_roots[i].append(second_level_child)
    subgames = []
    if return_subgames:
        for roots in subgame_roots:
            subgames.append(ExtensiveSubgame(roots))
        save_to_file(subgames, "data/" + fname + "_subgames.sg")
    else:
        trunk = ExtensiveTrunk(game, subgame_roots)
        save_to_file(trunk, "data/" + fname + "_trunk.trunk")


def decompose_game_at_depth_one(fname, return_subgames=True):
    game = ExtensiveGame()
    game.load("data/" + fname + ".efg")
    subgame_roots = [[]]
    for i, first_level_child in enumerate(game.root.children):
        subgame_roots[0].append(first_level_child)
    subgames = []
    if return_subgames:
        for roots in subgame_roots:
            subgames.append(ExtensiveSubgame(roots))
        save_to_file(subgames, "data/" + fname + "_subgames.sg")
    else:
        trunk = ExtensiveTrunk(game, subgame_roots)
        save_to_file(trunk, "data/" + fname + "_trunk.trunk")


def cfrd_tests(fname, cfrd_iterations=200, subgame_iterations=100, average_from=50, cfrd_verbose=0, full_iterations=1000, full_verbose=0):
    # print("CFR-D")
    cfrd = CFRD(fname)
    cfrd.solve(cfrd_iterations, subgame_iterations=subgame_iterations, average_from=average_from, verbose=cfrd_verbose)
    cfrd_strategy = cfrd.average_strategy
    # print(cfrd_strategy)

    # print("CFR")
    cfr = CFR(fname + ".efg")
    cfr.solve(full_iterations, verbose=full_verbose)
    cfr_strategy = cfr.average_strategy
    # for i in range(2):
    #     for key in cfrd_strategy[i].keys():
    #         cfrd_strategy[i][key] = cfr_strategy[i][key]
    # print(cfrd_strategy)
    # print("CFR-D-extension")
    cfrwf = CFR(fname + ".efg")
    cfrwf.solve(full_iterations, fixed=[{}, cfrd_strategy[1]], verbose=full_verbose)
    cfrwf_strategy = cfrwf.average_strategy

    seq_nash = SequenceNash(fname + ".efg")
    game_value = -seq_nash.solve(1)

    cfr_best_response = cfr.best_response(0, cfr_strategy)[0]
    cfrwf_best_response = cfr.best_response(0, cfrwf_strategy)[0]
    print("Game value", game_value, cfrwf.counterfactual_values[cfrwf.game.root])
    print("\nCFRD", cfr.best_response(0, cfrwf_strategy)[0], end=", ")
    print("CFR", cfr.best_response(0, cfr_strategy)[0], end=", ")
    cfrwf.strategy = cfrwf.average_strategy
    cfrwf.compute_counterfactual_values()

    # print("CFR-D strategy")
    # cfr.print_strategy(cfrwf_strategy, 0, compact=True, decimal_points=6)
    # cfr.print_strategy(cfrwf_strategy, 1, compact=True, decimal_points=6)
    # print("CFR strategy")
    # cfr.print_strategy(cfr_strategy, 0, compact=True, decimal_points=6)
    # cfr.print_strategy(cfr_strategy, 1, compact=True, decimal_points=6)

    # subgame = cfrd.subgames[0]
    # subgame.change_reaches(cfrd.reaches)
    # subgame.change_counterfactual_values(cfrd.counterfactual_values)
    # subgame.gadget_game()
    #
    # subgame.save_to_file("data/test_subgame.efg")
    save_to_file(cfrd_strategy, fname.rstrip("leduc_holdem") + "cfrd_strategies/cfrd_strategy_dit_" + str(cfrd_iterations) + "_sit_" + str(subgame_iterations) + "_fit_" + str(full_iterations))
    save_to_file(cfrwf_strategy, fname.rstrip("leduc_holdem") + "cfrd_strategies/cfrwf_strategy_dit_" + str(cfrd_iterations) + "_sit_" + str(subgame_iterations) + "_fit_" + str(full_iterations))
    save_to_file(cfr_strategy, fname.rstrip("leduc_holdem") + "cfrd_strategies/cfr_strategy_dit_" + str(cfrd_iterations) + "_sit_" + str(subgame_iterations) + "_fit_" + str(full_iterations))


def best_nash_test(fname):
    values = []
    for epsilon in np.linspace(0, 3.1, 311):
        strategy = load_from_file(fname + "_strategy.strat")
        best_nash = BestNashStatic(fname + ".efg", strategy=strategy, epsilon=epsilon)
        best_nash.solve()
        rnr = RNR(fname + ".efg", strategy, 1)
        rnr.solve(1)
        nash = SequenceNash(fname + ".efg")
        nash.solve()
        print(rnr.best_response(0, best_nash.strategy_in_cfr_format())[0])
        print(rnr.against_fixed(best_nash.strategy_in_cfr_format()))
        values.append(rnr.against_fixed(best_nash.strategy_in_cfr_format()))
    plt.plot(values)
    plt.show()


def iigs5():
    cfr = CFR("data/iigs5.efg")
    cfr.solve(256, verbose=1)


def iigs6():
    cfr = CFR("data/iigs6.efg")
    cfr.solve(64, verbose=1)


def decompose_leduc(return_subgames=True):
    game = ExtensiveGame()
    game.load("data/leduc_holdem.efg")
    subgame_roots = {}
    decompose_leduc_step(game.root, subgame_roots, [])
    subgames = []
    if return_subgames:
        for roots in subgame_roots.values():
            subgames.append(ExtensiveSubgame(roots))
        save_to_file(subgames, "data/leduc_holdem/leduc_holdem_subgames.sg")
    else:
        trunk = ExtensiveTrunk(game, list(subgame_roots.values()))
        save_to_file(trunk, "data/leduc_holdem/leduc_holdem_trunk.trunk")


def decompose_leduc_step(node, subgame_roots, sequence):
    if len(sequence) > 0 and node.player == 2:
        if tuple(sequence) not in subgame_roots:
            subgame_roots[tuple(sequence)] = []
        subgame_roots[tuple(sequence)].append(node)
        return
    elif node.player == 2:
        for child in node.children:
            decompose_leduc_step(child, subgame_roots, sequence)
    elif node.player == 3:
        return
    else:
        for i, child in enumerate(node.children):
            new_sequence = copy.deepcopy(sequence)
            new_sequence.append(i)
            decompose_leduc_step(child, subgame_roots, new_sequence)


def decompose_rnr_leduc(return_subgames=True):
    # p = 0.5
    # opponent_strategy = load_from_file("data/bad_strategies/leduc_holdem/leduc_holdem_" + str(i) + "_iterations.strat")[1]

    # temp_fname = "data/rnr_leduc.efg"
    # rnr_game = ExtensiveGame()
    # rnr_generator = GenerateRNRGameNotFixed(leduc_file)
    # rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
    # rnr_game.load(temp_fname)

    leduc_file = "data/leduc_holdem.efg"
    game = ExtensiveGame()
    game.load(leduc_file)
    subgames = {}
    decompose_rnr_leduc_step(game.root, subgames, [])


def decompose_rnr_leduc_step(node, subgame_roots, sequence):
    return


def decompose_all_small():
    for i in range(1000):
        name = '{:05d}'.format(i)
        auto_decompose_in_depth("data/small", name, 2)


def auto_decompose_in_depth(directory, name, depth):
    # Create directory if it does not already exists
    # directory_name = directory + "_decomp/" + name
    # if not os.path.exists(directory_name):
    #     os.mkdir(directory_name)

    # Create file names required to load and save
    file_name_open = directory + "_decomp/" + name + ".efg"
    file_name_game_out = directory + "_decomp/" + name + ".efg"
    file_name_trunk = directory + "_decomp/" + name + "_trunk.trunk"
    file_name_subgames = directory + "_decomp/" + name + "_subgames.sg"

    # Load the game
    print(file_name_open)
    game = ExtensiveGame()
    game.load(file_name_open)

    actual_depth = 0
    decompose_depth = depth
    node_to_sequences = {}

    auto_decompose_step(node=game.root, depth=actual_depth, decompose_depth=decompose_depth, sequence=['', ''],
                        node_to_sequences=node_to_sequences)

    sequences_to_nodes = [{}, {}]
    for node, sequences in node_to_sequences.items():
        for player in range(2):
            if sequences[player] not in sequences_to_nodes[player]:
                sequences_to_nodes[player][sequences[player]] = []
            sequences_to_nodes[player][sequences[player]].append(node)
    processed = set()
    subgames = []
    for node in node_to_sequences.keys():
        if node in processed:
            continue
        processed.add(node)
        subgame = [node]
        new_sequences = [[node_to_sequences[node][0]], [node_to_sequences[node][1]]]
        closed_sequences = [set(), set()]
        while True:
            if len(new_sequences[0]) == 0:
                if len(new_sequences[1]) == 0:
                    break
                else:
                    player = 1
            else:
                player = 0
            new_sequence = new_sequences[player].pop()
            if new_sequence in closed_sequences[player]:
                continue
            for new_node in sequences_to_nodes[player][new_sequence]:
                if new_node not in subgame:
                    subgame.append(new_node)
                    processed.add(new_node)
                    for i in range(2):
                        new_sequences[i].append(node_to_sequences[new_node][i])
        subgames.append(subgame)
    print("=== Subgames generated")
    game.save_to_file(file_name_game_out)
    print("=== Original game saved")
    save_subgames = []
    for roots in subgames:
        save_subgames.append(ExtensiveSubgame(roots))
    save_to_file(save_subgames, file_name_subgames)
    print("=== Subgames saved")
    trunk = ExtensiveTrunk(game, subgames)
    save_to_file(trunk, file_name_trunk)
    print("=== Trunk saved")


def auto_decompose_step(node, depth, decompose_depth, sequence, node_to_sequences):
    if depth == decompose_depth:
        new_sequence = [sequence[0], sequence[1]]
        if node.player < 2:
            new_sequence[node.player] += str(node.i_set)
        node_to_sequences[node] = new_sequence[0], new_sequence[1]
        return
    for i, child in enumerate(node.children):
        new_sequence = [sequence[0], sequence[1]]
        if node.player < 2:
            new_sequence[node.player] += str(node.i_set) + str(i)
        auto_decompose_step(node=child, depth=depth + 1, decompose_depth=decompose_depth, sequence=new_sequence,
                            node_to_sequences=node_to_sequences)


def test_comb_vs_random(folder, data_type="random", number_of_games=None):
    avg_combine_br = []
    avg_combine_qr = []
    avg_random_br = []
    avg_random_qr = []
    file_names = glob(folder + "/*")
    if number_of_games is None:
        number_of_games = len(file_names)
    for fname in file_names[0:number_of_games]:
        print(fname)
        cfrqr = CFRQRCFV(fname)
        cfrqr.initialize()
        combine_space = test_only_comb(fname)
        if data_type == "random":
            rnd_space = random_space(fname)
        elif data_type == "nash":
            rnd_space = random_nash_space(fname)
        elif data_type == "qrnash":
            rnd_space = random_qrnash_space(fname)
        elif data_type == "br_to_uniform":
            rnd_space = nash_br_to_uniform_space(fname)
        elif data_type == "level_k":
            rnd_space = level_k_space(fname)
        combine_br = []
        combine_qr = []
        random_br = []
        random_qr = []
        for combine_strategy, random_strategy in zip(combine_space, rnd_space):
            combine_qr.append(cfrqr.quantal_response(0, combine_strategy, 1)[0])
            random_qr.append(cfrqr.quantal_response(0, random_strategy, 1)[0])
            combine_br.append(cfrqr.best_response(0, combine_strategy)[0])
            random_br.append(cfrqr.best_response(0, random_strategy)[0])
        avg_combine_br.append(combine_br)
        avg_combine_qr.append(combine_qr)
        avg_random_br.append(random_br)
        avg_random_qr.append(random_qr)
    avg_combine_qr = np.average(avg_combine_qr, axis=0)
    avg_combine_br = np.average(avg_combine_br, axis=0)
    avg_random_qr = np.average(avg_random_qr, axis=0)
    avg_random_br = np.average(avg_random_br, axis=0)
    plt.plot(avg_combine_br, label="COMB BR")
    plt.plot(avg_combine_qr, label="COMB QR")
    plt.plot(avg_random_br, label="RAND BR")
    plt.plot(avg_random_qr, label="RAND QR")
    plt.legend(loc='best')
    plt.title("Comparison of our approach with random strategies")
    plt.xlabel("Value of p")
    plt.ylabel("Expected utility")
    plt.show()


def random_space(file_name, splits=11):
    cfrbr = CFRBR(file_name)
    cfrbr.initialize()

    cfrbr.initialize_random_strategy()
    strategy_one = copy.deepcopy(cfrbr.strategy)

    cfrbr.initialize_random_strategy()
    strategy_two = cfrbr.strategy

    combine = Combination(strategy_one, strategy_two, cfrbr.game)
    return combine.combination_space(splits)


def random_nash_space(file_name, splits=11):
    cfrbr = CFRBR(file_name)
    cfrbr.solve(1000)

    strategy_two = copy.deepcopy(cfrbr.strategy)

    cfrbr.initialize_random_strategy()
    strategy_one = cfrbr.strategy

    combine = Combination(strategy_one, strategy_two, cfrbr.game)
    return combine.combination_space(splits)


def random_qrnash_space(file_name, splits=11):
    cfrbr = CFRQRCFV(file_name)
    cfrbr.solve(1000)

    strategy_one = copy.deepcopy(cfrbr.strategy)

    cfrbr.initialize_random_strategy()
    strategy_two = cfrbr.strategy

    combine = Combination(strategy_one, strategy_two, cfrbr.game)
    return combine.combination_space(splits)


def nash_br_to_uniform_space(file_name, splits=11):
    cfrbr = CFRBR(file_name)
    cfrbr.solve(1000)

    strategy_two = copy.deepcopy(cfrbr.strategy)

    cfrbr.initialize_strategy()
    strategy_one = cfrbr.best_response(1, cfrbr.strategy)[1]

    combine = Combination(strategy_one, strategy_two, cfrbr.game)
    return combine.combination_space(splits)


def level_k_space(file_name, splits=11):
    cfrbr = CFRBR(file_name)
    cfrbr.solve(1000)

    strategy_two = copy.deepcopy(cfrbr.strategy)

    cfrbr.initialize_strategy()
    strategy_one = cfrbr.best_response(0, cfrbr.strategy)[1]
    strategy_one = cfrbr.best_response(1, strategy_one)[1]
    strategy_one = cfrbr.best_response(0, strategy_one)[1]
    strategy_one = cfrbr.best_response(1, strategy_one)[1]

    combine = Combination(strategy_one, strategy_two, cfrbr.game)
    return combine.combination_space(splits)


def compact_decomp_test(name):
    cfr = CFR("data/" + name + "_decomp/" + name + ".efg")
    cfr.initialize()
    cfr.initialize_random_strategy()
    save_to_file(cfr.strategy[0], "data/" + name + "_decomp/" + name + "_strategy.strat")

    auto_decompose_in_depth("data/" + name, name, 1)

    rnrd_tests("data/" + name + "_decomp/" + name, rnr_iterations=500, trunk_iterations=500,
               subgame_iterations=20, average_from=100, p=0, cfr_player=1, cfr_verbose=0, cfrd_verbose=1,
               compact=True)


def safe_cfrbr(fname, iterations=1000, cfr_player=1, verbose=0, subgame_iteration=1000, dl=False):
    if dl:
        cfrbrd = CFRDLBRD(fname, cfr_player=cfr_player)
    else:
        cfrbrd = CFRBRD(fname, cfr_player=cfr_player)
    cfrbrd.solve(iterations=iterations, verbose=verbose, subgame_iterations=subgame_iteration)
    if verbose > 0:
        print(cfrbrd.strategy)
        print(cfrbrd.average_strategy)
        print(cfrbrd.average_root_counterfactual_values)
        # print(cfrbrd.subgame_root_counterfactual_values)

    cfrbr = CFRBR(fname + ".efg", cfr_player=cfr_player)
    cfrbr.solve(iterations=iterations)
    if verbose > 0:
        print(cfrbr.strategy)
        print(cfrbr.average_strategy)

    if verbose > 0:
        print()

    cfrbr_fix = CFRBR(fname + ".efg", cfr_player=cfr_player)
    cfrbr_fix.solve(fixed=[{}, cfrbrd.average_strategy[1]], iterations=iterations)
    if verbose:
        print(cfrbr_fix.average_strategy)

    nash = SequenceNash(fname + ".efg")
    solution = nash.solve()
    if verbose > 0:
        print(nash.strategy_in_cfr_format())
    threshold = 0.0001

    if abs(cfrbr_fix.counterfactual_values[cfrbr.game.root][0] - solution) > threshold:
        print(fname, end=" ")
        print("CFRBR", cfrbr.counterfactual_values[cfrbr.game.root], end=" ")
        print("CFRBRD", cfrbrd.counterfactual_values[cfrbrd.game.root], end=" ")
        print("CFRBRD", cfrbrd.average_strategy_root_cfvs(), end=" ")
        print("NASH", solution, end=" ")
        print("CFRBR FIX:", cfrbr_fix.counterfactual_values[cfrbr_fix.game.root])


def nash_time(fname):
    time_start = time.time()
    seq_nash = SequenceNash(fname)
    seq_nash.solve()
    time_end = time.time()
    print(time_end - time_start)
    print(seq_nash.game.node_count)


def cfrqr_time(fname):
    time_start = time.time()
    cfrqr = CFRQRCFV(fname)
    cfrqr.solve(1000)
    time_end = time.time()
    print(time_end - time_start)
    print(cfrqr.game.node_count)


def rqr_time(fname):
    time_start = time.time()
    cfrqr = ADAPTQR(fname)
    cfrqr.solve(1000)
    time_end = time.time()
    print(time_end - time_start)
    print(cfrqr.game.node_count)


def comb_time(fname):
    time_start = time.time()
    cfrqr = CFRQRCFV(fname)
    cfrqr.solve(1000)
    cfrbr = CFRBR(fname)
    cfrbr.solve()
    comb = Combination(cfrqr.strategy, cfrbr.strategy, cfrbr.game)
    comb.combination_space(11)
    time_end = time.time()
    print(time_end - time_start)


def rnr_values():
    fname_in = "data/small/01724.gbt"
    fname_out = "data/rnr/small/01724_{:0.1f}.efg"
    splits = 11
    nash_strategy_one = [{0: [1 / 8, 7 / 8], 1: [0, 1], 2: [1 / 7, 6 / 7]},
                         {0: [0.75, 0.25], 1: [0.0, 1.0], 2: [0.0, 1.0]}]
    nash_strategy_two = [{0: [0, 1], 1: [0.5, 0.5], 2: [1 / 4, 1 / 4]},
                         {0: [0.75, 0.25], 1: [0.0, 1.0], 2: [0.0, 1.0]}]
    cfr_original = CFR(fname_in)
    cfr_original.solve(1)
    cfr_original.print_cfv_values_at_depth_with_strategy(nash_strategy_one, 2)
    cfr_original.print_cfv_values_at_depth_with_strategy(nash_strategy_two, 2)
    for p in np.linspace(0, 1, splits):
        print("#################")
        print("P value: ", p)
        seq_nash = SequenceNash(fname_out.format(p))
        seq_nash.solve()
        optimal_strategy = seq_nash.strategy_in_cfr_format()
        cfr = CFR(fname_out.format(p))
        cfr.solve(1)
        print("Correct values")
        cfr.print_cfv_values_at_depth_with_strategy(optimal_strategy, 3)
        print("Nash one")
        cfr.print_cfv_values_at_depth_with_strategy(nash_strategy_one, 3)
        print("Nash two")
        cfr.print_cfv_values_at_depth_with_strategy(nash_strategy_two, 3)


def test_generate_rnr_game(fname_in, fname_out, fixed_player, strategies, p):
    generate = GenerateRNRGame(fname_in, fixed_player, strategies)

    generate.generate(fname_out, p)


def test_br_to_strategy(strategy, fname, player):
    cfr = CFR(fname)
    cfr.initialize()
    full_strategy = cfr.strategy
    full_strategy[1 - player] = strategy
    br = cfr.best_response(1 - player, full_strategy)
    print(br[0] if player == 0 else -br[0])


def rnr_decompose_in_depth(directory, name, depth):
    # Create directory if it does not already exists
    # directory_name = directory + "_decomp/" + name
    # if not os.path.exists(directory_name):
    #     os.mkdir(directory_name)

    # Create file names required to load and save
    file_name_open = directory + "_decomp/" + name + ".efg"
    file_name_game_out = directory + "_decomp/" + name + ".efg"
    file_name_trunk = directory + "_decomp/" + name + "_trunk.trunk"
    file_name_subgames = directory + "_decomp/" + name + "_subgames.sg"

    # Load the game
    print(file_name_open)
    game = ExtensiveGame()
    game.load(file_name_open)

    actual_depth = 0
    decompose_depth = depth
    subgames = []

    for i in range(4):
        subgames.append([])

    rnr_decompose_step(node=game.root, depth=actual_depth, decompose_depth=decompose_depth,
                       node_to_sequences=subgames)

    print("=== Subgames generated")
    game.save_to_file(file_name_game_out)
    print("=== Original game saved")
    save_subgames = []
    for roots in subgames:
        save_subgames.append(ExtensiveSubgame(roots))
    save_to_file(save_subgames, file_name_subgames)
    print("=== Subgames saved")
    trunk = ExtensiveTrunk(game, subgames)
    save_to_file(trunk, file_name_trunk)
    print("=== Trunk saved")


def rnr_decompose_step(node, depth, decompose_depth, node_to_sequences):
    if depth == 0:
        for i, child in enumerate(node.children):
            rnr_decompose_step(node=child, depth=depth + 1, decompose_depth=decompose_depth,
                               node_to_sequences=node_to_sequences)
    else:
        if depth == decompose_depth:
            node_to_sequences[node.i_set - 1].append(node)
            return
        for i, child in enumerate(node.children):
            rnr_decompose_step(node=child, depth=depth + 1, decompose_depth=decompose_depth,
                               node_to_sequences=node_to_sequences)


def aaai_experiments():
    rqr = ADAPTQR("data/one_card_poker.efg", rationality=2)
    rqr.solve(verbose=2)
    print(rqr.quantal_response(0, rqr.average_strategy, 2))
    rat = 10
    print(rat)
    for p in np.linspace(0, 1, 101):
        rqr = RQR("data/one_card_poker.efg", rationality=rat)
        rqr.solve(p=p)
        rr = rqr.quantal_response(0, rqr.average_strategy, rat)
        print(p, rr[0])
    dir = "big"
    data = load_from_file("data/adaptqr_values_comb_" + dir + "(100up).avs")
    data2 = load_from_file("data/plus_times_and_values_" + dir + ".tav")
    for game, game2 in zip(data, data2):
        print(game, game2)
    rat = 1
    fname = "data/big/00002.gbt"
    cfrqr = CFRQRCFV(fname, rationality=rat)
    cfrqr.solve(verbose=0)
    max_qr = -100
    max_br = -100
    maxa_qr = -100
    maxa_br = -100
    for i in np.linspace(1.1, 3, 29):
        print("========", i)
        cfrqrover = CFRQRCFV(fname, rationality=i * rat)
        cfrqrover.solve(verbose=0)
        over = cfrqr.quantal_response(0, cfrqrover.strategy, rat)
        over_br = cfrqr.best_response(0, cfrqrover.strategy)
        print(over[0], over_br[0])
        if over[0] > max_qr:
            max_qr = over[0]
            max_br = over_br[0]
        aqrover = ADAPTQRIMP(fname, rationality=i * rat)
        aqrover.solve(weak_rat=rat)
        aover = cfrqr.quantal_response(0, aqrover.best_strategy, rat)[0]
        aover_br = cfrqr.best_response(0, aqrover.best_strategy)[0]
        if aover > maxa_qr:
            maxa_qr = aover
            maxa_br = aover_br
        print(aover, aover_br)
    orig = cfrqr.quantal_response(0, cfrqr.strategy, rat)
    orig_br = cfrqr.best_response(0, cfrqr.strategy)
    print(orig[0], orig_br[0])
    aqr = ADAPTQR(fname, rationality=rat)
    aqr.solve()
    print(cfrqr.quantal_response(0, aqr.average_strategy, rat)[0], cfrqr.best_response(0, aqr.average_strategy)[0])
    print(aqr.best_comb_qr[1], aqr.best_comb_br[1])
    print(max_qr, max_br)
    print(maxa_qr, maxa_br)
    rat = 1
    fname = "data/big/00000.gbt"
    cfrqr = CFRQRCFV(fname, rationality=rat)
    cfrqr.initialize()
    # cfrqr.solve(verbose=0)
    i = 2.3892857142857142
    print(i)
    cfrqrover = CFRQRCFV(fname, rationality=i * rat)
    cfrqrover.solve(verbose=2)
    over = cfrqr.quantal_response(0, cfrqrover.strategy, rat)
    over_br = cfrqr.best_response(0, cfrqrover.strategy)
    print(over[0], over_br[0])
    aqrover = ADAPTQR(fname, rationality=i * rat)
    aqrover.solve(verbose=2)
    aover = cfrqr.quantal_response(0, aqrover.best_strategy, rat)[0]
    aover_br = cfrqr.best_response(0, aqrover.best_strategy)[0]
    print(aover, aover_br)


def create_exp4_strategies(im_models, fname, player, p=1):
    created_strategies = []
    fname_out = "data/temp.efg"
    for i, strategy in enumerate(im_models):
        test_generate_rnr_game(fname_in=fname, fname_out=fname_out, fixed_player=player, strategies=[strategy], p=[p, 1 - p])
        sequence_nash = SequenceNash(fname_out)
        sequence_nash.solve(1 - player)
        created_strategies.append(sequence_nash.strategy_in_cfr_format()[1 - player])
    return created_strategies


def combine_strategy_by_levels(s1, s2, node, level, player, played=0, s3=None):
    should_return = False
    if s3 is None:
        should_return = True
        s3 = copy.deepcopy(s1)
    if node.player == 2:
        for child in node.children:
            combine_strategy_by_levels(s1, s2, child, level, player, played, s3)
    elif node.player == player:
        s3[node.i_set] = s1[node.i_set] if played <= level else s2[node.i_set]
        for child in node.children:
            combine_strategy_by_levels(s1, s2, child, level, player, played + 1, s3)
    elif node.player == 1 - player:
        for child in node.children:
            combine_strategy_by_levels(s1, s2, child, level, player, played, s3)
    else:
        return
    if should_return:
        return s3


def mrnr_leduc(opponent_strategy=None, mode=None, mrnr_method=None, rnrp=0):
    np.random.seed(58)

    if mrnr_method is None:
        # mrnr_method = "hand_crafted"
        mrnr_method = "clustered"

    if mrnr_method == "hand_crafted":
        strategy_names = glob("data/mrnr_leduc/strategies/use_*.plk")
        representative_strategies = [load_from_file(fname) for fname in strategy_names]
    elif mrnr_method == "clustered":
        representative_strategies = load_from_file("data/im_and_em/random_model_strategies/em_models")
    else:
        raise ValueError("Wrong mrnr method selected: " + str(mrnr_method))

    nash_lp = SequenceNash("data/mrnr_leduc/leduc_holdem.efg")
    nash_lp.solve(0)
    nash_strategy = nash_lp.strategy_in_cfr_format()[0]

    nash_lp.reset()
    nash_lp.solve(1)
    opponent_nash = nash_lp.strategy_in_cfr_format()[1]

    exp4_strategies = create_exp4_strategies(representative_strategies, "data/mrnr_leduc/leduc_holdem.efg", 1, p=rnrp)
    # exp4_strategies = create_exp4_strategies(representative_strategies, "data/mrnr_leduc/leduc_holdem.efg", 1, p=0.2)

    raise_fold_strategy = combine_strategy_by_levels(representative_strategies[2], representative_strategies[1], nash_lp.game.root, level=1, player=1)

    # Strategies Creation
    cfr = CFR("data/mrnr_leduc/leduc_holdem.efg")
    cfr.initialize()

    combine = Combination([nash_strategy, representative_strategies[0]], [nash_strategy, representative_strategies[2]], cfr.game)
    combined_strategy = combine.combine_strategies(0.7)

    # uniform_strategy = cfr.strategy[1]
    # set_strategy(cfr.game.root, 1, 1, uniform_strategy, 0, 0)

    # save_to_file(uniform_strategy, "data/mrnr_leduc/strategies/use_always_call.plk")
    # print(uniform_strategy)
    # strategy = load_from_file("data/mrnr_leduc/strategies/always_fold.plk")
    # print(create_exp4_strategies(0, [strategy], "data/mrnr_leduc/leduc_holdem.efg"))

    fname_orig = "data/mrnr_leduc/leduc_holdem.efg"
    fname = "data/mrnr_leduc/leduc_holdem_mrnr.efg"
    player = 0

    repeated_plays = 100000

    game = ExtensiveGame()
    game.load(fname_orig)

    # opponent_strategy = combined_strategy[1 - player]
    # opponent_strategy = representative_strategies[2]
    # opponent_strategy = opponent_nash
    # opponent_strategy = raise_fold_strategy
    if opponent_strategy is None:
        cfr.initialize_random_strategy()
        opponent_strategy = cfr.strategy[1 - player]

    full_strategy = [{}, opponent_strategy] if player == 0 else [opponent_strategy, {}]

    best_response_all = cfr.best_response(player, full_strategy)
    best_response = best_response_all[1][player]

    # evalueate_strategy([exp4_strategies[], opponent_strategy], 0, fname_orig)

    if mode is None:
        mode = "mrnr"

    results = {}

    if mode == "test_exp4":
        results["exp4"] = compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=exp4_strategies)
        for i, exp4_strategy in enumerate(exp4_strategies):
            results["exp4 part " + str(i)] = online_play(repeated_plays, game, player, opponent_strategy, exp4_strategy)
    elif mode == "test_both":
        results["exp4"] = compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=exp4_strategies)
        results["mrnr"] = compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=representative_strategies)
    elif mode == "exp4":
        results["exp4"] = compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=exp4_strategies)
    elif mode == "mrnr":
        results["mrnr"] = compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=representative_strategies)
    else:
        raise ValueError("Wrong computation selected: " + str(mode))
    results["best"] = online_play(repeated_plays, game, player, opponent_strategy, best_response)
    return results


def online_play(repeated_plays, game, player, opponent_strategy, strategy, text):
    full_return = []

    for t in range(repeated_plays):
        node = game.root
        while not node.player == 3:
            if node.player == 2:
                node = np.random.choice(node.children, p=node.chance)
            elif node.player == player:
                node = np.random.choice(node.children, p=strategy[node.i_set])
            else:
                node = np.random.choice(node.children, p=opponent_strategy[node.i_set])
            if node.player == 3:
                full_return.append(node.value * (1 if player == 0 else -1))
    m, se = np.mean(full_return), scipy.stats.sem(full_return)
    h = se * scipy.stats.t.ppf((1 + 0.95) / 2., repeated_plays - 1)
    print(text, m, "+-", h)
    return m, h


def normalize_probs(probs):
    if np.min(probs) < 0:
        probs = probs - np.min(probs)
    probs[0] = max(0, 1 - np.sum(probs[1:]))
    return probs / np.sum(probs)


def set_strategy(node, action, player, strategy, depth, chance_count):
    if node.player == 3:
        return
    if node.player == player:
        action_index = action
        # always call snippet
        if len(strategy[node.i_set]) == 2:
            if chance_count == 1:
                if node.children[0].player != 3:
                    action_index = 0
            if chance_count == 2:
                if node.children[1].player == 0:
                    print(depth, node)
                    action_index = 0
        strategy[node.i_set] = [0] * len(strategy[node.i_set])
        strategy[node.i_set][action_index] = 1
    if node.player == 2:
        chance_count += 1
    for child in node.children:
        set_strategy(child, action, player, strategy, depth + 1, chance_count)


def test_mrnr():
    only_head = [{0: [0, 1], 1: [1, 0], 2: [0, 1], 3: [1, 0], 4: [0, 1]},
                 {0: [1, 0], 1: [1, 0], 2: [1, 0], 3: [1, 0], 4: [1, 0]}]
    only_tail = [{0: [0, 1], 1: [1, 0], 2: [0, 1], 3: [1, 0], 4: [0, 1]},
                 {0: [0, 1], 1: [0, 1], 2: [0, 1], 3: [0, 1], 4: [0, 1]}]
    head_and_tail = [{0: [0, 1], 1: [1, 0], 2: [0, 1], 3: [1, 0], 4: [0, 1]},
                     {0: [0.5, 0.5], 1: [0.5, 0.5], 2: [0.5, 0.5], 3: [0.5, 0.5], 4: [0.5, 0.5]}]
    equilibrium = [{0: [0.5, 0.5], 1: [1, 0], 2: [0, 1], 3: [1, 0], 4: [0, 1]},
                   {0: [1, 0], 1: [0, 1], 2: [0.5, 0.5], 3: [0, 1], 4: [0.5, 0.5]}]
    exploit_this = [{0: [0.5, 0.5], 1: [1, 0], 2: [0, 1], 3: [1, 0], 4: [0, 1]},
                    {0: [1, 0], 1: [0.5, 0.5], 2: [0, 1], 3: [0, 1], 4: [0.5, 0.5]}]

    rnr_fnames = ["data/rmp/rmp_oh_rnr.gbt", "data/rmp/rmp_ot_rnr.gbt"]
    fname_orig = "data/rmp/rmp.gbt"
    fname = "data/rmp/rmp_rnr.gbt"
    player = 0

    cfr = CFR(fname_orig)
    cfr.initialize()

    combine = Combination(only_head, only_tail, cfr.game)
    strategy = combine.combine_strategies(0.7)

    nash = SequenceNash(fname_orig)
    nash.solve(0)
    nash_strategy = nash.strategy_in_cfr_format()[player]

    rnr_strategies = [nash_strategy]
    for rnr_fname in rnr_fnames:
        rnr_nash = SequenceNash(rnr_fname)
        rnr_nash.solve(0)
        rnr_strategies.append(rnr_nash.strategy_in_cfr_format()[player])

    repeated_plays = 100000

    game = ExtensiveGame()
    game.load(fname_orig)

    np.random.seed(42)

    models = [only_head[1], only_tail[1]]

    rnr_strategies = [cfr.best_response(0, [{}, model])[1][0] for model in models]

    # opponent_strategy = only_tail
    # opponent_strategy = strategy
    opponent_strategy = exploit_this

    best_response_to_opp_strategy = cfr.best_response(player, opponent_strategy)[1][player]

    opponent_strategy = opponent_strategy[1 - player]

    exp4 = False

    # Changed rnr_strategies to models
    if exp4:
        compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=rnr_strategies)
    else:
        compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=models)
    online_play(repeated_plays, game, player, opponent_strategy, best_response_to_opp_strategy)


def mrnr_matrix_big(fname_orig, strat_name, player, temp="temp", mode=None, mrnr_method=None, strategy=None, rnrp=0):
    np.random.seed(42)
    cfr = CFR(fname_orig)
    cfr.initialize()

    game = ExtensiveGame()
    game.load(fname_orig)

    if mrnr_method is None:
        mrnr_method = "optimized"

    if mrnr_method == "optimized":
        representative_strategies = [{0: k} for k in load_from_file("data/nfgs/random_strategies/" + strat_name + "/em_optimized")]
    elif mrnr_method == "clustered":
        representative_strategies = load_from_file("data/nfgs/random_strategies/" + strat_name + "/em_models")
    else:
        raise ValueError("Wrong mrnr method selected: " + str(mrnr_method))

    im_models = load_from_file("data/nfgs/random_strategies/" + strat_name + "/im_models")
    exp4_strategies = create_exp4_strategies(im_models, "data/nfgs/" + strat_name + ".efg", 1, p=rnrp)
    # ""
    fname = "data/nfgs/" + temp + "temp.efg"
    n_p = len(representative_strategies) + 1

    test_generate_rnr_game(fname_orig, fname, 1 - player, representative_strategies, [1 / n_p] * n_p)

    if strategy is None:
        cfr.initialize_random_strategy()
        strategy = cfr.strategy
    opponent_strategy = strategy

    best_response = cfr.best_response(player, opponent_strategy)[1][player]

    nash_computation = SequenceNash(fname_orig)
    nash_computation.solve(player)
    nash_strategy = nash_computation.strategy_in_cfr_format()[player]

    opponent_strategy = opponent_strategy[1 - player]

    repeated_plays = 100000

    # Changed rnr_strategies to models
    if mode is None:
        mode = "mrnr"

    results = {}
    if mode == "test_exp4":
        results["exp4"] = compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=exp4_strategies)
        for i, exp4_strategy in enumerate(exp4_strategies):
            results["exp4 part " + str(i)] = online_play(repeated_plays, game, player, opponent_strategy, exp4_strategy, "Exp4 part: ")
    elif mode == "test_both":
        results["exp4"] = compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=exp4_strategies)
        results["mrnr"] = compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=representative_strategies)
    elif mode == "exp4":
        results["exp4"] = compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=exp4_strategies)
    elif mode == "mrnr":
        results["mrnr"] = compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=representative_strategies)
    else:
        raise ValueError("Wrong computation selected: " + str(mode))
    results["best"] = online_play(repeated_plays, game, player, opponent_strategy, best_response, "Best response: ")
    results["nash"] = online_play(repeated_plays, game, player, opponent_strategy, nash_strategy, "Nash performance: ")
    return results


def mrnr_matrix():
    # only_x = [{0: [1, 0]}, {0: [1, 0]}]
    # only_y = [{0: [0, 1]}, {0: [0, 1]}]

    prob_x = [{0: [0.9, 0.1]}, {0: [0.9, 0.1]}]
    prob_y = [{0: [0.1, 0.9]}, {0: [0.1, 0.9]}]

    fname_orig = "data/2ND_mrnr/2ND.gbt"
    fname = "data/2ND_mrnr/2ND_mrnr.gbt"
    player = 0

    cfr = CFR(fname_orig)
    cfr.initialize()

    combine = Combination(prob_x, prob_y, cfr.game)
    strategy = combine.combine_strategies(0.7)

    nash = SequenceNash(fname_orig)
    nash.solve(0)
    nash_strategy = nash.strategy_in_cfr_format()[player]

    repeated_plays = 100000

    game = ExtensiveGame()
    game.load(fname_orig)

    np.random.seed(42)

    models = [prob_x[1], prob_y[1]]

    # opponent_strategy = only_tail
    opponent_strategy = strategy
    # opponent_strategy = exploit_this

    best_response_to_opp_strategy = cfr.best_response(player, opponent_strategy)[1][player]

    opponent_strategy = opponent_strategy[1 - player]

    exp4 = False
    both = False
    # Changed rnr_strategies to models
    if both:
        compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=models)
        compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=models)

    elif exp4:
        compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies=models)
    else:
        compute_mrnr(repeated_plays, game, player, opponent_strategy, fname=fname, models=models)
    online_play(repeated_plays, game, player, opponent_strategy, best_response_to_opp_strategy)


def compute_mrnr(repeated_plays, game, player, opponent_strategy, fname, models, solve_each=1, solve_mult=2,
                 alpha_decay=True):
    full_return = []
    verb = False
    glob_verb = False

    regret = [0] * len(models)

    alpha = 0.1
    epsilon = 0.01

    my_alg_nash = SequenceNash(fname)

    probs = [1 / len(models)] * len(models)
    p = 0

    for t in range(repeated_plays):
        if t == solve_each or t == 0:
            if t > 0:
                solve_each *= solve_mult
                # solve_each += 1
            if glob_verb:
                print(t)
                print(probs)
            my_alg_nash.game.root.chance = np.concatenate(([p], probs))
            my_alg_nash.reset()
            my_alg_nash.solve(0)
            # print(my_alg_nash.solution)
            mrnr_strategy = my_alg_nash.strategy_in_cfr_format()[0]
        node = game.root
        episode_actions = []
        action_model_probs = []
        for _ in range(len(models)):
            action_model_probs.append([])
        while not node.player == 3:
            if node.player == 2:
                node = np.random.choice(node.children, p=node.chance)
            elif node.player == player:
                node = np.random.choice(node.children, p=mrnr_strategy[node.i_set])
            else:
                action = np.random.choice(list(range(len(node.children))), p=opponent_strategy[node.i_set])
                episode_actions.append(action)
                for i in range(len(action_model_probs)):
                    action_model_probs[i].append(models[i][node.i_set][action])
                node = node.children[action]
            if node.player == 3:
                value = node.value * (1 if player == 0 else -1)
                full_return.append(value)
                ### EM TRY CROSS ENTROPY
                if verb:
                    print(action_model_probs)
                model_probs = np.product(np.clip(action_model_probs, epsilon, 1), 1)
                if verb:
                    print(model_probs)
                normalization_factor = np.sum(probs * model_probs)
                if verb:
                    print(normalization_factor)
                    print(probs * model_probs)
                probs_change = (probs * model_probs) / normalization_factor
                if verb:
                    print(probs_change)
                probs = probs + (1 / (t + 1)) * (probs_change - probs)
                # probs = probs + alpha * (probs_change - probs)
                if alpha_decay and alpha is not None:
                    alpha = max(alpha * 0.9, 0.001)
                if verb:
                    print(probs)
                probs = probs / np.sum(probs)
                probs[probs < 10 ** -10] = 0
                if verb:
                    print(probs)
                ### EM TRY SQUARED
                # if verb:
                #     print(action_model_probs)
                # model_probs = np.product(np.clip(action_model_probs, epsilon, 1), 1)
                # if verb:
                #     print(model_probs)
                # normalization_factor = np.sum(probs * model_probs)
                # if verb:
                #     print(normalization_factor)
                #     print(probs * model_probs)
                # probs_change = (probs * model_probs)
                # if verb:
                #     print(probs_change)
                # if alpha is None:
                #     alpha = 1 / (t + 1)
                # probs = probs + alpha * (probs_change - probs)
                # if alpha_decay:
                #     alpha *= 0.999
                # if verb:
                #     print(probs)
                # probs = probs / np.sum(probs)
                # probs[probs < 10 ** -10] = 0
                # if verb:
                #     print(probs)
                ### REGRET TRY
                # if verb:
                #     print(action_model_probs)
                # model_probs = np.product(np.clip(action_model_probs, epsilon, 1), 1)
                # if verb:
                #     print(model_probs)
                # prob = np.dot(probs, model_probs)
                # if verb:
                #     print(prob)
                # current_regrets = model_probs - prob
                # if verb:
                #     print(current_regrets)
                # regret += np.clip(current_regrets, 0, np.inf)
                # if verb:
                #     print(regret)
                # cliped_regret = np.clip(regret, 0, np.inf)
                # regret_sum = np.sum(cliped_regret)
                # if regret_sum == 0:
                #     probs = [1 / len(models)] * len(models)
                # else:
                #     probs = cliped_regret / regret_sum
                # if verb:
                #     print(probs)
                ### AVERAGE TRY
                # if verb:
                #     print(action_model_probs)
                # model_probs = np.product(np.clip(action_model_probs, epsilon, 1), 1)
                # if verb:
                #     print(model_probs)
                # probs = probs + 1 / (t + 1) * (model_probs - probs)
                # # probs = probs + alpha * (model_probs - probs)
                # if alpha_decay:
                #     alpha *= 0.999
                # if verb:
                #     print(probs)
                # probs = probs / np.sum(probs)
                # probs[probs < 10 ** -10] = 0
                # if verb:
                #     print(probs)
    m, se = np.mean(full_return), scipy.stats.sem(full_return)
    h = se * scipy.stats.t.ppf((1 + 0.95) / 2., repeated_plays - 1)
    print("MRNR:", m, "+-", h)
    return m, h


def test_learning_mrnr(repeated_plays, game, player, opponent_strategy, models, solve_each=1, solve_mult=2):
    full_return = []
    verb = False

    epsilon = 0.01

    probs = [1 / len(models)] * len(models)
    p = 0

    for t in range(repeated_plays):
        if t == solve_each or t == 0:
            if t > 0:
                # solve_each *= solve_mult
                solve_each += 1
            print(t)
            print(probs)
            # print(my_alg_nash.solution)
        node = game.root
        episode_actions = []
        action_model_probs = []
        for _ in range(len(models)):
            action_model_probs.append([])
        while not node.player == 3:
            if node.player == 2:
                node = np.random.choice(node.children, p=node.chance)
            elif node.player == player:
                node = np.random.choice(node.children, p=[1 / len(node.children)] * len(node.children))
            else:
                action = np.random.choice(list(range(len(node.children))), p=opponent_strategy[node.i_set])
                episode_actions.append(action)
                for i in range(len(action_model_probs)):
                    action_model_probs[i].append(models[i][node.i_set][action])
                node = node.children[action]
            if node.player == 3:
                value = node.value * (1 if player == 0 else -1)
                full_return.append(value)
                ### EM TRY
                if verb:
                    print(action_model_probs)
                model_probs = np.product(np.clip(action_model_probs, epsilon, 1), 1)
                if verb:
                    print(model_probs)
                normalization_factor = np.sum(probs * model_probs)
                if verb:
                    print(normalization_factor)
                probs_change = (probs * model_probs) / normalization_factor
                if verb:
                    print(probs_change)
                probs = probs + (1 / (t + 1)) * (probs_change - probs)
                if verb:
                    print(probs)
                probs = probs / np.sum(probs)
                probs[probs < 10 ** -10] = 0
                if verb:
                    print(probs)
    m, se = np.mean(full_return), scipy.stats.sem(full_return)
    h = se * scipy.stats.t.ppf((1 + 0.95) / 2., repeated_plays - 1)
    print(m, "+-", h)


def compute_exp4(repeated_plays, game, player, opponent_strategy, rnr_strategies):
    # Exp4 implicit modeling
    hedge = np.zeros(len(rnr_strategies))
    nau = 0.025
    gamma = 0.01
    full_return = []
    for t in range(repeated_plays):
        values = nau * hedge
        exps = np.exp(values - np.max(values))
        q = exps / np.sum(exps)
        node = game.root
        actions_played = []
        probabilities = []
        probs = []
        j = 0
        reaches = [1] * len(hedge)
        while not node.player == 3:
            if node.player == 2:
                node = np.random.choice(node.children, p=node.chance)
            elif node.player == player:
                # 2
                advice_vectors = np.asarray([expert_strategy[node.i_set] for expert_strategy in rnr_strategies])
                reach_probs = reaches * q
                if np.sum(reach_probs) > 0:
                    reach_probs = reach_probs / np.sum(reach_probs)
                else:
                    reach_probs = np.asarray([1 / len(q)] * len(q))
                p = np.sum([advice_vectors[i] * reach_probs[i] for i in range(len(q))], 0)
                p = (1 - gamma) * p + gamma / len(p)
                action = np.random.choice(list(range(len(p))), p=p)
                actions_played.append((node.i_set, action))
                probs.append(p[action])
                probabilities.append([])
                for i in range(len(q)):
                    reaches[i] *= advice_vectors[i][action]
                    probabilities[j].append(advice_vectors[i][action] * reaches[i])
                j += 1
                node = node.children[action]
            else:
                node = np.random.choice(node.children, p=opponent_strategy[node.i_set])
            if node.player == 3:
                value = node.value * (1 if player == 0 else -1)
                reward_inter = np.zeros(len(q))
                for probabs, prob in zip(probabilities, probs):
                    reward_inter += np.asarray(probabs) * value / prob
                hedge += reward_inter / len(probs)
                full_return.append(value)
    m, se = np.mean(full_return), scipy.stats.sem(full_return)
    h = se * scipy.stats.t.ppf((1 + 0.95) / 2., repeated_plays - 1)
    print("EXP4", m, "+-", h)
    return m, h


def evalueate_strategy(strategy, player, fname):
    cfr = CFR(fname)
    cfr.initialize()
    cfr.strategy = strategy
    cfr.compute_counterfactual_values()
    print(cfr.counterfactual_values[cfr.game.root] * (1 if player == 0 else -1))


def side_tests():
    nash = SequenceNash("data/rmp/rmp_ot_0.25.gbt")
    print(nash.solve(0))
    test_br_to_strategy(nash.strategy_in_cfr_format(), "data/rmp/rmp.gbt", 1)


def test_mrnr_lp():
    fname_original = "data/rmp/rmp.gbt"
    fname_mrnr = "data/rmp/rmp_rnr.gbt"
    models = [{0: [1, 0], 1: [1, 0], 2: [1, 0], 3: [1, 0], 4: [1, 0]},
              {0: [0, 1], 1: [0, 1], 2: [0, 1], 3: [0, 1], 4: [0, 1]}]
    strategy = {0: [1, 0], 1: [1, 0], 2: [1, 0], 3: [1, 0], 4: [1, 0]}

    mrnr_lp = OptimalMRNR(fname_original, strategy, models)
    mrnr_lp.binary_search(0)
    # mrnr_lp.solve(0, 0)
    # print(mrnr_lp.solution)
    # print(mrnr_lp.strategy_in_cfr_format())

    # nash_lp = SequenceNash(fname_original)
    # nash_lp.solve(0)
    # print(nash_lp.solution)
    # print(nash_lp.strategy_in_cfr_format())


def test_mrnr_optimization():
    import math
    model_ps = [[0.9, 0.1], [0.1, 0.9]]
    p = [0.3, 0.7]
    num_values = 1000
    values = [0] * int(num_values * (p[0] * model_ps[0][0] + p[1] * model_ps[1][0])) + [1] * int(num_values * (
            p[0] * model_ps[0][1] + p[1] * model_ps[1][1]))

    m = g.Model()

    # variables
    p_var = m.addVars(len(p), lb=0, ub=1, vtype=g.GRB.CONTINUOUS)
    obj = m.addVar(lb=-g.GRB.INFINITY, vtype=g.GRB.CONTINUOUS, name="objective")
    m.update()
    m.setObjective(obj, g.GRB.MAXIMIZE)

    m.addConstr(g.quicksum(p_var) == 1)
    m.addConstr(obj == g.quicksum(
        [g.quicksum([p_var[p] * model_ps[i][val] for i, p in enumerate(p_var)]) for val in values]))

    # m.setParam('OutputFlag', False)
    m.write("lp.lp")
    m.optimize()

    print(p_var[0].x, p_var[1].x)


def test_no_max_nash():
    lpnash = SequenceNash("data/rmp/rmp.gbt")
    lpnash.solve(1)

    from NoMaxNash import NoMaxNash
    nmn = NoMaxNash("data/rmp/rmp.gbt")
    nmn.solve()


def test_qcfr():
    fname = "data/2stagegame_rev.efg"
    iterations = 10000

    cfr = CFR(fname, cfr_plus=True)
    cfrqr = CFRQRCFV(fname, cfr_player=1, rationality=1)
    qcfr = QCFR(fname, cfr_player=1, rationality=1)

    cfr.solve(iterations=iterations)
    cfrqr.solve(iterations=iterations)
    qcfr.solve(iterations=iterations, verbose=2)

    print(cfr.strategy, cfr.average_strategy)
    print(cfrqr.strategy, cfrqr.average_strategy)
    print(qcfr.strategy, qcfr.average_strategy)


def brps_test():
    fname_in = "data/BRPS.gbt"
    fname_out = "data/BRPS_TEMP.gbt"
    repeated_plays = 100000
    player = 0

    game = ExtensiveGame()
    game.load(fname_in)

    probs = []

    opponent_strategies = [
        {0: [0.5, 0, 0.5]},
        {0: [0, 0, 1]},
        {0: [1, 0, 0]},
        {0: [0, 1, 0]},
        {0: [0.5, 0.5, 0]},
        {0: [0.1, 0.5, 0.4]},
        {0: [0.2, 0.3, 0.5]},
        {0: [0.2, 0.8, 0]},
        {0: [0, 0.2, 0.8]},
    ]

    models = [{0: [1, 0, 0]}, {0: [0, 0.5, 0.5]}]

    for opponent_strategy in opponent_strategies:
        test_generate_rnr_game(fname_in, fname_out, 1 - player, models, [1 / (len(models) + 1)] * (len(models) + 1))

        probs.append(compute_mrnr(repeated_plays, game, player, opponent_strategy, fname_out, models, solve_each=1,
                                  solve_mult=2, alpha_decay=True))

        cfr = CFR(fname_in)
        cfr.initialize()
        full_opponent_strategy = [{}, {}]
        full_opponent_strategy[1 - player] = opponent_strategy
        best_response_to_opp_strategy = cfr.best_response(player, full_opponent_strategy)[1][player]

        online_play(repeated_plays, game, player, opponent_strategy, best_response_to_opp_strategy)

    print(probs)

    x = np.asarray([1, 0])
    y = np.asarray([0, 1])
    z = np.asarray([0, 0])
    model_coords = []
    for model in models:
        model_coords.append(model[0][0] * x + model[0][1] * y)

    model_lines = []
    model_colors = []
    arrows = []
    for i in range(len(model_coords)):
        model_lines.append([model_coords[i], model_coords[(i + 1) % len(model_coords)]])
        model_colors.append((1, 0, 0, 1))
    for opponent_strategy, prob in zip(opponent_strategies, probs):
        opponent_strategy_coord = opponent_strategy[0][0] * x + opponent_strategy[0][1] * y
        mapped_strategy = np.matmul(prob, [model[0] for model in models])
        mapped_strategy_coord = mapped_strategy[0] * x + mapped_strategy[1] * y
        arrow_length = mapped_strategy_coord - opponent_strategy_coord
        arrows.append((opponent_strategy_coord, arrow_length))

    lines = [[x, y], [x, z], [y, z]] + model_lines
    c = np.array([(0, 0, 0, 1), (0, 0, 0, 1), (0, 0, 0, 1)] + model_colors)
    lc = mc.LineCollection(lines, colors=c, linewidths=2)
    fig, ax = pl.subplots(figsize=(12, 12))
    ax.add_collection(lc)
    ax.margins(0.1)
    plt.xlabel("Rock")
    plt.ylabel("Paper")
    plt.ylim(-0.1, 1.1)
    plt.xlim(-0.1, 1.1)
    for arrow in arrows:
        plt.arrow(arrow[0][0], arrow[0][1], arrow[1][0], arrow[1][1], head_width=0.01, length_includes_head=True,
                  color=(0, 1, 0, 0.75), zorder=100)
    plt.show()


def test_losses():
    opponent_strategies = [
        {0: [0.5, 0, 0.5]},
        {0: [0, 0, 1]},
        {0: [1, 0, 0]},
        {0: [0, 1, 0]},
        {0: [0.5, 0.5, 0]},
        {0: [0.1, 0.5, 0.4]},
        {0: [0.2, 0.3, 0.5]},
        {0: [0.1, 0.9, 0]},
        {0: [0, 0.2, 0.8]},
    ]

    models = [{0: [0.5, 0.5, 0]}, {0: [0, 0.5, 0.5]}]

    algorithm_results = []
    for strategy in opponent_strategies:
        print(strategy)
        print(em_with_full_info(strategy, models))
        p = np.linspace(0, 1, 101)
        for loss in [(cross_entropy, "Cross entropy"), (kl_divergence, "KL divergence"), (mse, "MSE"), (l1, "L1 loss")]:
            loss_one = []
            loss_two = []
            for p_part in p:
                p_full = [p_part, 1 - p_part]
                loss_one_part, loss_two_part = loss_computation(strategy, models, p_full, loss[0])
                loss_one.append(loss_one_part)
                loss_two.append(loss_two_part)
            plt.plot(p, loss_one, label=loss[1])
            plt.plot(p, loss_two, label=loss[1])
            print(loss[1] + ":", np.argmin(loss_one) / 100, end=", ")
        print()
        plt.legend()
        plt.show()


def loss_computation(strategy, models, probabilities, loss):
    np_strategy = strategy[0]
    np_models = np.asarray([model[0] for model in models], dtype=np.float)
    model_action_probs = np.asarray(np.dot(probabilities, np_models), dtype=np.float)
    return loss(np_strategy, model_action_probs)


def cross_entropy(p_one, p_two):
    epsilon = 10 ** -6
    p_one = np.clip(p_one, epsilon, 1 - epsilon)
    p_two = np.clip(p_two, epsilon, 1 - epsilon)
    loss_one = -np.sum(p_one * np.log(p_two))
    loss_two = -np.sum(p_two * np.log(p_one))
    return loss_one, loss_two


def kl_divergence(p_one, p_two):
    epsilon = 10 ** -6
    p_one = np.clip(p_one, epsilon, 1 - epsilon)
    p_two = np.clip(p_two, epsilon, 1 - epsilon)
    loss_one = -np.sum(p_one * np.log(p_two / p_one))
    loss_two = -np.sum(p_two * np.log(p_one / p_two))
    return loss_one, loss_two


def mse(p_one, p_two):
    loss_one = np.sum(np.power(p_one - p_two, 2))
    loss_two = np.sum(np.power(p_two - p_one, 2))
    return loss_one, loss_two


def l1(p_one, p_two):
    loss_one = np.sum(np.abs(p_one - p_two))
    loss_two = np.sum(np.abs(p_two - p_one))
    return loss_one, loss_two


def em_with_full_info(strategy, models):
    n_models = len(models)
    p = np.full(n_models, 1 / n_models, dtype=np.float)
    p_old = np.zeros(n_models, dtype=np.float)
    threshold = 10 ** -6
    epsilon = 10 ** -4
    np_strategy = strategy[0]
    np_models = np.array([model[0] for model in models])
    while np.max(np.abs(p - p_old)) > threshold:
        p_old = p
        weighted_action_probabilities = np.clip(p_old[:, np.newaxis] * np_models, epsilon, 1)
        weighted_action_probabilities = weighted_action_probabilities / np.sum(weighted_action_probabilities, axis=0)
        p = np.dot(weighted_action_probabilities, np_strategy)
    return p


def test_refinements():
    cfr = CFR("data/eq_refinement_game_bigger.gbt")
    cfr.solve()
    cfr.print_strategy(cfr.average_strategy, 0)
    cfr.print_strategy(cfr.average_strategy, 1)


def create_cool_graphs_from_values():
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    plt.gcf().subplots_adjust(bottom=0.23, left=0.1, right=0.99, top=0.84, wspace=0.3)
    for case, ax in zip([2, 3], [ax1, ax2]):
        # print(values)
        if case == 0:
            # plot EFG gains
            plot_values(best_ne=True, loss=False, comb_value=1, ax=ax2)
        elif case == 1:
            # plot NFG gains
            graph_from_data(False, False, ax=ax1)
        elif case == 2:
            # plot EFG loss
            plot_values(best_ne=True, loss=True, comb_value=1, ax=ax2)
        else:
            # plot NFG loss
            graph_from_data(False, True, ax=ax1)
    handles, labels = ax.get_legend_handles_labels()
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    plt.legend(handles, labels, bbox_to_anchor=(-1.3, 1.05, 2.3, .102), loc='lower left',
               ncol=6, borderaxespad=0., handlelength=1.3, handletextpad=0.2, mode="expand", borderpad=0.2)
    plt.show()


def generate_random_strategies():
    # folder_out = "data/nfgs/random_strategies/10eqs_test/"
    folder_out = "data/step_br_leduc/strategies/"
    cfr = CFR("data/step_br_leduc/leduc_holdem.efg")
    cfr.initialize()
    for i in range(100):
        cfr.initialize_random_strategy()
        save_to_file(cfr.strategy[1], folder_out + "{:03d}".format(i))


def average_from_implicit_clusters():
    folder_in = "data/nfgs/random_strategies/4eqs/"
    clusters = load_from_file("data/nfgs/random_strategies/4eqs/im_clusters")
    fname = "data/nfgs/4eqs.efg"
    strategies = []
    for i in range(10):
        strategy = load_from_file(folder_in + "{:03d}".format(i))
        strategies.append(strategy)
    models = []
    for cluster in clusters.itersets():
        indexes = cluster
        current_strategies = [strategies[index] for index in indexes]
        combined_strategy = combine_strategies(current_strategies, [1 / len(current_strategies)] * len(current_strategies), fname, 1)
        models.append(combined_strategy)
    save_to_file(models, "data/nfgs/random_strategies/4eqs/im_models")


def cluster_implicit():
    # folder_in = "data/im_and_em/random_model_strategies/"
    folder_in = "data/nfgs/random_strategies/4eqs/"
    clusters = DisjointSet()
    strategies = []
    brs = {}
    fname = "data/nfgs/4eqs.efg"
    cfr = CFR(fname)
    cfr.initialize()
    # initialize clusters to all strategies
    for i in range(10):
        clusters.find(i)
        strategy = load_from_file(folder_in + "{:03d}".format(i))
        strategies.append(strategy)
        brs[(i,)] = -cfr.best_response(0, [{}, strategy])[0]
    k = 3
    while len(list(clusters.itersets())) > k:
        n_clusters = len(list(clusters.itersets()))
        print("Remaining clusters", n_clusters)
        join = [0, 0]
        best_value = np.inf
        for i in range(n_clusters):
            print("Cluster iteration", i)
            for j in range(i + 1, n_clusters):
                indexes_one = tuple(np.sort(tuple(list(clusters.itersets())[i])))
                indexes_two = tuple(np.sort(tuple(list(clusters.itersets())[j])))
                indexes = tuple(np.sort(indexes_one + indexes_two))
                if indexes not in brs:
                    current_strategies = [strategies[index] for index in indexes]
                    combined_strategy = combine_strategies(current_strategies, [1 / len(current_strategies)] * len(current_strategies), fname, 1)
                    brs[indexes] = -cfr.best_response(0, [{}, combined_strategy])[0]
                average_before = (brs[indexes_one] * len(indexes_one) + brs[indexes_two] * len(indexes_two)) / len(indexes)
                if average_before - brs[indexes] < best_value:
                    best_value = average_before - brs[indexes]
                    join = [clusters.find(indexes_one[0]), clusters.find(indexes_two[0])]
        clusters.union(join[0], join[1])
    # save_to_file(clusters, "data/im_and_em/random_model_strategies/im_clusters")
    save_to_file(clusters, "data/nfgs/random_strategies/4eqs/im_clusters")


def cluster_explicit_models():
    # folder_in = "data/im_and_em/random_model_strategies/"
    folder_in = "data/nfgs/random_strategies/4eqs/"
    strategies = []
    # fname = "data/leduc_holdem.efg"
    fname = "data/nfgs/10eqs.efg"
    game = ExtensiveGame()
    game.load(fname)
    # initialize clusters to all strategies
    cluster_n = 10
    for i in range(cluster_n):
        strategy = load_from_file(folder_in + "{:03d}".format(i))
        strategies.append(strategy)
    affinity_matrix = np.zeros((cluster_n, cluster_n))
    for i in range(cluster_n):
        print(i)
        for j in range(i + 1, cluster_n):
            affinity = spectral_clustering_kl_affinity(strategies[i], strategies[j], game)
            affinity_matrix[i][j] = affinity
            affinity_matrix[j][i] = affinity
    clustering = SpectralClustering(3, affinity='precomputed')
    cluster = clustering.fit_predict(affinity_matrix)
    print(cluster)
    # save_to_file(cluster, "data/im_and_em/random_model_strategies/em_clusters_new_kl")
    save_to_file(cluster, "data/nfgs/random_strategies/4eqs/em_clusters_kl")


def spectral_clustering_kl_affinity(strategy_one, strategy_two, game):
    return np.exp(-np.power(efg_kl_divergence(strategy_one, strategy_two, game), 2))


def efg_kl_divergence(strategy_one, strategy_two, game):
    sumed_return = -efg_kl_divergence_step(strategy_one, strategy_two, game.root, 1, 1)
    return sumed_return / len(strategy_one)


def efg_kl_divergence_step(strategy_one, strategy_two, node, reach_one, reach_two):
    player = node.player
    divergence_sum = 0
    if player == 3:
        epsilon = 10 ** -6
        p_one = np.clip(reach_one, epsilon, 1 - epsilon)
        p_two = np.clip(reach_two, epsilon, 1 - epsilon)
        loss_one = p_one * np.log(p_two / p_one)
        loss_two = p_two * np.log(p_one / p_two)
        return loss_one + loss_two
    if player == 1:
        for i, child in enumerate(node.children):
            new_reach_one = reach_one * strategy_one[node.i_set][i]
            new_reach_two = reach_two * strategy_two[node.i_set][i]
            divergence_sum += efg_kl_divergence_step(strategy_one, strategy_two, child, new_reach_one, new_reach_two)
    else:
        for child in node.children:
            divergence_sum += efg_kl_divergence_step(strategy_one, strategy_two, child, reach_one, reach_two)
    return divergence_sum


def create_average_from_explicit_clusters():
    # folder_in = "data/im_and_em/random_model_strategies/"
    # clusters = load_from_file("data/im_and_em/random_model_strategies/em_clusters")
    folder_in = "data/nfgs/random_strategies/10eqs/"
    clusters = load_from_file("data/nfgs/random_strategies/6eqs/em_clusters_kl")
    k = 3
    strategies = [[] for _ in range(k)]
    cluster_n = 10
    for i in range(cluster_n):
        strategy = load_from_file(folder_in + "{:03d}".format(i))
        strategies[clusters[i]].append(strategy)
    # fname = "data/leduc_holdem.efg"
    fname = "data/nfgs/10eqs.efg"
    cfr = CFR(fname)
    cfr.initialize()
    resulting_strategies = []
    for i in range(k):
        if len(strategies[i]) > 0:
            resulting_strategies.append(combine_strategies(strategies[i], [1 / len(strategies[i])] * len(strategies[i]), fname, 1))
    save_to_file(resulting_strategies, "data/nfgs/random_strategies/10eqs/em_models")


def evaluate_all_opponent_modeling_leduc():
    folder_in = "data/im_and_em/random_test_strategies/"
    mrnr_method = "hand_crafted"
    mode = "mrnr"
    n_strategies = 10
    results = []
    p = 0.4
    # for i in range(6, 7):
    for i in range(n_strategies):
        strategy = load_from_file(folder_in + "{:03d}".format(i))
        results.append(mrnr_leduc(strategy, mode=mode, mrnr_method=mrnr_method, rnrp=p))
        save_to_file(results, "data/im_and_em/results/results_small_" + mode + "_" + mrnr_method + "_" + str(p))


def evaluate_all_opponent_modeling_matrix(fname_part, strategy_part, strat_name, test, p=0):
    folder_in = "data/nfgs/random_strategies/" + strategy_part + ("_test/" if test else "/")
    mrnr_method = "optimized"
    mode = "test_both"
    n_strategies = 100
    results = []
    fname = "data/nfgs/" + fname_part + ".efg"
    # for i in range(3, 4):
    for i in range(n_strategies):
        strategy = [{}, {}]
        strategy[0] = {}
        strategy[1] = load_from_file(folder_in + "{:03d}".format(i))
        results.append(mrnr_matrix_big(fname, strat_name, temp=fname_part + strategy_part + strat_name, strategy=strategy, player=0, mode=mode, mrnr_method=mrnr_method, rnrp=p))
        save_to_file(results, "data/nfgs/results/results_" + ("test_" if test else "train_") + mode + "_" + mrnr_method + "_" + str(
            p) + "_" + "game_" + fname_part + "_test_strategies_" + strategy_part + "_model_strategies_" + strat_name)


def load_and_average_results_leduc():
    results = load_from_file("data/im_and_em/results/hand_crafted_mrnr_results")
    mrnr = [result["mrnr"][0] for result in results]
    print("MRNR hand crafted average", np.average(mrnr))

    results = load_from_file("data/im_and_em/results/clustered_mrnr_results")
    mrnr = [result["mrnr"][0] for result in results]
    print("MRNR clustered average", np.average(mrnr))

    results = load_from_file("data/im_and_em/results/results_small_test_both_clustered")
    print("Clustered")
    print("MRNR", np.average([result["mrnr"][0] for result in results]))
    print("EXP4", np.average([result["exp4"][0] for result in results]))
    print("BEST", np.average([result["best"][0] for result in results]))

    results = load_from_file("data/im_and_em/results/results_small_test_both_hand_crafted")
    print("Hand crafted")
    print("MRNR", np.average([result["mrnr"][0] for result in results]))
    print("EXP4", np.average([result["exp4"][0] for result in results]))
    print("BEST", np.average([result["best"][0] for result in results]))

    print([result["mrnr"][0] for result in load_from_file("data/im_and_em/results/results_small_test_both_clustered")])
    # print([result["exp4"][0] for result in load_from_file("data/im_and_em/results/results_small_exp4_clustered_0.2")])
    # print([result["exp4"][0] for result in load_from_file("data/im_and_em/results/results_small_exp4_clustered_0.4")])


def load_and_average_results_matrix():
    # Graph for each
    for fname in glob.glob("data/nfgs/results/results*"):
        results = load_from_file(fname)
        print("Optimized")
        print(len(results))
        print("MRNR", np.average([result["mrnr"][0] for result in results]), "±", scipy.stats.sem([result["mrnr"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results)))
        print("EXP4", np.average([result["exp4"][0] for result in results]), "±", scipy.stats.sem([result["exp4"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results)))
        print("BEST", np.average([result["best"][0] for result in results]), "±", scipy.stats.sem([result["best"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results)))
        print("NASH", np.average([result["nash"][0] for result in results]), "±", scipy.stats.sem([result["nash"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results)))

        plt.xticks([0, 1, 2, 3], ["MRNR", "EXP4", "BEST", "NASH"])
        plt.bar([0], [np.average([result["mrnr"][0] for result in results])], yerr=[scipy.stats.sem([result["mrnr"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))])
        plt.bar([1], [np.average([result["exp4"][0] for result in results])], yerr=[scipy.stats.sem([result["exp4"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))])
        plt.bar([2], [np.average([result["best"][0] for result in results])], yerr=[scipy.stats.sem([result["best"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))])
        plt.bar([3], [np.average([result["nash"][0] for result in results])], yerr=[scipy.stats.sem([result["nash"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))])
        plt.title(fname[fname.find("\\") + 1:])
        plt.show()
    # results = load_from_file("data/nfgs/results/results_test_test_both_optimized_0.0_game_10eqs_test_strategies_10eqs_model_strategies_10eqs")
    # mrnr_mean = np.average([result["mrnr"][0] for result in results])
    # mrnr_error = scipy.stats.sem([result["mrnr"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))
    # best_mean = np.average([result["best"][0] for result in results])
    # best_error = scipy.stats.sem([result["best"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))
    # nash_mean = np.average([result["nash"][0] for result in results])
    # nash_error = scipy.stats.sem([result["nash"][0] for result in results]) * scipy.stats.t.ppf(1.95 / 2., len(results))
    # exp4_means = []
    # exp4_errors = []
    # for fname in glob.glob("data/nfgs/results/*"):
    #     results = load_from_file(fname)
    #     exp4_means.append(np.average([result["exp4"][0] for result in results]))
    #     exp4_errors.append(scipy.stats.sem([result["exp4"][0] for result in results]))
    # exp4_length = len(exp4_means)
    # plt.xticks(list(range(exp4_length + 3)), ["NASH", "MRNR"] + ["EXP4 0." + str(i) for i in range(exp4_length)] + ["BEST"])
    # plt.bar([0], [nash_mean], yerr=[nash_error])
    # plt.bar([1], [mrnr_mean], yerr=[mrnr_error])
    # plt.bar([exp4_length + 2], [best_mean], yerr=[best_error])
    # plt.bar(list(range(2, exp4_length + 2)), exp4_means, yerr=exp4_errors)
    # plt.show()


def envelope_explicit_models():
    folder_in = "data/nfgs/random_strategies/4eqs/"
    strategies = []
    cluster_n = 10
    for i in range(cluster_n):
        strategies.append(load_from_file(folder_in + "{:03d}".format(i))[0])
    strategies = np.asarray(strategies)
    hull = general_optimization_hull(strategies, 3)
    print(hull)
    save_to_file(hull, "data/nfgs/random_strategies/4eqs/em_optimized")
    # optimize_the_hull(strategies, 3)


def get_sequences(fname):
    game = ExtensiveGame()
    game.load(fname)
    sequences = game.get_sequences(1)
    print(sequences)
    # sequence = np.random.random(len(sequences))
    # sequence = sequence / np.sum(sequence)
    indexes = [0, 5]
    probs = [0.25, 0.75]
    sequence = np.zeros(np.max(indexes) + 1)
    for index, prob in zip(indexes, probs):
        sequence[index] = prob
    cfr = CFR(fname)
    cfr.initialize()
    inverse_sequence_mapping = {v: k for k, v in sequences.items()}
    strategy = cfr.convert_sequence_to_strategy(sequence, sequences, 1)
    print(strategy)
    new_sequence = cfr.convert_strategy_to_sequence(strategy, inverse_sequence_mapping, 1)
    print(new_sequence)


def convert_nfg_to_efg():
    folder = "data/nfgs/*"
    fnames = glob.glob(folder)
    for fname in fnames:
        nfg = Game()
        nfg.load_from_file(fname)
        efg = ExtensiveGame()
        efg.load_from_nfg(nfg)
        efg.save_to_file(fname.replace(".gbt", ".efg"))


def evaluate_all_opponent_modeling_matrix_different_p(fname_part, strategy_part, strat_name, test=True, splits=11):
    for p in np.linspace(0, 1, splits):
        p = np.round(p, 1)
        evaluate_all_opponent_modeling_matrix(fname_part, strategy_part, strat_name, test=test, p=p)


def prospect_theory_rl():
    ptrl = ProspectTheoryRl()
    ptrl.solve()


def step_best_response_test():
    for i in range(1, 100):
        strategy = load_from_file("data/step_br_leduc/strategies/{:03d}".format(i))
        cfrd = CFRD("data/step_br_leduc/leduc_holdem")
        cfrd.solve(1000, fixed_strategy=[{}, strategy], verbose=2)
        trunk_br = cfrd.average_strategy

        cfr_full = CFR("data/step_br_leduc/leduc_holdem.efg")
        cfr_full.solve(1000, fixed=[{}, strategy])
        full_br = cfr_full.average_strategy

        cfr_step_br = CFR("data/step_br_leduc/leduc_holdem.efg")
        cfr_step_br.solve(1000, fixed=[trunk_br[0], strategy])
        step_br = cfr_step_br.average_strategy

        save_to_file({"full_br": full_br, "step_br": step_br}, "data/step_br_leduc/results/strategies{:03d}".format(i))


def compare_step_br_results():
    step_br_acc = []
    full_br_acc = []
    best_rr_acc = []
    nash_acc = []
    nash = load_from_file("data/step_br_leduc/strategies/nash")
    for fname in glob.glob("data/step_br_leduc/results/*"):
        number_part = fname[-3:]
        responses = load_from_file(fname)
        full_br = responses["full_br"]
        step_br = responses["step_br"]

        cfr = CFR("data/step_br_leduc/leduc_holdem.efg")
        cfr.initialize()

        strategy = load_from_file("data/step_br_leduc/strategies/" + number_part)

        step_br_strat = [step_br[0], strategy]
        full_br_strat = [full_br[0], strategy]
        nash_strat = [nash[0], strategy]

        cfr.compute_game_value(nash_strat)
        print(fname)
        print("Nash val:", cfr.game_value)
        nash_acc.append(-cfr.game_value)

        cfr.compute_game_value(step_br_strat)
        print("Step br:", cfr.game_value)
        step_br_acc.append(-cfr.game_value)

        cfr.compute_game_value(full_br_strat)
        print("Full br:", cfr.game_value)
        full_br_acc.append(-cfr.game_value)

        normal_br = cfr.best_response(0, step_br_strat)[0]
        print("Normal br", normal_br)
        best_rr_acc.append(-normal_br)
    plt.xticks([0, 1, 2, 3], ["NASH", "SBR", "CFRBR", "BR"])
    plt.bar([0], [np.average(nash_acc)], yerr=[scipy.stats.sem(nash_acc) * scipy.stats.t.ppf(1.95 / 2., len(nash_acc))])
    plt.bar([1], [np.average(step_br_acc)], yerr=[scipy.stats.sem(step_br_acc) * scipy.stats.t.ppf(1.95 / 2., len(step_br_acc))])
    plt.bar([2], [np.average(full_br_acc)], yerr=[scipy.stats.sem(full_br_acc) * scipy.stats.t.ppf(1.95 / 2., len(full_br_acc))])
    plt.bar([3], [np.average(best_rr_acc)], yerr=[scipy.stats.sem(best_rr_acc) * scipy.stats.t.ppf(1.95 / 2., len(best_rr_acc))])
    plt.title("Comparison of utility in leduc hold'em")
    plt.xlabel("Method (Lines show 95% confidence intervals.)")
    plt.ylabel("Expected utility.")
    plt.axes().yaxis.set_minor_locator(AutoMinorLocator(5))
    plt.grid(which="major", axis="y")
    plt.grid(which="minor", axis="y", alpha=0.5)
    plt.show()


def plot_sbr_gain_data(fname, steps):
    if fname is None:
        br_values = [0.555556, 0.222222, 0.166667, 0.14, 0.115816, 0.0590131, 0.0486752, 0.0212319, 0.0117931, 0.0126874, 0.00789421]
        sbr_values = [0.24728, 0.212844, 0.131312, 0.0971159, 0.0902819, 0.0497432, 0.0365682, 0.0147012, 0.00736657, 0.00788091, 0.00410631]
        iterations = fibonacci_array(len(br_values))
    else:
        br_values = []
        sbr_values = []
        iterations = []
        with open(fname, "r") as file:
            for line in file:
                if line.startswith("#") or line.startswith("steps"):
                    continue
                tokens = line.split()
                iterations.append(int(tokens[0]))
                br_values.append(float(tokens[2]))
                sbr_values.append(float(tokens[1]))
    positions = list(range(len(br_values)))
    br_positions = []
    sbr_positions = []
    diff = 0.2
    for pos in positions:
        br_positions.append(pos + diff)
        sbr_positions.append(pos - diff)
    plt.bar(br_positions, br_values, width=0.4, label="Best response")
    plt.bar(sbr_positions, sbr_values, width=0.4, label="Step best response")
    plt.xlabel("Cfr iterations of opponent strategy")
    plt.ylabel("Gain")
    plt.grid(which="major", axis="y")
    plt.grid(which="minor", axis="y", alpha=0.3)
    plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
    plt.axes().set_axisbelow(True)
    plt.xticks(positions, iterations)
    plt.title("Gain comparison of BR and SRNR against CFR with low iterations\n"
              "on " + fname[fname.rfind("/") + 5:fname.rfind("_step")] + " with step size " + str(steps))
    plt.legend()
    plt.show()


def plot_sbr_expl_data(fname, steps):
    if fname is None:
        br_values = [0.555556, 0.222222, 0.166667, 0.14, 0.115816, 0.0590131, 0.0486752, 0.0212319, 0.0117931, 0.0126874, 0.00789421]
        sbr_values = [0.24728, 0.212844, 0.131312, 0.0971159, 0.0902819, 0.0497432, 0.0365682, 0.0147012, 0.00736657, 0.00788091, 0.00410631]
        iterations = fibonacci_array(len(br_values))
    else:
        br_values = []
        sbr_values = []
        iterations = []
        with open(fname, "r") as file:
            for line in file:
                if line.startswith("#") or line.startswith("steps"):
                    continue
                tokens = line.split()
                iterations.append(int(tokens[0]))
                br_values.append(float(tokens[2]))
                sbr_values.append(float(tokens[1]))
    positions = list(range(len(br_values)))
    br_positions = []
    sbr_positions = []
    diff = 0.2
    for pos in positions:
        br_positions.append(pos + diff)
        sbr_positions.append(pos - diff)
    plt.bar(br_positions, br_values, width=0.4, label="Best response")
    plt.bar(sbr_positions, sbr_values, width=0.4, label="Step best response")
    plt.xlabel("Cfr iterations of opponent strategy")
    plt.ylabel("Exploitability")
    plt.grid(which="major", axis="y")
    plt.grid(which="minor", axis="y", alpha=0.3)
    plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
    plt.axes().set_axisbelow(True)
    plt.xticks(positions, iterations)
    plt.title("Exploitability comparison of BR and SRNR against CFR with low iterations\n"
              "on " + fname[fname.rfind("/") + 5:fname.rfind("_step")] + " with step size " + str(steps))
    plt.legend()
    plt.show()


def plot_srnr_gain_data(fname, steps, p):
    if fname is None:
        br_values = [0.555556, 0.222222, 0.166667, 0.14, 0.115816, 0.0590131, 0.0486752, 0.0212319, 0.0117931, 0.0126874, 0.00789421]
        sbr_values = [0.24728, 0.212844, 0.131312, 0.0971159, 0.0902819, 0.0497432, 0.0365682, 0.0147012, 0.00736657, 0.00788091, 0.00410631]
        iterations = fibonacci_array(len(br_values))
    else:
        br_values = []
        sbr_values = []
        iterations = []
        with open(fname, "r") as file:
            for line in file:
                if line.startswith("#") or line.startswith("steps"):
                    continue
                tokens = line.split()
                iterations.append(int(tokens[0]))
                br_values.append(float(tokens[2]))
                sbr_values.append(float(tokens[1]))
    positions = list(range(len(br_values)))
    br_positions = []
    sbr_positions = []
    diff = 0.2
    for pos in positions:
        br_positions.append(pos + diff)
        sbr_positions.append(pos - diff)
    plt.bar(br_positions, br_values, width=0.4, label="Best response")
    plt.bar(sbr_positions, sbr_values, width=0.4, label="Step restrcited Nash response")
    plt.xlabel("Cfr iterations of opponent strategy")
    plt.ylabel("Gain")
    plt.grid(which="major", axis="y")
    plt.grid(which="minor", axis="y", alpha=0.3)
    plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
    plt.axes().set_axisbelow(True)
    plt.xticks(positions, iterations)
    plt.title("Gain comparison of BR and SRNR against CFR with low iterations\n"
              "on " + fname[fname.rfind("/") + 6:fname.rfind("_step")] + " with step size " + str(steps) + " and p = " + str(p))
    plt.legend()
    plt.show()


def plot_srnr_expl_data(fname, steps, p):
    if fname is None:
        br_values = [0.555556, 0.222222, 0.166667, 0.14, 0.115816, 0.0590131, 0.0486752, 0.0212319, 0.0117931, 0.0126874, 0.00789421]
        sbr_values = [0.24728, 0.212844, 0.131312, 0.0971159, 0.0902819, 0.0497432, 0.0365682, 0.0147012, 0.00736657, 0.00788091, 0.00410631]
        iterations = fibonacci_array(len(br_values))
    else:
        br_values = []
        sbr_values = []
        iterations = []
        with open(fname, "r") as file:
            for line in file:
                if line.startswith("#") or line.startswith("steps"):
                    continue
                tokens = line.split()
                iterations.append(int(tokens[0]))
                br_values.append(float(tokens[2]))
                sbr_values.append(float(tokens[1]))
    positions = list(range(len(br_values)))
    br_positions = []
    sbr_positions = []
    diff = 0.2
    for pos in positions:
        br_positions.append(pos + diff)
        sbr_positions.append(pos - diff)
    plt.bar(br_positions, br_values, width=0.4, label="Best response")
    plt.bar(sbr_positions, sbr_values, width=0.4, label="Step restrcited Nash response")
    plt.xlabel("Cfr iterations of opponent strategy")
    plt.ylabel("Exploitability")
    plt.grid(which="major", axis="y")
    plt.grid(which="minor", axis="y", alpha=0.3)
    plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
    plt.axes().set_axisbelow(True)
    plt.xticks(positions, iterations)
    plt.title("Exploitability comparison of BR and SRNR against CFR with low iterations\n"
              "on " + fname[fname.rfind("/") + 6:fname.rfind("_step")] + " with step size " + str(steps) + " and p = " + str(p))
    plt.legend()
    plt.show()


def plot_sbr_steps():
    br = 2.1731064240779956
    sbrs = 1.0565203511255392, 1.3281244185957957, 1.4578286463002184, 1.9022730907446637, 2.085606424077993
    pos = [1, 2, 3, 4, 5]
    plt.xlim(0.3, 5.7)
    plt.bar(pos, sbrs, width=0.8, label="Step best responses")
    plt.hlines([br], xmin=0, xmax=len(pos) + 2, label="Best response", color="orange")
    plt.xlabel("Step size")
    plt.ylabel("Gain")
    plt.grid(which="major", axis="y")
    plt.grid(which="minor", axis="y", alpha=0.3)
    plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
    plt.axes().set_axisbelow(True)
    plt.title("Step best response with varying step size on Leduc")
    plt.legend(loc=(0.05, 0.7))
    plt.show()


def step_best_response(game, fixed_strategy, fixed_player, step_scheme, dynamic_steps=True):
    sbr_strategy = {}
    nash_solver = SequenceNash("data/one_card_poker.efg")
    if dynamic_steps:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth(fixed_player, fixed_strategy, depth + step_scheme)
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - 1)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    elif step_scheme is None:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth(fixed_player, fixed_strategy, depth)
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - 1)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    elif type(step_scheme) == int:
        depth = 0
        for i in range(game.max_depth() + 1):
            depth += step_scheme
            game.fix_strategy_until_depth(fixed_player, fixed_strategy, depth)
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - step_scheme)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    else:
        depth = 0
        for i in range(len(step_scheme)):
            depth += step_scheme[i]
            game.fix_strategy_until_depth(fixed_player, fixed_strategy, depth)
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - step_scheme[i])
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    return sbr_strategy


def step_restricted_nash_response2(game, fixed_strategy, fixed_player, step_scheme, dynamic_steps=True):
    sbr_strategy = {}
    nash_solver = SequenceNash("data/one_card_poker.efg")
    if dynamic_steps:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth + step_scheme, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth)

            # for i in game.root.children[0]:


def find_infosets_in_depth(game, max_depth, player):
    info_dict = {}
    _find_infosets_in_depth(game.root, max_depth, player, 0, info_dict)
    return info_dict


def _find_infosets_in_depth(node, max_depth, player, depth, info_dict):
    if max_depth != depth:
        for child in node.children:
            _find_infosets_in_depth(child, max_depth, player, depth + 1, info_dict)
    else:
        if node.player == player:
            if not node.i_set in info_dict:
                info_dict[node.i_set] = []
            info_dict[node.i_set].append(node)


###
#   TODO:   Function which creates subgame from the current game (with chance as root)
#           Function which fixes strategy for only some nodes (likely done)
#


def ses(game, fixed_strategy, fixed_player, step_scheme, dynamic_steps=True):
    sbr_strategy = {}
    nash_solver = SequenceNash("data/one_card_poker.efg")
    if dynamic_steps:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    elif step_scheme is None:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    elif type(step_scheme) == int:
        depth = 0
        for i in range(game.max_depth() + 1):
            depth += step_scheme
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth - step_scheme + 1, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - step_scheme + 1)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    else:
        depth = 0
        for i in range(len(step_scheme)):
            depth += step_scheme[i]
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth - step_scheme[i] + 1, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - step_scheme[i] + 1)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    game.save_to_file("data/test.efg")
    return sbr_strategy


def step_restricted_nash_response(game, fixed_strategy, fixed_player, step_scheme, dynamic_steps=True):
    sbr_strategy = {}
    nash_solver = SequenceNash("data/one_card_poker.efg")
    if dynamic_steps:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth + step_scheme, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    elif step_scheme is None:
        for depth in range(game.max_depth() + 1):
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    elif type(step_scheme) == int:
        depth = 0
        for i in range(game.max_depth() + 1):
            depth += step_scheme
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - step_scheme + 1)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    else:
        depth = 0
        for i in range(len(step_scheme)):
            depth += step_scheme[i]
            game.fix_strategy_until_depth_from_node(fixed_player, fixed_strategy, depth, game.root.children[0])
            if sbr_strategy:
                game.fix_strategy_until_depth(1 - fixed_player, sbr_strategy, depth - step_scheme[i] + 1)
            nash_solver.game = game
            nash_solver.solve(1 - fixed_player)
            strategy = nash_solver.strategy_in_cfr_format()[1 - fixed_player]
            for key, value in strategy.items():
                sbr_strategy[key] = value
    game.save_to_file("data/test.efg")
    return sbr_strategy


def sbr_combination_level_lp_test(fname, steps=1, p=0.5, dynamic_steps=True):
    game = ExtensiveGame()
    game.load(fname)
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    cfr = CFR(fname)
    cfr.initialize()
    rnr_generator = GenerateRNRGameNotFixed(fname)
    temp_fname = "data/temp.efg"
    rnr_generator.generate(temp_fname, [0, 1], 1, [cfr.strategy[1]])
    rnr_game = ExtensiveGame()
    rnr_game.load(temp_fname)
    srnr = step_restricted_nash_response(rnr_game, cfr.strategy[1], 1, steps, dynamic_steps)
    nash_equilibrium = nash_solver.strategy_in_cfr_format()[0]
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/comb_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as gain_file, open(
            "results/br_sbr/comb_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter_expl" + ("_dynamic" if dynamic_steps else "") + ".txt",
            "w+") as expl_file:
        gain_file.write("# comparison of gain of COMB and BR on " + game_name + ". Each line are different iterations in format"
                                                                                "num_iterations sbr_value br_value\n")
        gain_file.write("steps " + str(steps) + "\n")
        expl_file.write("# comparison of exploitability of COMB and BR on " + game_name + ". Each line are different iterations in format"
                                                                                          "num_iterations sbr_value br_value\n")
        expl_file.write("steps " + str(steps) + "\n")
        for i in fibonacci_array(15, zero_start=True):
            opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            sbr = step_best_response(game, opponent_strategy, 1, steps, dynamic_steps)
            game.save_to_file("results/games_to_check/sbr.efg")
            game.load(fname)
            combination_tool = Combination([sbr, opponent_strategy], [srnr, opponent_strategy], game)
            combination = combination_tool.combine_strategies(p)[0]
            game.load(fname)
            # game.print_strategy_at_depth(sbr, 1, 0)
            joint_strategy = [combination, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            combination_value = -cfr.game_value - game_value
            print("COMB:", combination_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            gain_file.write(str(i) + " " + str(combination_value) + " " + str(br_value) + "\n")
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            combination_expl = cfr.best_response(1, [combination, {}])[0] + game_value
            print("Exploitability:", combination_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(combination_expl) + " " + str(br_expl) + "\n")


def sbr_level_lp_test(fname, steps=1, dynamic_steps=True):
    game = ExtensiveGame()
    game.load(fname)
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/sbr_" + game_name + "_step_" + str(steps) + "_cfr_var_iter" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as gain_file, open(
            "results/br_sbr/sbr_" + game_name + "_step_" + str(steps) + "_cfr_var_iter_expl" + ("_dynamic" if dynamic_steps else "") + ".txt",
            "w+") as expl_file:
        gain_file.write("# comparison of gain of SBR and BR on " + game_name + ". Each line are different iterations in format"
                                                                               "num_iterations sbr_value br_value\n")
        gain_file.write("steps " + str(steps) + "\n")
        expl_file.write("# comparison of exploitability of SBR and BR on " + game_name + ". Each line are different iterations in format"
                                                                                         "num_iterations sbr_value br_value\n")
        expl_file.write("steps " + str(steps) + "\n")
        for i in fibonacci_array(15, zero_start=True):
            opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            sbr = step_best_response(game, opponent_strategy, 1, steps, dynamic_steps)
            game.save_to_file("results/games_to_check/sbr.efg")
            game.load(fname)
            # game.print_strategy_at_depth(sbr, 1, 0)
            joint_strategy = [sbr, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            sbr_value = -cfr.game_value - game_value
            print("SBR:", sbr_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            gain_file.write(str(i) + " " + str(sbr_value) + " " + str(br_value) + "\n")
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            sbr_expl = cfr.best_response(1, [sbr, {}])[0] + game_value
            print("Exploitability:", sbr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(sbr_expl) + " " + str(br_expl) + "\n")


def sbr_level_lp_test_random_strategies(fname, steps=1, dynamic_steps=True):
    game = ExtensiveGame()
    game.load(fname)
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/sbr_" + game_name + "_step_" + str(steps) + "_random_seeded" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as gain_file, open(
            "results/br_sbr/sbr_" + game_name + "_step_" + str(steps) + "_random_seeded_expl" + ("_dynamic" if dynamic_steps else "") + ".txt",
            "w+") as expl_file:
        gain_file.write("# comparison of gain of SBR and BR on " + game_name + ". Each line are different iterations in format"
                                                                               "num_iterations sbr_value br_value\n")
        gain_file.write("steps " + str(steps) + "\n")
        expl_file.write("# comparison of exploitability of SBR and BR on " + game_name + ". Each line are different iterations in format"
                                                                                         "num_iterations sbr_value br_value\n")
        expl_file.write("steps " + str(steps) + "\n")
        for i in range(10):
            opponent_strategy = load_from_file("data/random_strategies/" + game_name + "_seed_" + str(i) + ".strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            sbr = step_best_response(game, opponent_strategy, 1, steps, dynamic_steps)
            game.save_to_file("results/games_to_check/sbr.efg")
            game.load(fname)
            # game.print_strategy_at_depth(sbr, 1, 0)
            joint_strategy = [sbr, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            sbr_value = -cfr.game_value - game_value
            print("SBR:", sbr_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            gain_file.write(str(i) + " " + str(sbr_value) + " " + str(br_value) + "\n")
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            sbr_expl = cfr.best_response(1, [sbr, {}])[0] + game_value
            print("Exploitability:", sbr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(sbr_expl) + " " + str(br_expl) + "\n")


def srnr_level_lp_test(fname, steps=1, p=0.5, dynamic_steps=True):
    game = ExtensiveGame()
    game.load(fname)
    rnr_game = ExtensiveGame()
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    nash_strategy = nash_solver.strategy_in_cfr_format()
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/srnr_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as gain_file, open(
            "results/br_sbr/srnr_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter_expl" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as expl_file:
        gain_file.write("# comparison of SRNR and BR on " + game_name + ". Each line are different iterations in format"
                                                                        "num_iterations sbr_value br_value\n")
        gain_file.write("steps " + str(steps) + "\n")
        expl_file.write("# comparison of exploitability of SRNR and BR on " + game_name + ". Each line are different iterations in format"
                                                                                          "num_iterations sbr_value br_value\n")
        expl_file.write("steps " + str(steps) + "\n")
        for i in [2]:
            # for i in fibonacci_array(15, zero_start=True):
            opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            rnr_generator = GenerateRNRGameNotFixed(fname)
            temp_fname = "data/temp.efg"
            rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
            rnr_game.load(temp_fname)
            srnr = step_restricted_nash_response(rnr_game, opponent_strategy, 1, steps, dynamic_steps)
            rnr_game.save_to_file("results/games_to_check/srnr.efg")
            joint_strategy = [srnr, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            srnr_value = -cfr.game_value - game_value
            print("SRNR:", srnr_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            gain_file.write(str(i) + " " + str(srnr_value) + " " + str(br_value) + "\n")
            srnr_expl = cfr.best_response(1, [srnr, {}])[0] + game_value
            nash_expl = cfr.best_response(1, [nash_strategy[0], {}])[0] + game_value
            assert nash_expl < 0.000001
            print("Nash Exploitability:", nash_expl)
            print("Exploitability:", srnr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl, flush=True)
            expl_file.write(str(i) + " " + str(srnr_expl) + " " + str(br_expl) + "\n")


def ses_lp_test(fname, steps=1, p=0.5, dynamic_steps=True):
    game = ExtensiveGame()
    game.load(fname)
    rnr_game = ExtensiveGame()
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    nash_strategy = nash_solver.strategy_in_cfr_format()
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/ses_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as gain_file, open(
            "results/br_sbr/ses_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter_expl" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as expl_file:
        gain_file.write("# comparison of SES and BR on " + game_name + ". Each line are different iterations in format"
                                                                       "num_iterations sbr_value br_value\n")
        gain_file.write("steps " + str(steps) + "\n")
        expl_file.write("# comparison of exploitability of SES and BR on " + game_name + ". Each line are different iterations in format"
                                                                                         "num_iterations sbr_value br_value\n")
        expl_file.write("steps " + str(steps) + "\n")
        for i in fibonacci_array(15, zero_start=True):
            opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            rnr_generator = GenerateRNRGameNotFixed(fname)
            temp_fname = "data/temp.efg"
            rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
            rnr_game.load(temp_fname)
            srnr = ses(rnr_game, opponent_strategy, 1, steps, dynamic_steps)
            rnr_game.save_to_file("results/games_to_check/srnr.efg")
            joint_strategy = [srnr, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            srnr_value = -cfr.game_value - game_value
            print("SRNR:", srnr_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            gain_file.write(str(i) + " " + str(srnr_value) + " " + str(br_value) + "\n")
            srnr_expl = cfr.best_response(1, [srnr, {}])[0] + game_value
            nash_expl = cfr.best_response(1, [nash_strategy[0], {}])[0] + game_value
            assert nash_expl < 0.000001
            print("Nash Exploitability:", nash_expl)
            print("Exploitability:", srnr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(srnr_expl) + " " + str(br_expl) + "\n")


def srnr_level_lp_test_random_strategies(fname, steps=1, p=0.5, dynamic_steps=True):
    game = ExtensiveGame()
    game.load(fname)
    rnr_game = ExtensiveGame()
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    nash_strategy = nash_solver.strategy_in_cfr_format()
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/srnr_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_random_seeded" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as gain_file, open(
            "results/br_sbr/srnr_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_random_seeded_expl" + ("_dynamic" if dynamic_steps else "") + ".txt", "w+") as expl_file:
        gain_file.write("# comparison of SRNR and BR on " + game_name + ". Each line are different iterations in format"
                                                                        "num_iterations sbr_value br_value\n")
        gain_file.write("steps " + str(steps) + "\n")
        expl_file.write("# comparison of exploitability of SRNR and BR on " + game_name + ". Each line are different iterations in format"
                                                                                          "num_iterations sbr_value br_value\n")
        expl_file.write("steps " + str(steps) + "\n")
        for i in range(10):
            opponent_strategy = load_from_file("data/random_strategies/" + game_name + "_seed_" + str(i) + ".strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            rnr_generator = GenerateRNRGameNotFixed(fname)
            temp_fname = "data/temp.efg"
            rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
            rnr_game.load(temp_fname)
            srnr = step_restricted_nash_response(rnr_game, opponent_strategy, 1, steps, dynamic_steps)
            rnr_game.save_to_file("results/games_to_check/srnr.efg")
            joint_strategy = [srnr, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            srnr_value = -cfr.game_value - game_value
            print("SRNR:", srnr_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            gain_file.write(str(i) + " " + str(srnr_value) + " " + str(br_value) + "\n")
            srnr_expl = cfr.best_response(1, [srnr, {}])[0] + game_value
            nash_expl = cfr.best_response(1, [nash_strategy[0], {}])[0] + game_value
            assert nash_expl < 0.000001
            print("Nash Exploitability:", nash_expl)
            print("Exploitability:", srnr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(srnr_expl) + " " + str(br_expl) + "\n")


def generate_random_strategies(fname):
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    cfr = CFR(fname)
    cfr.initialize()
    for i in range(10):
        cfr.initialize_random_strategy(i)
        save_to_file(cfr.average_strategy, "data/random_strategies/" + game_name + "_seed_" + str(i) + ".strat")


def generate_bad_strategies(fname):
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    cfr = CFR(fname)
    cfr.initialize()
    for i in fibonacci_array(15, zero_start=True):
        cfr.solve(i, initialize=False)
        average_strategy = cfr.average_strategy
        save_to_file(average_strategy, "data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")


def load_lbrs(file_name, value_limit, include_average):
    iterations = []
    values = []
    confidences = []
    with open(file_name, "r") as file:
        value_count = 0
        for line in file:
            if value_count == value_limit:
                break
            if line.startswith("#"):
                continue
            tokens = line.split()
            iterations.append(int(tokens[0]))
            values.append(float(tokens[1]))
            confidences.append(float(tokens[2]))
            value_count += 1
    if include_average:
        iterations.append("a")
        values.append(np.average(values))
        confidences.append(np.average(confidences))
    return iterations, values, confidences


def load_sbrs(file_name, steps, value_limit, include_average):
    iterations = []
    sbr_values = []
    br_values = []
    for step in steps:
        local_iterations = []
        local_sbr_values = []
        local_br_values = []
        value_count = 0
        with open(file_name.format(step), "r") as file:
            for line in file:
                if value_count == value_limit:
                    break
                if line.startswith("#") or line.startswith("steps"):
                    continue
                tokens = line.split()
                local_iterations.append(int(tokens[0]))
                local_sbr_values.append(float(tokens[1]))
                local_br_values.append(float(tokens[2]))
                value_count += 1
        if include_average:
            local_iterations.append("a")
            local_sbr_values.append(np.average(local_sbr_values))
            local_br_values.append(np.average(local_br_values))
        iterations.append(local_iterations)
        sbr_values.append(local_sbr_values)
        br_values.append(local_br_values)
    return iterations, sbr_values, br_values


def check_iterations_consistency_only_sbr(sbr_iterations):
    print(sbr_iterations)
    for step in range(1, len(sbr_iterations)):
        if not np.array_equal(sbr_iterations[0], sbr_iterations[step]):
            return False
    return True


def check_iterations_consistency(lbr_iterations, sbr_iterations):
    for index, iteration in enumerate(lbr_iterations):
        for sbr_iterations_part in sbr_iterations:
            if iteration != sbr_iterations_part[index]:
                return False
    return True


def check_br_consistency(br_values):
    for step in range(1, len(br_values)):
        if not np.array_equal(br_values[0], br_values[step]):
            return False
    return True


def plot_leduc_lbr_sbr(steps, mode, dynamic_steps=True, values=-1, include_average=False, ax=None):
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    if mode == "cfr":
        sbr_nn_iterations, br_nn_values, sbr_nn_values = load_general_index_value_br("results/br_sbr/leduc_net_step_2_cfr.txt", values, include_average)
        lbr_iterations, lbr_values, lbr_confidences = load_lbrs("results/br_sbr/lbr_leduc.txt", values, include_average)
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/sbr_leduc_holdem_step_{0}_cfr_var_iter" + ("_dynamic" if dynamic_steps else "") + ".txt", steps, values, include_average)
    elif mode == "random":
        sbr_nn_iterations, br_nn_values, sbr_nn_values = load_general_index_value_br("results/br_sbr/leduc_net_step_2_random.txt", values, include_average)
        lbr_iterations, lbr_values, lbr_confidences = load_lbrs("results/br_sbr/lbr_leduc_seeded.txt", values, include_average)
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/sbr_leduc_holdem_step_{0}_random_seeded" + ("_dynamic" if dynamic_steps else "") + ".txt", steps, values, include_average)
    else:
        assert False, "Mode needs to be 'cfr' or 'random'"
    assert check_iterations_consistency(lbr_iterations, sbr_iterations), "Iterations differ"
    assert check_iterations_consistency(sbr_nn_iterations, sbr_iterations), "Iterations differ"
    assert check_br_consistency(br_values), "Best responses differ"
    br_values = br_values[0]
    positions = list(range(len(br_values)))
    lbr_positions = []
    br_positions = []
    sbr_positions = []
    sbr_nn_positions = []
    for step in range(len(sbr_values)):
        sbr_positions.append([])
    total_bars = len(sbr_values) + 3
    diff = 1. / (total_bars + 1)
    for pos in positions:
        start_pos = pos - 0.5
        br_positions.append(start_pos + diff)
        lbr_positions.append(start_pos + 2 * diff)
        sbr_nn_positions.append(start_pos + 3 * diff)
        for step in range(len(sbr_values)):
            sbr_positions[step].append(start_pos + diff * (step + 4))
    if ax is None:
        plt.bar(br_positions, br_values, width=1. / (total_bars + 2), label="BR")
        plt.bar(lbr_positions[:len(lbr_values)], lbr_values, yerr=lbr_confidences, width=1. / (total_bars + 2), label="LBR")
        plt.bar(sbr_nn_positions, sbr_nn_values, width=1. / (total_bars + 2), label="CDBRNN")
    else:
        ax.bar(br_positions, br_values, width=1. / (total_bars + 2), label="BR")
        ax.bar(lbr_positions[:len(lbr_values)], lbr_values, yerr=lbr_confidences, width=1. / (total_bars + 2), label="LBR")
        ax.bar(sbr_nn_positions, sbr_nn_values, width=1. / (total_bars + 2), label="CDBRNN")
    for step in range(len(sbr_values)):
        if ax is None:
            plt.bar(sbr_positions[step], sbr_values[step], width=1. / (total_bars + 2), label="CDBR" + str(steps[step]))
        else:
            ax.bar(sbr_positions[step], sbr_values[step], width=1. / (total_bars + 2), label="CDBR" + str(steps[step]))
    if mode == "cfr":
        if ax is None:
            plt.xlabel("CFR iterations of opponent's strategy")
        else:
            ax.set_xlabel("CFR iterations of opponent's strategy")
    else:
        if ax is None:
            plt.xlabel("Seed of random strategy")
        else:
            ax.set_xlabel("Seed of random strategy")
    if ax is None:
        plt.ylabel("Gain")
    else:
        ax.set_ylabel("Gain")
    if mode == "cfr":
        if ax is None:
            plt.ylim(-0.03, 3)
        else:
            ax.set_ylim(-0.03, 3)
    else:
        if ax is None:
            plt.ylim(-0.03, 3.1)
        else:
            ax.set_ylim(-0.03, 3.1)
    if ax is None:
        plt.grid(which="major", axis="y")
        plt.grid(which="minor", axis="y", alpha=0.3)
        plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
        plt.axes().set_axisbelow(True)
        plt.xticks(positions, sbr_iterations[0])
        # plt.title("LBR, CDBR and BR gain against CFR")
        plt.gcf().subplots_adjust(bottom=0.14, left=0.09, right=0.995, top=0.89)
        plt.legend(bbox_to_anchor=(0, 1.02, 1., .102), loc='lower left',
                   ncol=6, mode="expand", borderaxespad=0., handlelength=0.5, handletextpad=0.1, borderpad=0.3)
        plt.show()
    else:
        ax.grid(which="major", axis="y")
        ax.grid(which="minor", axis="y", alpha=0.3)
        ax.yaxis.set_minor_locator(AutoMinorLocator(4))
        ax.set_axisbelow(True)
        ax.set_yticks([0, 1, 2, 3])
        ax.set_xticks(positions)
        ax.set_xticklabels(sbr_iterations[0])


def plot_br_sbr_steps_gain(game_name, steps, mode, dynamic_steps=True, values=-1, include_average=False, ax=None, lim=None, ticks=None):
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    if mode == "cfr":
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/sbr_" + game_name + "_step_{0}_cfr_var_iter" + ("_dynamic" if dynamic_steps else "") + ".txt", steps, values, include_average)
    elif mode == "random":
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/sbr_" + game_name + "_step_{0}_random_seeded" + ("_dynamic" if dynamic_steps else "") + ".txt", steps, values, include_average)
    else:
        assert False, "Mode needs to be 'cfr' or 'random'"
    assert check_br_consistency(br_values), "Best responses differ"
    br_values = br_values[0]
    positions = list(range(len(br_values)))
    br_positions = []
    sbr_positions = []
    for step in range(len(sbr_values)):
        sbr_positions.append([])
    total_bars = len(sbr_values) + 1
    diff = 1. / (total_bars + 1)
    for pos in positions:
        start_pos = pos - 0.5
        br_positions.append(start_pos + diff)
        for step in range(len(sbr_values)):
            sbr_positions[step].append(start_pos + diff * (step + 2))
    if ax is None:
        plt.bar(br_positions, br_values, width=1. / (total_bars + 2), label="BR")
    else:
        ax.bar(br_positions, br_values, width=1. / (total_bars + 2), label="BR")
    for step in range(len(sbr_values)):
        if ax is None:
            plt.bar(sbr_positions[step], sbr_values[step], width=1. / (total_bars + 2), label="CDBR" + str(steps[step]))
        else:
            ax.bar(sbr_positions[step], sbr_values[step], width=1. / (total_bars + 2), label="CDBR" + str(steps[step]))
    if mode == "cfr":
        if ax is None:
            plt.xlabel("CFR iterations of opponent's strategy")
        else:
            ax.set_xlabel("CFR iterations of opponent's strategy")
    else:
        if ax is None:
            plt.xlabel("Seed of random strategy")
        else:
            ax.set_xlabel("Seed of random strategy")
    if ax is None:
        plt.ylabel("Gain")
    else:
        ax.set_ylabel("Gain")
    if lim is not None:
        if mode == "cfr":
            if ax is None:
                plt.ylim(lim[0], lim[1])
            else:
                ax.set_ylim(lim[0], lim[1])
        else:
            if ax is None:
                plt.ylim(lim[0], lim[1])
            else:
                ax.set_ylim(lim[0], lim[1])
    if ax is None:
        plt.grid(which="major", axis="y")
        plt.grid(which="minor", axis="y", alpha=0.3)
        plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
        plt.axes().set_axisbelow(True)
        plt.xticks(positions, sbr_iterations[0])
        # plt.title("LBR, CDBR and BR gain against CFR")
        plt.gcf().subplots_adjust(bottom=0.14, left=0.09, right=0.995, top=0.89)
        plt.legend(bbox_to_anchor=(0, 1.02, 1., .102), loc='lower left',
                   ncol=6, mode="expand", borderaxespad=0., handlelength=0.5, handletextpad=0.1, borderpad=0.3)
        plt.show()
    else:
        ax.grid(which="major", axis="y")
        ax.grid(which="minor", axis="y", alpha=0.3)
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        ax.set_axisbelow(True)
        if ticks is not None:
            ax.set_yticks(ticks)
        ax.set_xticks(positions)
        ax.set_xticklabels(sbr_iterations[0])


def plot_br_sbr_steps_expl(game_name, steps, mode, dynamic_steps=True, values=-1, include_average=False, ax=None, lim=None, ticks=None):
    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    if mode == "cfr":
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/sbr_" + game_name + "_step_{0}_cfr_var_iter_expl" + ("_dynamic" if dynamic_steps else "") + ".txt", steps, values, include_average)
    elif mode == "random":
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/sbr_" + game_name + "_step_{0}_random_seeded_expl" + ("_dynamic" if dynamic_steps else "") + ".txt", steps, values, include_average)
    else:
        assert False, "Mode needs to be 'cfr' or 'random'"
    assert check_br_consistency(br_values), "Best responses differ"
    br_values = br_values[0]
    positions = list(range(len(br_values)))
    br_positions = []
    sbr_positions = []
    for step in range(len(sbr_values)):
        sbr_positions.append([])
    total_bars = len(sbr_values) + 1
    diff = 1. / (total_bars + 1)
    for pos in positions:
        start_pos = pos - 0.5
        br_positions.append(start_pos + diff)
        for step in range(len(sbr_values)):
            sbr_positions[step].append(start_pos + diff * (step + 2))
    if ax is None:
        plt.bar(br_positions, br_values, width=1. / (total_bars + 2), label="BR")
    else:
        ax.bar(br_positions, br_values, width=1. / (total_bars + 2), label="BR")
    for step in range(len(sbr_values)):
        if ax is None:
            plt.bar(sbr_positions[step], sbr_values[step], width=1. / (total_bars + 2), label="CDBR" + str(steps[step]))
        else:
            ax.bar(sbr_positions[step], sbr_values[step], width=1. / (total_bars + 2), label="CDBR" + str(steps[step]))
    if mode == "cfr":
        if ax is None:
            plt.xlabel("CFR iterations of opponent's strategy")
        else:
            ax.set_xlabel("CFR iterations of opponent's strategy")
    else:
        if ax is None:
            plt.xlabel("Seed of random strategy")
        else:
            ax.set_xlabel("Seed of random strategy")
    if ax is None:
        plt.ylabel("Exploitability")
    else:
        ax.set_ylabel("Exploitability")
    if lim is not None:
        if mode == "cfr":
            if ax is None:
                plt.ylim(lim[0], lim[1])
            else:
                ax.set_ylim(lim[0], lim[1])
        else:
            if ax is None:
                plt.ylim(lim[0], lim[1])
            else:
                ax.set_ylim(lim[0], lim[1])
    if ax is None:
        plt.grid(which="major", axis="y")
        plt.grid(which="minor", axis="y", alpha=0.3)
        plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
        plt.axes().set_axisbelow(True)
        plt.xticks(positions, sbr_iterations[0])
        # plt.title("LBR, CDBR and BR gain against CFR")
        plt.gcf().subplots_adjust(bottom=0.14, left=0.09, right=0.995, top=0.89)
        plt.legend(bbox_to_anchor=(0, 1.02, 1., .102), loc='lower left',
                   ncol=6, mode="expand", borderaxespad=0., handlelength=0.5, handletextpad=0.1, borderpad=0.3)
        plt.show()
    else:
        ax.grid(which="major", axis="y")
        ax.grid(which="minor", axis="y", alpha=0.3)
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        ax.set_axisbelow(True)
        if ticks is not None:
            ax.set_yticks(ticks)
        ax.set_xticks(positions)
        ax.set_xticklabels(sbr_iterations[0])


def load_general_index_value_br(file_name, value_limit, include_average):
    iterations = []
    rnr_values = []
    br_values = []
    value_count = 0
    with open(file_name, "r") as file:
        for line in file:
            if value_count == value_limit:
                break
            if line.startswith("#") or line.startswith("steps"):
                continue
            tokens = line.split()
            iterations.append(int(tokens[0]))
            rnr_values.append(float(tokens[1]))
            br_values.append(float(tokens[2]))
            value_count += 1
    if include_average:
        iterations.append("a")
        rnr_values.append(np.average(rnr_values))
        br_values.append(np.average(br_values))
    return iterations, rnr_values, br_values


def plot_br_srnr_steps_gain(game_name, steps, p, dynamic_steps=True, value_limit=-1, include_average=True, ax=None, cmap=None, mode="cfr", yticks=None, print_set=None):
    plot_br_srnr_steps(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=value_limit, include_average=include_average, ax=ax, cmap=cmap, mode=mode, yticks=yticks, metric="gain", print_set=print_set)


def plot_br_srnr_steps_expl(game_name, steps, p, dynamic_steps=True, value_limit=-1, include_average=True, ax=None, cmap=None, mode="cfr", yticks=None, print_set=None):
    plot_br_srnr_steps(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=value_limit, include_average=include_average, ax=ax, cmap=cmap, mode=mode, yticks=yticks, metric="expl", print_set=print_set)


def plot_br_srnr_steps(game_name, steps_in, p, dynamic_steps=True, value_limit=-1, include_average=True, ax=None, cmap=None, mode="cfr", yticks=None, metric="gain", print_set=None):
    if metric == "gain":
        metric = ""
    elif metric == "expl":
        metric = "_expl"
    step_values = []
    step_labels = []
    step_size = []

    other_values = []
    other_labels = []
    steps = []

    step_values_count = 0
    values_count = 0

    if mode == "cfr":
        if print_set is None or "br" in print_set:
            rnr_iterations, rnr_values, rnr_br_values = load_general_index_value_br("results/br_sbr/rnr_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((rnr_iterations, rnr_br_values, rnr_br_values))
            other_labels.append(("BR", "Best response"))
            values_count += 1
        if print_set is None or "srnr" in print_set:
            srnr_iterations, srnr_values, srnr_br_values = load_sbrs("results/br_sbr/srnr_" + game_name + "_step_{0}_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[0],
                                                                     value_limit, include_average)
            step_values.append((srnr_iterations, srnr_values, srnr_br_values))
            step_labels.append(("S", "Continual depth limited best response"))
            step_size.append(len(steps_in[0]))
            values_count += len(steps_in[0])
            step_values_count += len(steps_in[0])
            steps.append(steps_in[0])
        if print_set is None or "rnr" in print_set:
            rnr_iterations, rnr_values, rnr_br_values = load_general_index_value_br("results/br_sbr/rnr_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((rnr_iterations, rnr_values, rnr_br_values))
            other_labels.append(("RNR", "Restricted Nash response"))
            values_count += 1
        if print_set is None or "bne" in print_set:
            bne_iterations, bne_values, bne_br_values = load_general_index_value_br("results/br_sbr/bne_" + game_name + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((bne_iterations, bne_values, bne_br_values))
            other_labels.append(("BNE", "Best Nash equilibrium"))
            values_count += 1
        if print_set is None or "srnrvf" in print_set:
            srnrvf_iterations, srnrvf_values, srnrvf_br_values = load_general_index_value_br("results/br_sbr/srnr_vf_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((srnrvf_iterations, srnrvf_values, srnrvf_br_values))
            other_labels.append(("VF", "Continual depth limited restricted Nash response with value function"))
            values_count += 1
        if print_set is None or "srnrg" in print_set:
            srnrg_iterations, srnrg_values, srnrg_br_values = load_general_index_value_br("results/br_sbr/srnr_g_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((srnrg_iterations, srnrg_values, srnrg_br_values))
            other_labels.append(("VFG", "Continual depth limited restricted Nash response with value function using gadget"))
            values_count += 1
        if print_set is None or "srnru" in print_set:
            srnru_iterations, srnru_values, srnru_br_values = load_general_index_value_br("results/br_sbr/srnr_nog_" + game_name + "_p=" + str(p) + "_cfr_var_iter" + metric + ".txt", value_limit, include_average)
            other_values.append((srnru_iterations, srnru_values, srnru_br_values))
            other_labels.append(("VFU", "Continual depth limited restricted Nash response with unsafe resolving"))
            values_count += 1
        if print_set is None or "comb" in print_set:
            comb_iterations, comb_values, comb_br_values = load_sbrs(
                "results/br_sbr/comb_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[2], value_limit, include_average)
            step_values.append((comb_iterations, comb_values, comb_br_values))
            step_labels.append(("C", "Combination of NE and BR"))
            step_size.append(len(steps_in[2]))
            values_count += len(steps_in[2])
            step_values_count += len(steps_in[2])
            steps.append(steps_in[2])
        if print_set is None or "ses" in print_set:
            ses_iterations, ses_values, ses_br_values = load_sbrs(
                "results/br_sbr/ses_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps_in[1], value_limit, include_average)
            step_values.append((ses_iterations, ses_values, ses_br_values))
            step_labels.append(("SES", "Safe exploitation search"))
            step_size.append(len(steps_in[1]))
            values_count += len(steps_in[1])
            step_values_count += len(steps_in[1])
            steps.append(steps_in[1])
    elif mode == "random":
        sbr_iterations, sbr_values, br_values = load_sbrs("results/br_sbr/srnr_" + game_name + "_step_{0}_p=" + str(p) + "_random_seeded" + metric + ("_dynamic" if dynamic_steps else "") + ".txt", steps, value_limit,
                                                          include_average)
        rnr_iterations, rnr_values, rnr_br_values = load_general_index_value_br("results/br_sbr/rnr_" + game_name + "_p=" + str(p) + "_random_seeded" + metric + ".txt", value_limit, include_average)
        bne_iterations, bne_values, bne_br_values = load_general_index_value_br("results/br_sbr/bne_" + game_name + "_random_seeded" + metric + ".txt", value_limit, include_average)
        srnrvf_iterations, srnrvf_values, srnrvf_br_values = load_general_index_value_br("results/br_sbr/srnr_vf_" + game_name + "_p=" + str(p) + "_random_seeded" + metric + ".txt", value_limit, include_average)
        comb_iterations, comb_values, comb_br_values = load_general_index_value_br(
            "results/br_sbr/comb_" + game_name + "_step_1" + "_p=" + str(p) + "_random_seeded" + metric + ("_dynamic" if dynamic_steps else "") + ".txt",
            value_limit, include_average)
    positions = list(range(len(other_values[0][0])))
    other_positions = []
    step_positions = []
    for i in range(len(other_values)):
        other_positions.append([])
    for i in range(len(step_values)):
        step_positions.append([])
        for _ in range(step_size[i]):
            step_positions[i].append([])
    total_bars = values_count
    diff = 1. / (total_bars + 1)
    for pos in positions:
        start_pos = pos - 0.5
        for i in range(len(other_positions)):
            other_positions[i].append(start_pos + diff * (i + 1))
        offset = 0
        for i in range(len(step_positions)):
            for step in range(step_size[i]):
                step_positions[i][step].append(start_pos + diff * ((values_count - step_values_count) + 1 + offset))
                offset += 1
    if ax is None:
        for i in range(values_count - step_values_count):
            plt.bar(other_positions[i], other_values[i][1], width=1. / (total_bars + 2), label=other_labels[i][1])
        for i in range(len(step_positions)):
            for step in range(step_size[i]):
                local_label = str(step_labels[i][1]) + str(steps[i][step])
                plt.bar(step_positions[i][step], step_values[i][1][step], width=1. / (total_bars + 2), label=local_label)
        plt.xlabel("Cfr iterations of opponent strategy")
        if metric == "":
            plt.ylabel("Gain")
        else:
            plt.ylabel("Exploitability")
        plt.grid(which="major", axis="y")
        plt.grid(which="minor", axis="y", alpha=0.3)
        plt.axes().yaxis.set_minor_locator(AutoMinorLocator(4))
        plt.axes().set_axisbelow(True)
        plt.xticks(positions, other_values[0][0])
        plt.title("Gain comparison of SRNR and BR against CFR with low iterations\n"
                  "on " + game_name + " with varying SBR step sizes and p = " + str(p))
        plt.legend()
        plt.show()
    else:
        for i in range(values_count - step_values_count):
            ax.bar(other_positions[i], other_values[i][1], width=1. / (total_bars + 2), label=other_labels[i][0])
        for i in range(len(step_positions)):
            for step in range(step_size[i]):
                local_label = str(step_labels[i][0]) + str(steps[i][step])
                ax.bar(step_positions[i][step], step_values[i][1][step], width=1. / (total_bars + 2), label=local_label)
        if metric == "":
            ax.set_ylabel("Gain")
        else:
            ax.set_ylabel("Exploitability")
        ax.grid(which="major", axis="y")
        ax.grid(which="minor", axis="y", alpha=0.3)
        if yticks is not None:
            ax.set_yticks(yticks)
        ax.yaxis.set_minor_locator(AutoMinorLocator(4))
        ax.set_axisbelow(True)
        ax.set_xticks(positions)
        ax.set_xticklabels(other_values[0][0])


def save_cfr_strategies_for_open_spiel(game_name):
    game = ExtensiveGame()
    game.load("data/" + game_name + ".efg")
    infoset_to_observation_history = game.infoset_to_observation_history_only_for_poker()
    infoset_to_action_labels = game.infoset_to_action_labels()
    for i in fibonacci_array(15, zero_start=True):
        opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")
        with open("data/strategies_for_os/" + game_name + "_" + str(i) + "_iterations", "w") as file:
            for player in range(2):
                file.write("Player " + str(player) + "\n")
                for i_set in opponent_strategy[player]:
                    print(infoset_to_observation_history[player][i_set])
                    file.write(infoset_to_observation_history[player][i_set] + " [")
                    for action, prob in zip(infoset_to_action_labels[player][i_set], opponent_strategy[player][i_set]):
                        file.write(" ")
                        file.write(action + ":" + str(prob))
                    file.write(" ]\n")


def save_random_strategies_for_open_spiel(game_name):
    game = ExtensiveGame()
    game.load("data/" + game_name + ".efg")
    infoset_to_observation_history = game.infoset_to_observation_history_only_for_poker()
    infoset_to_action_labels = game.infoset_to_action_labels()
    for i in range(10):
        opponent_strategy = load_from_file("data/random_strategies/" + game_name + "_seed_" + str(i) + ".strat")
        with open("data/strategies_for_os/" + game_name + "_seed_" + str(i) + "", "w") as file:
            for player in range(2):
                file.write("Player " + str(player) + "\n")
                for i_set in opponent_strategy[player]:
                    print(infoset_to_observation_history[player][i_set])
                    file.write(infoset_to_observation_history[player][i_set] + " [")
                    for action, prob in zip(infoset_to_action_labels[player][i_set], opponent_strategy[player][i_set]):
                        file.write(" ")
                        file.write(action + ":" + str(prob))
                    file.write(" ]\n")


def rnr_computation(fname, p=0.5, mode="cfr"):
    game = ExtensiveGame()
    game.load(fname)
    rnr_game = ExtensiveGame()
    print(game.max_depth())
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    nash_strategy = nash_solver.strategy_in_cfr_format()
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/rnr_" + game_name + "_p=" + str(round(1 - p, 1)) + ("_cfr_var_iter.txt" if mode == "cfr" else "_random_seeded.txt"), "w+") as gain_file, open(
            "results/br_sbr/rnr_" + game_name + "_p=" + str(round(1 - p, 1)) + ("_cfr_var_iter_expl.txt" if mode == "cfr" else "_random_seeded_expl.txt"), "w+") as expl_file:
        gain_file.write("# comparison of SRNR and BR on " + game_name + ". Each line are different iterations in format"
                                                                        "num_iterations sbr_value br_value\n")
        expl_file.write("# comparison of exploitability of SRNR and BR on " + game_name + ". Each line are different iterations in format"
                                                                                          "num_iterations sbr_value br_value\n")
        for i in (fibonacci_array(10, zero_start=True) if mode == "cfr" else range(10)):
            if mode == "cfr":
                opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")[1]
            else:
                opponent_strategy = load_from_file("data/random_strategies/" + game_name + "_seed_" + str(i) + ".strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            game.load(fname)
            rnr_generator = GenerateRNRGame(fname)
            temp_fname = "data/temp.efg"
            rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
            rnr_nash = SequenceNash(temp_fname)
            rnr_nash.solve(player=0)
            rnr_strategy = rnr_nash.strategy_in_cfr_format()[0]
            joint_strategy = [rnr_strategy, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i) if mode == "cfr" else print("Seed:", i)
            rnr_value = -cfr.game_value - game_value
            print("RNR:", rnr_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            gain_file.write(str(i) + " " + str(rnr_value) + " " + str(br_value) + "\n")
            srnr_expl = cfr.best_response(1, [rnr_strategy, {}])[0] + game_value
            nash_expl = cfr.best_response(1, [nash_strategy[0], {}])[0] + game_value
            assert nash_expl < 0.000001
            print("Nash Exploitability:", nash_expl)
            print("Exploitability:", srnr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(srnr_expl) + " " + str(br_expl) + "\n")


def best_ne_computation(fname, mode="cfr"):
    game = ExtensiveGame()
    game.load(fname)
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    nash_strategy = nash_solver.strategy_in_cfr_format()
    cfr = CFR(fname)
    cfr.initialize()
    game_name = fname[fname.rfind("/") + 1:fname.rfind(".")]
    with open("results/br_sbr/bne_" + game_name + ("_cfr_var_iter.txt" if mode == "cfr" else "_random_seeded.txt"), "w+") as gain_file, open(
            "results/br_sbr/bne_" + game_name + ("_cfr_var_iter_expl.txt" if mode == "cfr" else "_random_seeded_expl.txt"), "w+") as expl_file:
        gain_file.write("# comparison of Best Nash and BR on " + game_name + ". Each line are different iterations in format"
                                                                             "num_iterations sbr_value br_value\n")
        expl_file.write("# comparison of exploitability of Best Nash and BR on " + game_name + ". Each line are different iterations in format"
                                                                                               "num_iterations sbr_value br_value\n")
        for i in [2]:
            # for i in (fibonacci_array(10, zero_start=True) if mode == "cfr" else range(10)):
            if mode == "cfr":
                opponent_strategy = load_from_file("data/bad_strategies/" + game_name + "/" + game_name + "_" + str(i) + "_iterations.strat")[1]
            else:
                opponent_strategy = load_from_file("data/random_strategies/" + game_name + "_seed_" + str(i) + ".strat")[1]
            game.load(fname)
            game.fix_strategy_until_depth(1, opponent_strategy, -1)
            nash_solver.game = game
            nash_solver.solve(0)
            best_response = nash_solver.strategy_in_cfr_format()

            best_nash = BestNashStatic(fname, opponent_strategy)
            best_nash.solve()
            bne_strategy = best_nash.strategy_in_cfr_format()[0]
            joint_strategy = [bne_strategy, opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            print("Iterations:", i)
            bne_value = -cfr.game_value - game_value
            print("Best Nash:", bne_value)
            joint_strategy = [best_response[0], opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            br_value = -cfr.game_value - game_value
            print("BR:", br_value)
            print("BR check:", -(cfr.best_response(0, [{}, opponent_strategy])[0] + game_value))
            gain_file.write(str(i) + " " + str(bne_value) + " " + str(br_value) + "\n")
            srnr_expl = cfr.best_response(1, [bne_strategy, {}])[0] + game_value
            nash_expl = cfr.best_response(1, [nash_strategy[0], {}])[0] + game_value
            assert nash_expl < 0.000001
            print("Nash Exploitability:", nash_expl)
            print("Exploitability:", srnr_expl)
            br_expl = cfr.best_response(1, [best_response[0], {}])[0] + game_value
            print("BR Exploitability:", br_expl)
            expl_file.write(str(i) + " " + str(srnr_expl) + " " + str(br_expl) + "\n")


def cdrnr_counterexample_plot():
    x = np.linspace(0, 1, 10000)
    a1 = 4.5
    a2 = -5
    a3 = -3
    a4 = -4
    c1 = 3.500001
    c2 = -4
    c3 = -3
    c4 = -3

    sequence_a = x * (a2 + a4) + (1 - x) * a1 * 2
    sequence_b = 0 * x
    sequence_c = x * (c2 + c4) + (1 - x) * c1 * 2

    resolving_a = x * (a2 + a3 + a4) + (1 - x) * a1 * 2
    resolving_b = 0 * x
    resolving_c = x * (c2 + c3 + c4) + (1 - x) * c1 * 2

    maxmargin_a = x * a2 + (1 - x) * a1 * 2
    maxmargin_b = 0 * x
    maxmargin_c = x * c2 + (1 - x) * c1 * 2

    for a, (b, c) in zip(1 - x, zip(sequence_c, sequence_a)):
        print(a, b, c)

    gadget_type = "sequence"

    plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    if gadget_type == "maxmargin":
        plt.gcf().subplots_adjust(bottom=0.23, left=0.095, right=0.99, top=0.84, wspace=0.4)
    if gadget_type == "sequence":
        plt.gcf().subplots_adjust(bottom=0.23, left=0.1, right=0.99, top=0.84, wspace=0.3)
    if gadget_type == "resolving":
        plt.gcf().subplots_adjust(bottom=0.23, left=0.12, right=0.99, top=0.84, wspace=0.3)
    cmap = plt.get_cmap('CMRmap')
    indices = np.linspace(0, cmap.N, 4)
    my_colors = [cmap(int(i)) for i in indices[:-1]]
    ax1.set_prop_cycle(color=my_colors)
    ax2.set_prop_cycle(color=my_colors)

    ax1.set_xlim(0, 1)
    ax1.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
    for ax in [ax1, ax2]:
        ax.set_axisbelow(True)
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        ax.yaxis.grid(which='major')
        ax.yaxis.grid(which='minor', alpha=0.5)
        ax.xaxis.set_minor_locator(AutoMinorLocator(5))
        ax.xaxis.grid(which='minor', alpha=0.5)
        ax.xaxis.grid(which='major')

    if gadget_type == "sequence":
        ax1.plot(x, sequence_a, label="Action a")
        ax1.plot(x, sequence_b, label="Action b")
        ax1.plot(x, sequence_c, label="Action c")
        ax2.plot(x, sequence_a, label="Action a")
        ax2.plot(x, sequence_b, label="Action b")
        ax2.plot(x, sequence_c, label="Action c")
        ax2.set_xlim(0.4999988, 0.5000012)
        ax2.set_ylim(-0.000011, 0.000011)
        ax2.set_yticks([-0.00001, 0, 0.00001])
        ax2.set_xticks([0.499999, 0.5, 0.500001])
        ax2.set_yticklabels([-1, 0, 1])
        ax2.set_xticklabels(["0.5-x", 0.5, "0.5+x"])
    if gadget_type == "resolving":
        ax1.plot(x, resolving_a, label="Action a")
        ax1.plot(x, resolving_b, label="Action b")
        ax1.plot(x, resolving_c, label="Action c")
        ax2.plot(x, resolving_a, label="Action a")
        ax2.plot(x, resolving_b, label="Action b")
        ax2.plot(x, resolving_c, label="Action c")
        ax2.set_xlim(0.4, 0.5)
        ax2.set_ylim(-2, 2)
        ax2.set_xticks([0.4, 0.42, 0.44, 0.46, 0.48])
        ax2.set_yticks([-1, 0, 1])
        ax2.xaxis.set_minor_locator(AutoMinorLocator(4))
    if gadget_type == "maxmargin":
        ax1.plot(x, maxmargin_a, label="Action a")
        ax1.plot(x, maxmargin_b, label="Action b")
        ax1.plot(x, maxmargin_c, label="Action c")
        ax2.plot(x, maxmargin_a, label="Action a")
        ax2.plot(x, maxmargin_b, label="Action b")
        ax2.plot(x, maxmargin_c, label="Action c")
        ax2.set_xlim(0.625, 0.675)
        ax2.set_ylim(-0.5, 0.5)
        ax2.set_xticks([0.63, 0.65, 0.67])
        ax2.set_yticks([-0.4, -0.2, 0, 0.2, 0.4])
        ax2.xaxis.set_minor_locator(AutoMinorLocator(4))
    ax1.set_xlabel("1-p")
    ax1.set_ylabel("Utility")
    ax2.set_xlabel("1-p")
    ax2.set_ylabel("Utility")

    if gadget_type == "maxmargin":
        plt.legend(bbox_to_anchor=(-1.2, 1.02, 2, .102), loc='lower left',
                   ncol=6, mode="expand", borderaxespad=0, borderpad=0.1)
    if gadget_type == "sequence":
        plt.legend(bbox_to_anchor=(-1.15, 1.02, 2, .102), loc='lower left',
                   ncol=6, mode="expand", borderaxespad=0, borderpad=0.1)
    if gadget_type == "resolving":
        plt.legend(bbox_to_anchor=(-1.15, 1.02, 2, .102), loc='lower left',
                   ncol=6, mode="expand", borderaxespad=0, borderpad=0.1)

    plt.show()

    plt.plot(x, sequence_c)
    plt.plot(x, sequence_b)
    plt.plot(x, sequence_a)
    plt.show()


def strategy_exploitability(fname, game_name):
    strategy = load_from_file(fname)
    cfr = CFR(game_name)
    sequence_nash = SequenceNash(game_name)
    zero_nash_value = sequence_nash.solve(0)
    one_nash_value = sequence_nash.solve(1)
    cfr.initialize()
    one_expl = -cfr.best_response(0, strategy)[0] - zero_nash_value
    two_expl = cfr.best_response(1, strategy)[0] - one_nash_value
    print(-cfr.best_response(0, strategy)[0])
    print(cfr.best_response(1, strategy)[0])
    print(zero_nash_value)
    print("Exploitability player 1:", one_expl)
    print("Exploitability player 2:", two_expl)
    print((one_expl + two_expl) / 2)


def cdbr_convergence_tests(fname, fixed_player, strategy=None, cfrd_iterations=1000, subgame_iterations=1000, average_from=0, cfrd_verbose=0, full_iterations=1000, full_verbose=0):
    # print("CFR-D")
    if strategy is None:
        generate_random = CFR(fname + ".efg")
        generate_random.initialize()
        generate_random.initialize_random_strategy(0)
        strategy = generate_random.strategy

    cfrd = CFRD(fname)
    cfrd.solve(cfrd_iterations, subgame_iterations=subgame_iterations, average_from=average_from, verbose=cfrd_verbose, fixed_strategy=[{}, strategy[fixed_player]])
    cfrd_average_strategy = cfrd.average_strategy
    cfrd_current_strategy = cfrd.strategy
    # print(cfrd_strategy)

    # print("CFR")
    cfr = CFR(fname + ".efg")
    cfr.solve(full_iterations, verbose=full_verbose, fixed=[{}, cfrd_current_strategy[fixed_player]])
    cfr_strategy = cfr.average_strategy
    cfr.compute_game_value([cfr_strategy[0], strategy[1]])
    print("Cfr exploitation:", cfr.game_value)

    # print("CFR-D-extension")
    cfravg = CFR(fname + ".efg")
    cfravg.solve(full_iterations, fixed=[cfrd_average_strategy[0], strategy[1]], verbose=full_verbose)
    cfravg_strategy = cfravg.average_strategy
    cfravg.compute_game_value(cfravg_strategy)
    print("Average strategy: ", cfravg.game_value)

    cfrcur = CFR(fname + ".efg")
    cfrcur.solve(full_iterations, fixed=[cfrd_current_strategy[0], strategy[1]], verbose=full_verbose)
    cfrcur_strategy = cfrcur.average_strategy
    cfrcur.compute_game_value(cfrcur_strategy)
    print("Current strategy: ", cfrcur.game_value)

    game_value = cfr.best_response(0, strategy)[0]
    print("Maximum exploitation: ", game_value)


def generate_leduc_rnr():
    leduc = "data/leduc_holdem.efg"
    for i in fibonacci_array(15, zero_start=True):
        opponent_strategy = load_from_file("data/bad_strategies/leduc_holdem/leduc_holdem_" + str(i) + "_iterations.strat")[1]
        for p in np.arange(0.1, 1, 0.1):
            # rnr_game = ExtensiveGame()
            rnr_generator = GenerateRNRGameNotFixed(leduc)
            temp_fname = "data/cdrnr/leduc_holdem_" + str(i) + "_" + "{:.1f}".format(p) + ".efg"
            rnr_generator.generate(temp_fname, [p, 1 - p], 1, [opponent_strategy])
            # rnr_game.load(temp_fname)


def compose_leduc_strategy_step(cfr, public_node, node, results):
    # if node.player == 0 and node.i_set == 15:
    # for i in range(30):
    # print(public_node)
    # print(public_node)
    # print(results[public_node][node.player][node.i_set])
    if node.player == 3:
        return
    if node.player == 0 or node.player == 1:
        cfr.strategy[node.player][node.i_set] = copy.deepcopy(results[public_node][node.player][node.i_set])
    for child in node.children:
        compose_leduc_strategy_step(cfr, public_node, child, results)


def compose_cfrrnr_leduc_strategies(fname, strat_file):
    results = load_from_file(fname)
    cfrrnr = CFRRNR(strat_file=strat_file)
    cfrrnr.strategy = results["Main strategy"]
    for id, isets in enumerate(results["ISets"]):
        for i_set in isets[0]:
            cfrrnr.strategy[0][i_set] = results["Strategy"][id][0][i_set]
        for i_set in isets[1]:
            cfrrnr.strategy[1][i_set] = results["Strategy"][id][1][i_set]
        # if 178 in iset[0]:
        #     print(id)
    for i_public_set, public_set in enumerate(cfrrnr.public_sets):
        for node in public_set:
            compose_leduc_strategy_step(cfrrnr, i_public_set, node, results["Strategy"])
    cfrrnr.average_strategy = cfrrnr.strategy
    return cfrrnr


def eval_cfrrnr(result_file, strat_file):
    # results = load_from_file(result_file)
    # cfr = CFRRNR(strat_file=strat_file)
    # cfr.strategy = results["Main strategy"]
    cfr = compose_cfrrnr_leduc_strategies(result_file, strat_file)
    # eval_cfrrnr(cfr)
    fname = "data/leduc_holdem.efg"
    # print(cfr.average_strategy[1][768])
    basic_cfr = CFR(fname)
    basic_cfr.initialize()
    game = ExtensiveGame()
    nash_solver = SequenceNash(fname)
    game_value = nash_solver.solve(0)
    nash_strategy = nash_solver.strategy_in_cfr_format()
    game.load(fname)
    game.fix_strategy_until_depth(1, cfr.opponent_strategy, -1)
    nash_solver.game = game
    nash_solver.solve(0)
    best_response = nash_solver.strategy_in_cfr_format()
    # rnr_game.save_to_file("results/games_to_check/srnr.efg")
    # joint_strategy = [srnr, opponent_strategy]
    joint_strategy = [cfr.strategy[0], cfr.opponent_strategy]
    basic_cfr.compute_game_value(joint_strategy)
    # print("Iterations:", i)
    srnr_value = -basic_cfr.game_value - game_value
    print("SRNR:", srnr_value)
    joint_strategy = [best_response[0], cfr.opponent_strategy]
    basic_cfr.compute_game_value(joint_strategy)
    br_value = -basic_cfr.game_value - game_value
    print("BR:", br_value)
    print("BR check:", -(basic_cfr.best_response(0, [{}, cfr.opponent_strategy])[0] + game_value))
    # gain_file.write(str(i) + " " + str(srnr_value) + " " + str(br_value) + "\n")
    srnr_expl = basic_cfr.best_response(1, [cfr.strategy[0], {}])[0] + game_value
    nash_expl = basic_cfr.best_response(1, [nash_strategy[0], {}])[0] + game_value
    assert nash_expl < 0.000001
    print("Nash Exploitability:", nash_expl)
    print("Exploitability:", srnr_expl)
    br_expl = basic_cfr.best_response(1, [best_response[0], {}])[0] + game_value
    print("BR Exploitability:", br_expl)
    # expl_file.write(str(i) + " " + str(srnr_expl) + " " + str(br_expl) + "\n")


def solve_leduc_rest(cfr, i_public_set, iterations):
    cfr.solve_leduc_rest(cfr.public_sets[i_public_set], i_public_set, iterations)
    return copy.deepcopy(cfr.average_strategy)


def solve_leduc_only_trunk(val_iterations, trunk_iterations, result_file, strat_file="data/bad_strategies/leduc_holdem/leduc_holdem_21_iterations.strat", p=0.5, use_vf=True):
    e = CFRRNR(p=p, vf_iterations=val_iterations, strat_file=strat_file, use_vf=use_vf)

    result = {}
    e.solve_leduc_trunk(iterations=trunk_iterations)
    result["Main strategy"] = copy.deepcopy(e.average_strategy)
    result["Main ISets"] = copy.deepcopy(e.trunk_i_sets)
    save_to_file(result, result_file)


def solve_leduc_gadget(val_iterations, trunk_iterations, subgame_iterations, result_file, strat_file="data/bad_strategies/leduc_holdem/leduc_holdem_21_iterations.strat", p=0.5, use_vf=True):
    e = CFRRNR(p=p, vf_iterations=val_iterations, strat_file=strat_file, use_vf=use_vf)
    result = {}
    e.solve_leduc_trunk(iterations=trunk_iterations)
    result["Main strategy"] = copy.deepcopy(e.average_strategy)
    result["Main ISets"] = copy.deepcopy(e.trunk_i_sets)
    strategies = {}

    for i_public_set, public_set in enumerate(e.public_sets):
        new_e = copy.deepcopy(e)
        new_e.solve_leduc_rest_with_gadget(public_set, i_public_set, subgame_iterations)
        strategies[i_public_set] = copy.deepcopy(new_e.average_strategy)
        # return
    result["ISets"] = copy.deepcopy(e.subgame_i_sets)
    result["Strategy"] = strategies
    save_to_file(result, result_file)


def solve_leduc_without_gadget(val_iterations, trunk_iterations, subgame_iterations, result_file, strat_file="data/bad_strategies/leduc_holdem/leduc_holdem_21_iterations.strat", p=0.5, use_vf=True):
    e = CFRRNR(p=p, vf_iterations=val_iterations, strat_file=strat_file, use_vf=use_vf)
    result = {}
    e.solve_leduc_trunk(iterations=trunk_iterations)
    result["Main strategy"] = copy.deepcopy(e.average_strategy)
    result["Main ISets"] = copy.deepcopy(e.trunk_i_sets)
    strategies = {}
    for i_public_set, public_set in enumerate(e.public_sets):
        new_e = copy.deepcopy(e)
        new_e.solve_leduc_rest_without_gadget(public_set, i_public_set, subgame_iterations)
        strategies[i_public_set] = copy.deepcopy(new_e.average_strategy)
    result["ISets"] = copy.deepcopy(e.subgame_i_sets)
    result["Strategy"] = strategies
    save_to_file(result, result_file)


def solve_leduc(val_iterations, trunk_iterations, subgame_iterations, result_file, strat_file="data/bad_strategies/leduc_holdem/leduc_holdem_21_iterations.strat", p=0.5):
    e = CFRRNR(p=p, vf_iterations=val_iterations, strat_file=strat_file)
    result = {}
    e.solve_leduc_trunk(iterations=trunk_iterations)
    # return
    # solve_leduc_rest(e, 14, subgame_iterations)
    # return
    # main_strategy = copy.deepcopy(e.average_strategy)
    # new_e = copy.deepcopy(e)
    result["Main strategy"] = copy.deepcopy(e.average_strategy)
    result["Main ISets"] = copy.deepcopy(e.trunk_i_sets)
    strategies = {}
    pool = mp.Pool()
    processes = {}
    for i_public_set, public_set in enumerate(e.public_sets):
        new_e = copy.deepcopy(e)
        processes[i_public_set] = pool.apply_async(solve_leduc_rest, args=(new_e, i_public_set, subgame_iterations))
    for id, process in processes.items():
        strategies[id] = copy.deepcopy(process.get())
    result["ISets"] = copy.deepcopy(e.subgame_i_sets)
    result["Strategy"] = strategies
    save_to_file(result, result_file)


def solve_rps(trunk_iterations, subgame_iterations, res_file, strat_file, split_p):
    rps = RPS("data/RPS_2round.gbt")
    rps.solve_rps_trunk(trunk_iterations)
    avs = copy.deepcopy(rps.average_strategy)
    c = 0
    for i, public in enumerate(rps.public_tree):
        rps_new = copy.deepcopy(rps)
        rps_new.solve_rps_rest(subgame_iterations, i)
        for p in public:
            avs[0][p.i_set] = copy.deepcopy(rps_new.average_strategy[0][p.i_set])
            for c in p.children:
                avs[1][c.i_set] = copy.deepcopy(rps_new.average_strategy[1][c.i_set])
    rps.average_strategy = copy.deepcopy(avs)
    rps.strategy = copy.deepcopy(avs)
    print(rps.compute_game_value(avs))
    return avs


def fix_strategy(s, node, public_chances):
    if node.is_terminal():
        return
    if node.player == 2 and node.id_val in public_chances:
        return
    if node.player == 0:
        node.player = 2
        # Rather copy it, to be sure nothing bad happens to it
        node.chance = copy.deepcopy(s[0][node.i_set])
    for child in node.children:
        fix_strategy(s, child, public_chances)


def add_to_strategy(c, s, node, public_chances, l):
    if node.is_terminal():
        return
    if node.player == 2 and node.id_val in public_chances:
        return
    if node.player == 0:
        c.average_strategy[0][node.i_set] = copy.deepcopy(s[0][node.i_set])
        l.add(node.i_set)
    for child in node.children:
        add_to_strategy(c, s, child, public_chances, l)


def plot_srnr_comb_expl_gain(steps, ps, game_name, value_limit=9, include_average=True, dynamic_steps=True):
    srnr_values = {"gain": {}, "expl": {}}
    comb_values = {"gain": {}, "expl": {}}
    for step in steps:
        for metric in ["gain", "expl"]:
            srnr_values[metric][step] = []
            comb_values[metric][step] = []
    for metric in [("gain", ""), ("expl", "_expl")]:
        for p in ps:
            srnr_iterations, srnr_values_local, srnr_br_values = load_sbrs(
                "results/br_sbr/srnr_" + game_name + "_step_{0}_p=" + str(p) + "_cfr_var_iter" + metric[1] + ("_dynamic" if dynamic_steps else "") + ".txt", steps, value_limit, include_average)
            for i, step in enumerate(steps):
                srnr_values[metric[0]][step].append(srnr_values_local[i][-1])
            comb_iterations, comb_values_local, comb_br_values = load_sbrs(
                "results/br_sbr/comb_" + game_name + "_step_{0}" + "_p=" + str(p) + "_cfr_var_iter" + metric[1] + ("_dynamic" if dynamic_steps else "") + ".txt", steps, value_limit, include_average)
            for i, step in enumerate(steps):
                comb_values[metric[0]][step].append(comb_values_local[i][-1])
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.set_aspect('equal', adjustable='box')
    ax.grid()
    ax.set_xlabel("Exploitability")
    ax.set_ylabel("Gain")
    ax.set_title("Gain vs Exploitability of combination and CDRNR with " + str(step) + " steps on " + game_name)
    for step in steps:
        ax.plot(srnr_values["expl"][step], srnr_values["gain"][step])
        ax.plot(comb_values["expl"][step], comb_values["gain"][step])
        for p, (x, y) in zip(ps, zip(srnr_values["expl"][step], srnr_values["gain"][step])):
            ax.text(x, y, p)
    plt.show()
    plt.close(fig)


def plot_gain_exploitability_graph(ps, gains, exploitabilities, labels, reverse_p=True, title="", show_numbers=True, show_points=False, labels_to_ignore=[]):
    plt.rcParams.update({'font.size': 14, 'font.family': 'Times New Roman'})
    fig, ax = plt.subplots(figsize=(6, 4))
    plt.gcf().subplots_adjust(bottom=0.13, left=0.1, right=0.99, top=0.99)
    # ax.set_aspect('equal', adjustable='box')
    ax.grid()
    ax.set_xlabel("Exploitability")
    ax.set_ylabel("Gain")
    ax.set_title(title)
    linetypes = ['-', '--', '--', ':', '-.', (0, (1, 5)), (0, (1, 3))]
    # linetypes = ['--', '-.', '-.', '-.', '-.', '-.', '-.']
    # linetypes = ['-'] * len(gains)
    cmap = matplotlib.cm.get_cmap('CMRmap')
    colors = [cmap(0.2), cmap(0.8), cmap(0.8), cmap(0.4), cmap(0.5), cmap(0.6), cmap(0.7)]
    for i in range(len(gains)):
        if labels[i] not in labels_to_ignore:
            ax.plot(exploitabilities[i], gains[i], label=labels[i], linestyle=linetypes[i], color=colors[i])
            if show_points:
                ax.scatter(exploitabilities[i], gains[i], marker="x")
            if show_numbers:
                for p, (x, y) in zip(reversed(ps[i]) if reverse_p else ps[i], zip(exploitabilities[i], gains[i])):
                    ax.text(x, y, p)
    plt.legend()
    plt.show()
    plt.close(fig)


def generate_repeated_rps(rounds=2):
    repeated_rps = ExtensiveGame()
    generate_repeated_rps_round(game=repeated_rps, parent=None, parent_action=None, rounds=rounds, current_round=0, current_value=0)
    return repeated_rps


bias = 2
reward_matrix = [
    [0, -bias, 1],
    [bias, 0, -1],
    [-1, 1, 0]
]


def generate_repeated_rps_round(game, parent, parent_action, rounds, current_round, current_value):
    if current_round == rounds:
        terminal_node = Node(3, None, parent, game.getid(), current_value)
        parent.children.append(terminal_node)
        parent.labels.append(parent_action)
        return
    top_node = Node(0, game.get_actual_set(0), parent, game.getid())
    if parent is None:
        game.root = top_node
    else:
        parent.children.append(top_node)
        parent.labels.append(parent_action)
    bottom_infoset = game.get_actual_set(1)
    for action_one_index, action_one_label in [(0, "R"), (1, "P"), (2, "S")]:
        bottom_node = Node(1, bottom_infoset, top_node, game.getid())
        top_node.children.append(bottom_node)
        top_node.labels.append(action_one_label)
        for action_two_index, action_two_label in [(0, "R"), (1, "P"), (2, "S")]:
            generate_repeated_rps_round(game=game, parent=bottom_node, parent_action=action_two_label, rounds=rounds, current_round=current_round + 1,
                                        current_value=current_value + reward_matrix[action_one_index][action_two_index])


def generate_results_from_counterexample_game_gadget_vs_full_tree(ps, gadgets):
    val_func_iterations = 1000
    trunk_iterations = 1000
    subgame_iterations = 1000
    game = "data/gadget_game.efg"
    sequential_form = SequenceNash(game)
    game_value = sequential_form.solve(0)
    cfr = CFR(game)
    cfr.initialize()
    for resolving_type in gadgets:
        print(resolving_type)
        for p in ps:
            print(p, end=" ")
            cdrnr = CDRNR_G(game, p, val_func_iterations)
            cdrnr.solve_trunk_partially(trunk_iterations)
            if resolving_type == "unsafe":
                cdrnr.solve_rest_chance(subgame_iterations)
            elif resolving_type == "full":
                cdrnr.solve_rest_full(subgame_iterations)
            elif resolving_type == "gadget":
                cdrnr.solve_rest_gadget(subgame_iterations)
            joint_strategy = [cdrnr.average_strategy[0], cdrnr.opponent_strategy]
            cfr.compute_game_value(joint_strategy)
            gain = -cfr.game_value - game_value
            print(gain, end=" ")
            expl = cfr.best_response(1, [cdrnr.average_strategy[0], {}])[0] + game_value
            print(expl)


def compute_one_lp_cd_response():
    import ContinualResponsesWithGadgets
    games = ["data/repeated_rps_two_rounds.efg"]
    opponent_strategies = [
        load_from_file("data/bad_strategies/repeated_rps_two_rounds/repeated_rps_two_rounds_1_iterations.strat")[1]]  # load_from_file(f"data/bad_strategies/leduc_holdem/leduc_holdem_0_iterations.strat")[1]]
    gadgets = ["full", "exp-strat", "ses", "resolving", "max-margin"]
    ps = [0.5]
    for game, opponent_strategy in zip(games, opponent_strategies):
        data = {"labels": [], "gain": [], "expl": []}
        for gadget in gadgets:
            data["labels"].append(gadget)
            data["gain"].append([])
            data["expl"].append([])
            for p in ps:
                gain, expl = ContinualResponsesWithGadgets.cd_rnr_one_step(file_name=game, opponent_strategy=opponent_strategy, p=p, gadget=gadget, decomposition_type="steps", verbose=True)
                data["gain"][-1].append(gain)
                data["expl"][-1].append(expl)
    for i, gadget in enumerate(gadgets):
        print(gadget)
        print(data["gain"][i])
        print(data["expl"][i])


def leduc_cd_responses_tests_generate(iters):
    import ContinualResponsesWithGadgets
    game_name = "repeated_rps_two_rounds"
    decomposition_type = "steps"
    games = [f"data/{game_name}.efg"]
    strategy = load_from_file(f"data/bad_strategies/{game_name}/{game_name}_{iters}_iterations.strat")[1]
    opponent_strategies = [strategy]
    gadgets = ["full", "max-margin", "resolving", "ses", "exp-strat", "unsafe", "comb"]
    ps = np.linspace(0, 1, 11)
    print(ps)
    # ps = [0.5]
    for game, opponent_strategy in zip(games, opponent_strategies):
        print("Game:", game)
        data = {"labels": [], "gain": [], "expl": []}
        for gadget in gadgets:
            print("Gadget:", gadget)
            data["labels"].append(gadget)
            data["gain"].append([])
            data["expl"].append([])
            print("p:", end=" ")
            for p in ps:
                print(p, end=", ")
                gain, expl = ContinualResponsesWithGadgets.cd_rnr_one_step(file_name=game, opponent_strategy=opponent_strategy, p=p, gadget=gadget, decomposition_type=decomposition_type, verbose=False)
                data["gain"][-1].append(gain)
                data["expl"][-1].append(expl)
            print()

        data["p"] = [ps] * len(data["labels"])
        save_to_file(data, f"results/br_sbr/all_{game_name}_lp_cfr_iter={iters}")


def leduc_cd_responses_tests_plot(file_name, iters, title="constructed", show_numbers=True, labels_to_ignore=[]):
    data = load_from_file(file_name)
    if title == "constructed":
        plot_gain_exploitability_graph(data["p"], data["gain"], data["expl"], data["labels"], reverse_p=False, title=f"on {file_name} with {iters} iterations", show_numbers=show_numbers,
                                       labels_to_ignore=labels_to_ignore)
    elif title is None:
        plot_gain_exploitability_graph(data["p"], data["gain"], data["expl"], data["labels"], reverse_p=False, title="", show_numbers=show_numbers, labels_to_ignore=labels_to_ignore)


if __name__ == '__main__':
    #### Psychologist experiments
    # prospect_theory_rl()

    #### Perfect value function with RNR experiments
    # step_best_response_test()
    # generate_random_strategies()
    # compare_step_br_results()

    #### general SBR simulations
    ### generate bad strategies
    # fname = "data/iigs6.efg"
    # generate_bad_strategies(fname)

    ### evaluate SBR
    ## against cfr strategies
    # for game_name in ["iigs5"]:
    #     for steps in [1]:
    # # steps = 6
    # # game_name = "leduc_holdem"
    #         fname = "data/" + game_name + ".efg"
    #         sbr_level_lp_test(fname, steps, dynamic_steps=True)

    ## against random strategies
    # for game_name in ["iigs5"]:
    #     for steps in [1, 3, 5]:
    #         fname = "data/" + game_name + ".efg"
    #         sbr_level_lp_test_random_strategies(fname, steps, dynamic_steps=True)

    ## plot sbr evaluation
    # game_name = "iigs5"
    # steps = 4
    # fname = "results/br_sbr/sbr_" + game_name + "_step_" + str(steps) + "_cfr_var_iter.txt"
    # plot_sbr_gain_data(fname, steps)
    # fname = "results/br_sbr/sbr_" + game_name + "_step_" + str(steps) + "_cfr_var_iter_expl.txt"
    # plot_sbr_expl_data(fname, steps)

    #### general SRNR simulations
    # ## evaluate SRNR
    # for game_name in ["iigs6"]:
    #     for steps in [6]:
    #         for p in [0, 0.2, 0.4, 0.6, 0.8, 1]:
    #             fname = "data/" + game_name + ".efg"
    #             srnr_level_lp_test(fname, steps, p, dynamic_steps=False)
    # srnr_level_lp_test_random_strategies(fname, steps, p, dynamic_steps=True)

    #### general SES simulations
    # # evaluate SES
    # for game_name in ["iigs5", "leduc_holdem", "ld"]:
    #     for steps in [5]:
    #         for p in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    #             fname = "data/" + game_name + ".efg"
    #             ses_lp_test(fname, steps, p, dynamic_steps=False)
    # srnr_level_lp_test_random_strategies(fname, steps, p, dynamic_steps=True)

    ## plot the evaluation
    # game_name = "iigs4"
    # steps = 2
    # p = 0.9
    # fname = "results/br_sbr/srnr_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter.txt"
    # plot_srnr_gain_data(fname, steps, p)
    # fname = "results/br_sbr/srnr_" + game_name + "_step_" + str(steps) + "_p=" + str(p) + "_cfr_var_iter_expl.txt"
    # plot_srnr_expl_data(fname, steps, p)

    #### Local best response tests
    ## Against low iteration cfr strategies
    # for i in fibonacci_array(15, zero_start=True):
    #     strategy = load_from_file("data/bad_strategies/leduc_holdem/leduc_holdem_" + str(i) + "_iterations.strat")[1]
    #     wins = []
    #     games = 100000
    #     for j in range(games):
    #         won = leduc_local_best_response(strategy)
    #         wins.append(won)
    #     m, se = np.mean(wins), scipy.stats.sem(wins)
    #     h = se * scipy.stats.t.ppf((1 + 0.95) / 2., games - 1)
    #     print("Iterations: ", i)
    #     print("LBR:", m, "+-", h)

    ## Against random strategies
    # for i in range(10):
    #     strategy = load_from_file("data/random_strategies/leduc_holdem_seed_" + str(i) + ".strat")[1]
    #     wins = []
    #     games = 100000
    #     for j in range(games):
    #         won = leduc_local_best_response(strategy)
    #         wins.append(won)
    #     m, se = np.mean(wins), scipy.stats.sem(wins)
    #     h = se * scipy.stats.t.ppf((1 + 0.95) / 2., games - 1)
    #     print("Iterations: ", i)
    #     print("LBR:", m, "+-", h)

    #### Full lbr sbr br comparison on leduc
    ### Cfr strategies
    # plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    # plt.gcf().subplots_adjust(bottom=0.23, left=0.07, right=0.99, top=0.84, wspace=0.2)
    # cmap = plt.get_cmap('CMRmap')
    # indices = np.linspace(0, cmap.N, 7)
    # my_colors = [cmap(int(i)) for i in indices[:-1]]
    # ax1.set_prop_cycle(color=my_colors)
    # ax2.set_prop_cycle(color=my_colors)
    # steps = [1, 3, 5]
    # plot_leduc_lbr_sbr(steps, "cfr", dynamic_steps=True, values=9, include_average=True, ax=ax1)
    #
    # ### Random strategies
    # steps = [1, 3, 5]
    # plot_leduc_lbr_sbr(steps, "random", dynamic_steps=True, values=-1, include_average=True, ax=ax2)
    # plt.legend(bbox_to_anchor=(-1.2, 1.02, 2.2, .102), loc='lower left',
    #            ncol=6, mode="expand", borderaxespad=0., handlelength=1, handletextpad=0.3, borderpad=0.1)
    # plt.show()

    ### Full sbr br comparison
    ## Cfr strategies
    # game_name = "leduc_holdem"
    # plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    # plt.gcf().subplots_adjust(bottom=0.23, left=0.07, right=0.99, top=0.84, wspace=0.2)
    # cmap = plt.get_cmap('CMRmap')
    # indices = np.linspace(0, cmap.N, 5)
    # my_colors = [cmap(int(i)) for i in indices[:-1]]
    # ax1.set_prop_cycle(color=my_colors)
    # ax2.set_prop_cycle(color=my_colors)
    # steps = [1, 3, 5]
    # plot_br_sbr_steps_expl(game_name, steps, "random", dynamic_steps=True, values=9, include_average=True, ax=ax1, ticks=[0, 1, 2, 3, 4])
    # plot_br_sbr_steps_expl(game_name, steps, "cfr", dynamic_steps=True, values=9, include_average=True, ax=ax2, ticks=[0, 1, 2, 3, 4])
    # plt.legend(bbox_to_anchor=(-1.2, 1.02, 2.2, .102), loc='lower left',
    #            ncol=6, mode="expand", borderaxespad=0., handlelength=1, handletextpad=0.3, borderpad=0.1)
    # plt.show()
    # #
    # # ### Random strategies
    # steps = [1, 3, 5]
    # plot_br_sbr_steps_gain(game_name, steps, "random", dynamic_steps=True, values=-1, include_average=True, ax=ax2, ticks=[0,1,2,3,4,5,6])
    # plt.legend(bbox_to_anchor=(-1.2, 1.02, 2.2, .102), loc='lower left',
    #            ncol=6, mode="expand", borderaxespad=0., handlelength=1, handletextpad=0.3, borderpad=0.1)
    # plt.show()

    ### Plot multiple steps against BR
    ## SBR part
    # game_name = "iigs5"
    # mode = "cfr"
    # steps = [1, 2, 3, 4, 5]
    # plot_br_sbr_steps_gain(game_name, steps, dynamic_steps=True, mode=mode)
    # plot_br_sbr_steps_expl(game_name, steps, dynamic_steps=True, mode=mode)

    # ## SRNR part
    # br = True
    # rnr = True
    # bne = True
    # srnrg = False
    # srnru = False
    # srnrvf = False
    # srnr = True
    # comb = True
    # ses = True
    #
    # dynamic_steps = True
    # cdrnr_steps = [1, 5]
    # ses_steps = [5]
    # comb_steps = [1, 5]
    # steps = [cdrnr_steps, ses_steps, comb_steps]
    # game_name = "iigs5"
    # mode = "cfr"
    # p = 0.2
    # print_set = set()
    # bar_count = 0
    #
    # if br:
    #     print_set.add("br")
    #     bar_count += 1
    # if rnr:
    #     print_set.add("rnr")
    #     bar_count += 1
    # if bne:
    #     print_set.add("bne")
    #     bar_count += 1
    # if srnrg:
    #     print_set.add("srnrg")
    #     bar_count += 1
    # if srnru:
    #     print_set.add("srnru")
    #     bar_count += 1
    # if srnrvf:
    #     print_set.add("srnrvf")
    #     bar_count += 1
    # if srnr:
    #     print_set.add("srnr")
    #     bar_count += len(cdrnr_steps)
    # if comb:
    #     print_set.add("comb")
    #     bar_count += len(comb_steps)
    # if ses:
    #     print_set.add("ses")
    #     bar_count += len(ses_steps)
    #
    # y_ticks_gain = [0, 1, 2, 3, 4]
    # y_ticks_expl = [0, 1, 2, 3, 4]
    # # y_ticks_gain = [0, 0.25, 0.5, 0.75, 1]
    # # y_ticks_expl = [0, 0.25, 0.5, 0.75, 1]
    #
    # broken_axe = 0
    # plt.rcParams.update({'font.size': 20, 'font.family': 'Times New Roman'})
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
    # plt.gcf().subplots_adjust(bottom=0.23, left=0.07, right=0.99, top=0.84, wspace=0.2)
    #
    # cmap = plt.get_cmap('CMRmap')
    # indices = np.linspace(0, cmap.N, bar_count + 1)
    # my_colors = [cmap(int(i)) for i in indices[:-1]]
    # if broken_axe > 0:
    #     divider = make_axes_locatable(ax2)
    #     ax3 = divider.new_vertical(size="100%", pad=0.1)
    #     fig.add_axes(ax3)
    # ax1.set_prop_cycle(color=my_colors)
    # ax2.set_prop_cycle(color=my_colors)
    # if broken_axe > 0:
    #     ax3.set_prop_cycle(color=my_colors)
    #     ax2.set_ylim(0, broken_axe)
    #     ax3.set_ylim(1, 8)
    #     ax2.spines['top'].set_visible(False)
    #     ax3.tick_params(bottom=False, labelbottom=False)
    #     ax3.spines['bottom'].set_visible(False)
    # plot_br_srnr_steps_gain(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=9, include_average=True, ax=ax1, mode=mode, yticks=y_ticks_gain, print_set=print_set)
    # plot_br_srnr_steps_expl(game_name, steps, p, dynamic_steps=dynamic_steps, value_limit=9, include_average=True, ax=ax2, mode=mode, yticks=y_ticks_expl, print_set=print_set)
    # if broken_axe > 0:
    #     plot_br_srnr_steps_expl(game_name, steps, p, dynamic_steps=True, value_limit=9, include_average=True, ax=ax3)
    # plt.legend(bbox_to_anchor=(-1.2, 1.02, 2.2, .102), loc='lower left',
    #            ncol=bar_count, mode="expand", borderaxespad=0., handlelength=1, handletextpad=0.3, borderpad=0.1)
    # fig.add_subplot(111, frameon=False)
    # plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
    # plt.xlabel("CFR iterations of opponent's strategy (p=" + str(p) + ")")
    # plt.show()

    #### Random strategies with seeds
    # fname = "data/iigs5.efg"
    # generate_random_strategies(fname)

    #### Save strategies in Open spiel format
    # CFR strategies
    # game_name = "leduc_holdem"
    # save_cfr_strategies_for_open_spiel(game_name)

    # Random strategies
    # game_name = "leduc_holdem"
    # save_random_strategies_for_open_spiel(game_name)

    #### Create plain RNR results
    # mode = "cfr"
    # for game_name in ["iigs5"]:
    #     for p in [0.3, 0.7, 0.5]:
    #         fname = "data/" + game_name + ".efg"
    #         rnr_computation(fname, p, mode)

    #### Create results for best NE against strategy
    # mode = "cfr"
    # for game_name in ["iigs6"]:
    #     fname = "data/" + game_name + ".efg"
    #     best_ne_computation(fname, mode)

    #### Plots for counterexample
    # cdrnr_counterexample_plot()

    #### CFR-D tests
    ### Compute strategies
    # for main_iterations in [10, 100, 1000]:
    #     for subgame_iterations in [100, 1000]:
    #         cfrd_tests("data/leduc_holdem/leduc_holdem", cfrd_verbose=1, cfrd_iterations=main_iterations, subgame_iterations=subgame_iterations, full_iterations=1000, full_verbose=1,
    #                    average_from=int(main_iterations / 2))

    ### Evaluate strategies
    # cfrd_iterations = 100
    # subgame_iterations = 100
    # full_iterations = 1000
    # strategy_exploitability(f"data/leduc_holdem/cfrd_strategies/cfrwf_strategy_dit_{cfrd_iterations}_sit_{subgame_iterations}_fit_{full_iterations}", "data/leduc_holdem/leduc_holdem.efg")
    # strategy_exploitability(f"data/leduc_holdem/cfrd_strategies/cfr_strategy_dit_{cfrd_iterations}_sit_{subgame_iterations}_fit_{full_iterations}", "data/leduc_holdem/leduc_holdem.efg")

    #### CDBR current iterations convergence tests
    # cdbr_convergence_tests("data/synt_decomp/synt", 1, None, cfrd_iterations=10000, subgame_iterations=1000, average_from=0, cfrd_verbose=0, full_iterations=10000, full_verbose=0)


    #### Comparison of gadget, full tree and unsave solving
    # generate_results_from_counterexample_game_gadget_vs_full_tree([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], ["gadget", "full"])

    #### Compute combination between CDBR and NE
    # for game_name in ["iigs5", "leduc_holdem", "ld"]:
    #     for steps in [1, 2, 3, 4, 5]:
    #         for p in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
    #             # steps = 6
    #             # game_name = "leduc_holdem"
    #             fname = "data/" + game_name + ".efg"
    #             sbr_combination_level_lp_test(fname, steps, p, dynamic_steps=False)

    #### Plot CDRNR and COMB comparison as Gain X Exploitability graph
    # plot_srnr_comb_expl_gain(steps=[1, 3, 5], ps=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], game_name="leduc_holdem")

    #### Theoretical results for exploitability vs gains of the gadgets (resolving vs full)

    # values = [
    #     (0, 0), (-1, 4), (-3, 7), (-6, 9), (-10, 10)
    # ]
    # values_other = [
    #     (0, 0), (-1, 4), (-3, 7), (-6, 9), (0, 10)
    # ]
    # actions = 5
    #
    # for value, value_other in zip(values, values_other):
    #     plt.plot([0, 1], [value_other[0] * (actions - 1) + value[0], value[1]])
    # plt.ylabel("Expected utility")
    # plt.xlabel("p")
    # plt.title("Best actions for different p in CDRNR using resolving gadet")
    # plt.show()
    #
    # for value in values:
    #     plt.plot([0, 1], value)
    # plt.ylabel("Expected utility")
    # plt.xlabel("p")
    # plt.title("Best actions for different p in CDRNR using full gadget")
    # plt.show()

    ### Max margin vs full
    # values = [
    #     ([0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]),
    #     ([0, -6, 0, 0, 0, 0], [30, -6, 0, 0, 0, 0]),
    #     ([0, -18, 0, 0, 0, 0], [60, -18, 0, 0, 0, 0]),
    #     ([0, -36, 0, 0, 0, 0], [90, -36, 0, 0, 0, 0]),
    #     ([0, -12, -12, -12, -12, -12], [120, -12, -12, -12, -12, -12]),
    # ]

    # for value in values:
    #     plt.plot([0, 1], [np.min(value[0]) * 1./6, np.average(value[1])])
    # plt.ylabel("Expected utility")
    # plt.xlabel("p")
    # plt.title("Best actions for different p in CDRNR using max-margin gadet")
    # plt.show()
    # # #
    # for value in values:
    #     plt.plot([0, 1], [np.average(value[0]), np.average(value[1])])
    # plt.ylabel("Expected utility")
    # plt.xlabel("p")
    # plt.title("Best actions for different p in CDRNR using full gadget")
    # plt.show()


    #### Practical results from codebase comparing SES
    # cfr = CFR("data/leduc_holdem.efg")
    # cfr.solve(10)
    # print(cfr.average_strategy)
    # print(cfr.compute_game_value(cfr.average_strategy))
    for iters in fibonacci_array(9, True):
        leduc_cd_responses_tests_generate(iters)
    # leduc_cd_responses_tests_generate(1)
    # compute_one_lp_cd_response()
    # for iters in fibonacci_array(9, True):
    #     leduc_cd_responses_tests_plot(f"results/br_sbr/all_leduc_holdem_lp_cfr_iter={iters}", iters, show_numbers=False, labels_to_ignore=[])
    # leduc_cd_responses_tests_plot(f"results/br_sbr/all_gadgets_two_round_rps_lp_cfr_iter={iters}", iters)
    # leduc_cd_responses_tests_plot(f"results/br_sbr/all_gadget_game_lp_cfr_iter={0}", 0, title=None, show_numbers=False, labels_to_ignore=["ses", "exp-strat", "unsafe", "comb"])
    # leduc_cd_responses_tests_plot(f"results/br_sbr/all_repeated_rps_two_rounds_lp_cfr_iter={1}", 1, show_numbers=False, title=None, labels_to_ignore=["unsafe"])
    # leduc_cd_responses_tests_plot(f"results/br_sbr/all_200p_leduc_holdem_lp_cfr_iter={3}", 3, show_numbers=False)

    # save_to_file([{},strategy], "data/bad_strategies/gadget_game/gadget_game_0_iterations.strat")
    pass
