from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from scipy import stats
import pickle
import json
from numpy.random import RandomState
import argparse
import multiprocessing as mp


np.set_printoptions(suppress=True)
np.set_printoptions(precision=3)

np.random.seed(0)

parser = argparse.ArgumentParser(description="All Spinning Top Payoffs DPP")
parser.add_argument("--nb_iters", type=int, default=250)
parser.add_argument("--nb_exps", type=int, default=5)
parser.add_argument("--lr", help="learning rate", default=0.5, type=float)
parser.add_argument(
    "--mp",
    default=True,
    action="store_false",
    help="Set --mp for False, otherwise leave it for True",
)
parser.add_argument("--game_name", type=str, default="AlphaStar")
parser.add_argument("--plot", action="store_true")

args = parser.parse_args()

LR = 0.8
TH = 0.03

expected_card = []
sizes = []

time_string = time.strftime("%Y%m%d-%H%M%S")
DATASET = os.path.join("dataset")
PATH_RESULTS = os.path.join("results", "{}/{}".format(args.game_name, time_string))
if not os.path.exists(PATH_RESULTS):
    os.makedirs(PATH_RESULTS)


# Search over the pure strategies to find the BR to a strategy
def get_br_to_strat(strat, payoffs=None, verbose=False, symmetric=True):
    if symmetric:
        row_weighted_payouts = strat @ payoffs
        br = np.zeros_like(row_weighted_payouts)
        br[np.argmin(row_weighted_payouts)] = 1
        if verbose:
            print(
                row_weighted_payouts[np.argmin(row_weighted_payouts)], "exploitability"
            )
    else:
        # print("ssfffe", strat[0].shape, strat[1].shape, payoffs.shape)
        row_weighted_payouts, column_weighted_payouts = (
            strat[0] @ payoffs,
            -payoffs @ strat[1],
        )
        rbr = np.zeros_like(row_weighted_payouts)
        cbr = np.zeros_like(column_weighted_payouts)
        rbr[np.argmin(row_weighted_payouts)] = 1
        cbr[np.argmin(column_weighted_payouts)] = 1
        br = [rbr, cbr]
        if verbose:
            print(
                row_weighted_payouts[np.argmin(row_weighted_payouts)]
                + column_weighted_payouts[np.argmin(column_weighted_payouts)],
                "exploitability",
            )
    return br


# Fictituous play as a nash equilibrium solver
def fictitious_play(iters=2000, payoffs=None, verbose=False, symmetric=True):
    exps = []
    if symmetric:
        dim = payoffs.shape[0]
        pop = np.random.uniform(0, 1, (1, dim))
        pop = pop / pop.sum(axis=1)[:, None]
        averages = pop
        for i in range(iters):
            average = np.average(pop, axis=0)
            br = get_br_to_strat(average, payoffs=payoffs)
            exp1 = average @ payoffs @ br.T
            exp2 = br @ payoffs @ average.T
            exps.append(exp2 - exp1)
            # if verbose:
            #     print(exp, "exploitability")
            averages = np.vstack((averages, average))
            pop = np.vstack((pop, br))
    else:
        rdim, cdim = payoffs.shape
        rpop = np.random.uniform(0, 1, (1, rdim))
        cpop = np.random.uniform(0, 1, (cdim, 1))
        rpop = rpop / rpop.sum(axis=1)[:, None]
        cpop = cpop / cpop.sum(axis=0)[None, :]
        averages = [rpop, cpop]
        for i in range(iters):
            average = [np.average(rpop, axis=0), np.average(cpop, axis=1)]
            br = get_br_to_strat(average, payoffs=payoffs, symmetric=symmetric)
            exp1 = average[0] @ payoffs @ br[0]
            exp2 = br[1] @ payoffs @ average[1]
            if verbose:
                print(exp1, exp2, "exploitability")
            exps.append(exp2 - exp1)
            # if verbose:
            #     print(exp, "exploitability")
            averages = [
                np.vstack((averages[0], average[0])),
                np.hstack((averages[1], average[1].reshape(-1, 1))),
            ]
            # print("shape", cpop.shape, br[0].shape, br[1].shape, rpop.shape)
            cpop = np.hstack((cpop, br[0].reshape(-1, 1)))
            rpop = np.vstack((rpop, br[1]))

        # print("averagefge", averages[0].shape, averages[1].shape)

    return averages, exps


# Solve exploitability of a nash equilibrium over a fixed population
def get_exploitability(pop, payoffs, iters=1000, symmetric=True):
    if symmetric:
        emp_game_matrix = pop @ payoffs @ pop.T
        averages, _ = fictitious_play(
            payoffs=emp_game_matrix, iters=iters, symmetric=symmetric
        )
        strat = averages[-1] @ pop  # Aggregate
        test_br = get_br_to_strat(strat, payoffs=payoffs)
        exp1 = strat @ payoffs @ test_br.T
        exp2 = test_br @ payoffs @ strat
        return exp2 - exp1
    else:
        # rpop, cpop = pop
        # emp_game_matrix = rpop @ payoffs @ cpop
        emp_game_matrix = payoffs @ pop
        averages, _ = fictitious_play(
            payoffs=emp_game_matrix, iters=iters, symmetric=symmetric
        )
        rstrat = averages[0][-1]  # @ pop  # Aggregate
        # print("shape ", cpop.shape, averages[1][:, -1].shape)
        cstrat = pop @ averages[1][:, -1]
        test_br = get_br_to_strat(
            [rstrat, cstrat], payoffs=payoffs, symmetric=symmetric
        )
        exp1 = rstrat @ payoffs @ test_br[0]
        # exp1 = cstrat @ payoffs @ test_br[1].T
        exp2 = test_br[1] @ payoffs @ cstrat
        return exp2 - exp1


def ros(rpopulation_strategies, cpopulation_strategies, payoffs, lr, k, pop):
    min_brs = None
    min_eps = 1e9
    base_cbr = pop[0][k]
    base_rbr = pop[1][:, k]

    for rstrategy, cstrategy in zip(rpopulation_strategies, cpopulation_strategies):
        rbr, cbr = get_br_to_strat(
            [rstrategy, cstrategy],
            payoffs=payoffs,
            symmetric=False,
        )
        pop[0][k] = lr * cbr + (1 - lr) * base_cbr
        pop[1][:, k] = lr * rbr + (1 - lr) * base_rbr
        exp = get_exploitability(pop, payoffs, symmetric=False)
        if exp < min_eps:
            min_brs = [rbr, cbr]
            min_eps = exp

    # recover base brs
    pop[0][k] = base_cbr
    pop[1][:, k] = base_rbr

    return min_brs


def minmax(rurrgame_matrix, currgame_matrix, lr):
    raise NotImplementedError


def joint_loss(pop, payoffs, meta_nash, k, lambda_weight, lr):
    dim = payoffs.shape[0]

    br = np.zeros((dim,))
    values = []
    cards = []

    if np.random.randn() < lambda_weight:
        for i in range(dim):
            br_tmp = np.zeros((dim,))
            br_tmp[i] = 1.0

            aggregated_enemy = meta_nash @ pop[:k]
            value = br_tmp @ payoffs @ aggregated_enemy.T

            values.append(value)
        br[np.argmax(values)] = 1

    else:
        for i in range(dim):
            br_tmp = np.zeros((dim,))
            br_tmp[i] = 1.0

            aggregated_enemy = meta_nash @ pop[:k]
            pop_k = lr * br_tmp + (1 - lr) * pop[k]
            pop_tmp = np.vstack((pop[:k], pop_k))
            M = pop_tmp @ payoffs @ pop_tmp.T
            metanash_tmp, _ = fictitious_play(payoffs=M, iters=1000)
            L = np.diag(metanash_tmp[-1]) @ M @ M.T @ np.diag(metanash_tmp[-1])
            l_card = np.trace(
                np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0]))
            )
            cards.append(l_card)
        br[np.argmax(cards)] = 1

    return br


def psro_steps(
    iters=5,
    payoffs=None,
    verbose=False,
    seed=0,
    num_learners=4,
    improvement_pct_threshold=0.03,
    lr=0.2,
    loss_func="dpp",
    full=False,
):
    if num_learners > 1:
        print("start pipeline psro")
    else:
        print("start psro")
    dim = payoffs.shape[0]

    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (1 + num_learners, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]

    M = pop @ payoffs @ pop.T
    L = M @ M.T
    l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
    l_cards = [l_card]

    learner_performances = [[0.1] for i in range(num_learners + 1)]
    for i in range(iters):
        # Define the weighting towards diversity as a function of the fixed population size, this is currently a hyperparameter
        lambda_weight = 0.85
        if i % 10 == 0:
            print("iteration: ", i, " exp full: ", exps[-1])
            print("size of pop: ", pop.shape[0])

        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[0] - j - 1
            emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            meta_nash, _ = fictitious_play(payoffs=emp_game_matrix, iters=1000)
            population_strategy = (
                meta_nash[-1] @ pop[:k]
            )  # aggregated enemy according to nash

            if loss_func == "br":
                # standard PSRO
                br = get_br_to_strat(population_strategy, payoffs=payoffs)
            elif loss_func == "dpp":
                # Diverse PSRO
                br = joint_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr)
                br_orig = get_br_to_strat(population_strategy, payoffs=payoffs)
            else:
                raise ValueError("Unknow loss func")

            # Update the mixed strategy towards the pure strategy which is returned as the best response to the
            # nash equilibrium that is being trained against.
            pop[k] = lr * br + (1 - lr) * pop[k]
            performance = (
                pop[k] @ payoffs @ population_strategy.T + 1
            )  # make it positive for pct calculation
            learner_performances[k].append(performance)

            # if the first learner plateaus, add a new policy to the population
            if (
                j == num_learners - 1
                and performance / learner_performances[k][-2] - 1
                < improvement_pct_threshold
            ):
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                pop = np.vstack((pop, learner))
                learner_performances.append([0.1])

        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000)
        exps.append(exp)

        M = pop @ payoffs @ pop.T
        L = M @ M.T
        l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
        l_cards.append(l_card)
    print("end psro")

    return pop, exps, l_cards


# Define the self-play algorithm
def self_play_steps(
    iters=10,
    payoffs=None,
    verbose=False,
    improvement_pct_threshold=0.03,
    lr=0.2,
    seed=0,
):
    print("start self play")
    dim = payoffs.shape[0]
    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (2, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]
    performances = [0.01]

    M = pop @ payoffs @ pop.T
    L = M @ M.T
    l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
    l_cards = [l_card]

    for i in range(iters):
        if i % 10 == 0:
            print("iteration: ", i, "exploitability: ", exps[-1])
        br = get_br_to_strat(pop[-2], payoffs=payoffs)
        pop[-1] = lr * br + (1 - lr) * pop[-1]
        performance = pop[-1] @ payoffs @ pop[-2].T + 1
        performances.append(performance)
        if performance / performances[-2] - 1 < improvement_pct_threshold:
            learner = np.random.uniform(0, 1, (1, dim))
            learner = learner / learner.sum(axis=1)[:, None]
            pop = np.vstack((pop, learner))
        exp = get_exploitability(pop, payoffs, iters=1000)
        exps.append(exp)

        M = pop @ payoffs @ pop.T
        L = M @ M.T
        l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
        l_cards.append(l_card)
    print("end self play")

    return pop, exps, l_cards


def psro_efficient_steps(
    iters=5,
    payoffs=None,
    verbose=False,
    seed=0,
    num_learners=4,
    improvement_pct_threshold=0.03,
    lr=0.8,
    full=False,
    loss_func="br",
):
    print("start epsro")
    rdim, cdim = payoffs.shape

    r = np.random.RandomState(seed)
    # rpop = r.uniform(0, 1, (1 + num_learners, rdim))
    # rpop = rpop / rpop.sum(axis=1)[:, None]
    cpop = r.uniform(0, 1, (cdim, 1 + num_learners))
    cpop = cpop / cpop.sum(axis=0)[None, :]
    pop = cpop
    exp = get_exploitability(pop, payoffs, iters=1000, symmetric=False)
    exps = [exp]

    M = pop.T @ payoffs @ pop
    L = M @ M.T
    l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
    l_cards = [l_card]

    learner_performances = [[0.1] for i in range(num_learners + 1)]
    if loss_func == "ros":
        rpopulation_strategies = [[] for _ in range(num_learners)]
        cpopulation_strategies = [[] for _ in range(num_learners)]

    for i in range(iters):
        # Define the weighting towards diversity as a function of the fixed population size, this is currently a hyperparameter
        lambda_weight = 0.85
        if i % 5 == 0:
            print(
                "iteration: ", i, " exp full: ", exps[-1], "cardinality: ", l_cards[-1]
            )

        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[1] - j - 1
            # emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            rurrgame_matrix = payoffs @ pop[:, :k]
            # currgame_matrix = pop[0][:k] @ payoffs

            # return a RBR and cmixture
            rmeta_nash, _ = fictitious_play(
                payoffs=rurrgame_matrix, iters=1000, symmetric=False
            )
            # return a CBR and rmixture
            # cmeta_nash, _ = fictitious_play(
            #     payoffs=currgame_matrix, iters=1000, symmetric=False
            # )
            # print("meta nahs", rmeta_nash[0].shape, cmeta_nash[1].shape)

            # print("cmeta add", k, currgame_matrix.shape, rurrgame_matrix.shape, cmeta_nash[0].shape, cmeta_nash[1].shape, rmeta_nash[0].shape, rmeta_nash[1].shape, pop[1][:, :k].shape, pop[0][:k].shape)

            # rpopulation_strategy = (
            #     cmeta_nash[0][-1] @ pop[0][:k]
            # )  # aggregated enemy according to nash
            # print("shpoe", pop[0].shape, cmeta_nash[1].shape)
            cpopulation_strategy = pop[:, :k] @ rmeta_nash[1][:, -1]
            # print("population strategy:", rpopulation_strategy.shape, cpopulation_strategy.shape)

            # standard PSRO
            if loss_func == "br":
                # standard PSRO
                rbr, cbr = get_br_to_strat(
                    [rmeta_nash[0][-1], cpopulation_strategy],
                    payoffs=payoffs,
                    symmetric=False,
                )
            elif loss_func == "urr":
                cbr = rmeta_nash[0][-1]
            else:
                raise ValueError("Unknow loss func")
            # cbr = get_br_to_strat(cmeta_nash, payoffs=payoffs, symmetric=False)[0]
            # Update the mixed strategy towards the pure strategy which is returned as the best response to the
            # nash equilibrium that is being trained against.
            # print(cbr.shape, rbr.shape, pop[0][k].shape, pop[1][:, k].shape)
            pop[:, k] = lr * cbr + (1 - lr) * pop[:, k]
            # pop[:, k] = lr * rbr + (1 - lr) * pop[:, k]
            # performance = (
            #     pop[0][k] @ payoffs @ cpopulation_strategy
            #     + 1
            #     - rpopulation_strategy @ payoffs @ pop[1][:, k]
            #     + 1
            # )  # make it positive for pct calculation
            performance = rmeta_nash[0][-1] @ payoffs @ pop[:, k] + 1
            learner_performances[k].append(performance)

            # if the first learner plateaus, add a new policy to the population
            if (
                j == num_learners - 1
                and performance / learner_performances[k][-2] - 1
                < improvement_pct_threshold
            ):
                # rlearner = np.random.uniform(0, 1, (1, rdim))
                # rlearner = rlearner / rlearner.sum(axis=1)[:, None]
                clearner = np.random.uniform(0, 1, (cdim, 1))
                clearner = clearner / clearner.sum(axis=0)[:, None]
                pop = np.hstack((pop, clearner))
                learner_performances.append([0.1])

        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000, symmetric=False)
        exps.append(exp)

        M = pop.T @ payoffs @ pop
        L = M @ M.T
        l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
        l_cards.append(l_card)
    print("end epsro")

    return pop, exps, l_cards


# Define the PSRO rectified nash algorithm
def psro_rectified_steps(
    iters=10,
    payoffs=None,
    verbose=False,
    eps=1e-2,
    seed=0,
    num_start_strats=1,
    num_pseudo_learners=4,
    lr=0.3,
    threshold=0.001,
):
    print("start rectified")
    dim = payoffs.shape[0]
    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (num_start_strats, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]
    counter = 0

    M = pop @ payoffs @ pop.T
    L = M @ M.T
    l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
    l_cards = [l_card]

    while counter < iters * num_pseudo_learners:
        # if counter % (5 * num_pseudo_learners) == 0:
        #    print('iteration: ', int(counter / num_pseudo_learners), ' exp: ', exps[-1])
        #    print('size of population: ', pop.shape[0])

        new_pop = np.copy(pop)
        emp_game_matrix = pop @ payoffs @ pop.T
        averages, _ = fictitious_play(payoffs=emp_game_matrix, iters=iters)

        # go through all policies. If the policy has positive meta Nash mass,
        # find policies it wins against, and play against meta Nash weighted mixture of those policies
        for j in range(pop.shape[0]):
            if counter > iters * num_pseudo_learners:
                return pop, exps, l_cards
            # if positive mass, add a new learner to pop and update it with steps, submit if over thresh
            # keep track of counter
            if averages[-1][j] > eps:
                # create learner
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                new_pop = np.vstack((new_pop, learner))
                idx = new_pop.shape[0] - 1

                current_performance = 0.02
                last_performance = 0.01
                while current_performance / last_performance - 1 > threshold:
                    counter += 1
                    mask = emp_game_matrix[j, :]
                    mask[mask >= 0] = 1
                    mask[mask < 0] = 0
                    weights = np.multiply(mask, averages[-1])
                    weights /= np.maximum(1e-4, weights.sum())  # avoid zero divide.
                    strat = weights @ pop
                    br = get_br_to_strat(strat, payoffs=payoffs)
                    new_pop[idx] = lr * br + (1 - lr) * new_pop[idx]
                    last_performance = current_performance
                    current_performance = new_pop[idx] @ payoffs @ strat + 1

                    if counter % num_pseudo_learners == 0:
                        # count this as an 'iteration'

                        # exploitability
                        exp = get_exploitability(new_pop, payoffs, iters=1000)
                        exps.append(exp)

                        M = pop @ payoffs @ pop.T
                        L = M @ M.T
                        l_card = np.trace(
                            np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0]))
                        )
                        l_cards.append(l_card)

        pop = np.copy(new_pop)
    print("end rectified")

    return pop, exps, l_cards


def iterative_steps(
    iters=5,
    payoffs=None,
    verbose=False,
    seed=0,
    num_learners=4,
    improvement_pct_threshold=0.03,
    lr=0.8,
    loss_func="br",
):
    print("start iterative")
    dim = payoffs.shape[0]
    r = np.random.RandomState(seed)
    br_pop = r.uniform(0, 1, (1 + num_learners, dim))
    br_pop = br_pop / br_pop.sum(axis=1)[:, None]
    pop = r.uniform(0, 1, (1 + num_learners, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]

    M = pop @ payoffs @ pop.T
    L = M @ M.T
    l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
    l_cards = [l_card]

    learner_performances = [[0.1] for i in range(num_learners + 1)]
    for i in range(iters):
        # Define the weighting towards diversity as a function of the fixed population size, this is currently a hyperparameter
        lambda_weight = 0.85
        if i % 10 == 0:
            print("iteration: ", i, " exp full: ", exps[-1])
            print("size of pop: ", pop.shape[0])

        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[0] - j - 1
            emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            meta_nash, _ = fictitious_play(payoffs=emp_game_matrix, iters=1000)
            population_strategy = pop[k - 1].copy()

            if loss_func == "br":
                # standard PSRO
                br = get_br_to_strat(population_strategy, payoffs=payoffs)
            elif loss_func == "dpp":
                # Diverse PSRO
                br = joint_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr)
                br_orig = get_br_to_strat(population_strategy, payoffs=payoffs)
            else:
                raise ValueError("Unknow loss func")

            br_pop[k] = lr * br + (1 - lr) * br_pop[k]
            # mixed oracle with meta nash
            real_br = meta_nash[-1] @ br_pop[:k]
            pop[k] = real_br.copy()
            performance = (
                pop[k] @ payoffs @ population_strategy.T + 1
            )  # make it positive for pct calculation
            learner_performances[k].append(performance)

            # if the first learner plateaus, add a new policy to the population
            if (
                j == num_learners - 1
                and performance / learner_performances[k][-2] - 1
                < improvement_pct_threshold
            ):
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                pop = np.vstack((pop, learner))

                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                br_pop = np.vstack((br_pop, learner))
                learner_performances.append([0.1])

        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000)
        exps.append(exp)

        M = pop @ payoffs @ pop.T
        L = M @ M.T
        l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
        l_cards.append(l_card)
    print("end iteartive")

    return pop, exps, l_cards


def run_experiment(param_seed):
    params, seed = param_seed
    iters = params["iters"]
    num_threads = params["num_threads"]
    lr = params["lr"]
    thresh = params["thresh"]
    psro = params["psro"]
    pipeline_psro = params["pipeline_psro"]
    dpp_psro = params["dpp_psro"]
    rectified = params["rectified"]
    self_play = params["self_play"]
    iterative = params["iterative"]
    epsro = params["epsro"]
    pepsro = params["pepsro"]
    verbose = params["verbose"]

    psro_exps = []
    psro_cardinality = []
    pipeline_psro_exps = []
    pipeline_psro_cardinality = []
    dpp_psro_exps = []
    dpp_psro_cardinality = []
    rectified_exps = []
    rectified_cardinality = []
    self_play_exps = []
    self_play_cardinality = []
    epsro_exps = []
    epsro_cardinality = []
    pepsro_exps = []
    pepsro_cardinality = []
    iterative_exps = []
    iterative_cardinality = []

    print("Experiment: ", seed + 1)
    np.random.seed(seed)
    payoffs_data_dir = os.path.expanduser("~/dataset/expground/")
    with open(payoffs_data_dir + str(args.game_name) + ".pkl", "rb") as fh:
        payoffs = pickle.load(fh)

    if psro:
        # print('PSRO')
        pop, exps, cards = psro_steps(
            iters=iters,
            num_learners=1,
            seed=seed + 1,
            improvement_pct_threshold=thresh,
            lr=lr,
            payoffs=payoffs,
            loss_func="br",
        )
        psro_exps = exps
        psro_cardinality = cards
    if pipeline_psro:
        # print('Pipeline PSRO')
        pop, exps, cards = psro_steps(
            iters=iters,
            num_learners=num_threads,
            seed=seed + 1,
            improvement_pct_threshold=thresh,
            lr=lr,
            payoffs=payoffs,
            loss_func="br",
        )
        pipeline_psro_exps = exps
        pipeline_psro_cardinality = cards
    if dpp_psro:
        # print('Pipeline DPP')
        pop, exps, cards = psro_steps(
            iters=iters,
            num_learners=num_threads,
            seed=seed + 1,
            improvement_pct_threshold=thresh,
            lr=lr,
            payoffs=payoffs,
            loss_func="dpp",
        )
        dpp_psro_exps = exps
        dpp_psro_cardinality = cards
    if pepsro:
        pop, exps, cards = psro_efficient_steps(
            iters=iters,
            num_learners=num_threads,
            seed=seed + 1,
            improvement_pct_threshold=thresh,
            lr=0.8,  # 0.8,
            payoffs=payoffs,
            verbose=verbose,
        )
        pepsro_exps = exps
        pepsro_cardinality = cards
    if epsro:
        pop, exps, cards = psro_efficient_steps(
            iters=iters,
            num_learners=1,
            seed=seed + 1,
            improvement_pct_threshold=thresh,
            lr=0.8,
            payoffs=payoffs,
            verbose=verbose,
        )
        epsro_exps = exps
        epsro_cardinality = cards
    if rectified:
        # print('Rectified')
        pop, exps, cards = psro_rectified_steps(
            iters=iters,
            num_pseudo_learners=num_threads,
            payoffs=payoffs,
            seed=seed + 1,
            lr=lr,
            threshold=thresh,
        )
        rectified_exps = exps
        rectified_cardinality = cards
    if self_play:
        # print('Self-play')
        pop, exps, cards = self_play_steps(
            iters=iters,
            payoffs=payoffs,
            improvement_pct_threshold=thresh,
            lr=lr,
            seed=seed + 1,
        )
        self_play_exps = exps
        self_play_cardinality = cards
    if iterative:
        pop, exps, cards = iterative_steps(
            iters=iters,
            num_learners=1,
            seed=seed + 1,
            improvement_pct_threshold=thresh,
            lr=lr,
            payoffs=payoffs,
        )
        iterative_exps = exps
        iterative_cardinality = cards

    return {
        "psro_exps": psro_exps,
        "psro_cardinality": psro_cardinality,
        "pipeline_psro_exps": pipeline_psro_exps,
        "pipeline_psro_cardinality": pipeline_psro_cardinality,
        "dpp_psro_exps": dpp_psro_exps,
        "dpp_psro_cardinality": dpp_psro_cardinality,
        "rectified_exps": rectified_exps,
        "rectified_cardinality": rectified_cardinality,
        "self_play_exps": self_play_exps,
        "self_play_cardinality": self_play_cardinality,
        "epsro_exps": epsro_exps,
        "epsro_cardinality": epsro_cardinality,
        "pepsro_exps": pepsro_exps,
        "pepsro_cardinality": pepsro_cardinality,
        "iterative_exps": iterative_exps,
        "iterative_cardinality": iterative_cardinality,
    }


def run_experiments(
    num_experiments=2,
    iters=40,
    num_threads=20,
    lr=0.6,
    thresh=0.001,
    logscale=True,
    psro=False,
    pipeline_psro=False,
    rectified=False,
    self_play=False,
    dpp_psro=False,
    epsro=False,
    pepsro=False,
    iterative=False,
    verbose=False,
):

    params = {
        "num_experiments": num_experiments,
        "iters": iters,
        "num_threads": num_threads,
        "lr": lr,
        "thresh": thresh,
        "psro": psro,
        "pipeline_psro": pipeline_psro,
        "dpp_psro": dpp_psro,
        "rectified": rectified,
        "self_play": self_play,
        "epsro": epsro,
        "pepsro": pepsro,
        "iterative": iterative,
        "verbose": verbose,
    }

    psro_exps = []
    psro_cardinality = []
    pipeline_psro_exps = []
    pipeline_psro_cardinality = []
    dpp_psro_exps = []
    dpp_psro_cardinality = []
    rectified_exps = []
    rectified_cardinality = []
    self_play_exps = []
    self_play_cardinality = []
    epsro_exps = []
    epsro_cardinality = []
    pepsro_exps = []
    pepsro_cardinality = []
    iterative_exps = []
    iterative_cardinality = []

    with open(
        os.path.join(PATH_RESULTS, "params.json"), "w", encoding="utf-8"
    ) as json_file:
        json.dump(params, json_file, indent=4)

    result = []

    if args.plot:
        data_file = os.path.join(DATASET, args.game_name, "data.p")
        with open(data_file, "rb") as f:
            data = pickle.load(f)
        if psro:
            psro_exps = data["psro_exps"]
            psro_cardinality = data["psro_cardinality"]
        if pipeline_psro:
            pipeline_psro_exps = data["pipeline_psro_exps"]
            pipeline_psro_cardinality = data["pipeline_psro_cardinality"]
        if dpp_psro:
            dpp_psro_exps = data["dpp_psro_exps"]
            dpp_psro_cardinality = data["dpp_psro_cardinality"]
        if rectified:
            rectified_exps = data["rectified_exps"]
            rectified_cardinality = data["rectified_cardinality"]
        if self_play:
            self_paly_exps = data["self_play_exps"]
            self_play_cardinality = data["self_play_cardinality"]
        if epsro:
            epsro_exps = data["epsro_exps"]
            epsro_cardinality = data["epsro_cardinality"]
        if pepsro:
            pepsro_exps = data["pepsro_exps"]
            pepsro_cardinality = data["pepsro_cardinality"]
        if iterative:
            iterative_exps = data["iterative_exps"]
            iterative_cardinality = data["iterative_cardinality"]
    else:
        # print(args.mp)
        if args.mp == False:
            for i in range(num_experiments):
                result.append(run_experiment((params, i)))

        else:
            pool = mp.Pool()
            result = pool.map(
                run_experiment, [(params, i) for i in range(num_experiments)]
            )

        for r in result:
            psro_exps.append(r["psro_exps"])
            psro_cardinality.append(r["psro_cardinality"])
            pipeline_psro_exps.append(r["pipeline_psro_exps"])
            pipeline_psro_cardinality.append(r["pipeline_psro_cardinality"])
            dpp_psro_exps.append(r["dpp_psro_exps"])
            dpp_psro_cardinality.append(r["dpp_psro_cardinality"])
            rectified_exps.append(r["rectified_exps"])
            rectified_cardinality.append(r["rectified_cardinality"])
            self_play_exps.append(r["self_play_exps"])
            self_play_cardinality.append(r["self_play_cardinality"])
            epsro_exps.append(r["epsro_exps"])
            epsro_cardinality.append(r["epsro_cardinality"])
            pepsro_exps.append(r["pepsro_exps"])
            pepsro_cardinality.append(r["pepsro_cardinality"])
            iterative_exps.append(r["iterative_exps"])
            iterative_cardinality.append(r["iterative_cardinality"])

        d = {
            "psro_exps": psro_exps,
            "psro_cardinality": psro_cardinality,
            "pipeline_psro_exps": pipeline_psro_exps,
            "pipeline_psro_cardinality": pipeline_psro_cardinality,
            "dpp_psro_exps": dpp_psro_exps,
            "dpp_psro_cardinality": dpp_psro_cardinality,
            "rectified_exps": rectified_exps,
            "rectified_cardinality": rectified_cardinality,
            "self_play_exps": self_play_exps,
            "self_play_cardinality": self_play_cardinality,
            "epsro_exps": epsro_exps,
            "epsro_cardinality": epsro_cardinality,
            "pepsro_exps": pepsro_exps,
            "pepsro_cardinality": pepsro_cardinality,
            "iterative_exps": iterative_exps,
            "iterative_cardinality": iterative_cardinality,
        }
        pickle.dump(d, open(os.path.join(PATH_RESULTS, "data.p"), "wb"))

    def plot_error(data, label=""):
        data_mean = np.mean(np.array(data), axis=0)
        error_bars = stats.sem(np.array(data))
        plt.plot(data_mean, label=label)
        plt.fill_between(
            [i for i in range(data_mean.size)],
            np.squeeze(data_mean - error_bars),
            np.squeeze(data_mean + error_bars),
            alpha=alpha,
        )

    alpha = 0.4
    for j in range(3):
        fig_handle = plt.figure()

        if psro:
            if j == 0:
                plot_error(psro_exps, label="PSRO")
            elif j == 1:
                plot_error(psro_cardinality, label="PSRO")
            elif j == 2:
                plot_error(psro_exps, label="PSRO")
        if pipeline_psro:
            if j == 0:
                plot_error(pipeline_psro_exps, label="P-PSRO")
            elif j == 1:
                plot_error(pipeline_psro_cardinality, label="P-PSRO")
            elif j == 2:
                plot_error(pipeline_psro_exps, label="P-PSRO")
        if rectified:
            if j == 0:
                length = min([len(l) for l in rectified_exps])
                for i, l in enumerate(rectified_exps):
                    rectified_exps[i] = rectified_exps[i][:length]
                plot_error(rectified_exps, label="PSRO-rN")
            elif j == 1:
                length = min([len(l) for l in rectified_cardinality])
                for i, l in enumerate(rectified_cardinality):
                    rectified_cardinality[i] = rectified_cardinality[i][:length]
                plot_error(rectified_cardinality, label="PSRO-rN")
            elif j == 2:
                length = min([len(l) for l in rectified_exps])
                for i, l in enumerate(rectified_exps):
                    rectified_exps[i] = rectified_exps[i][:length]
                plot_error(rectified_exps, label="PSRO-rN")
        if self_play:
            if j == 0:
                plot_error(self_play_exps, label="Self-play")
            elif j == 1:
                plot_error(self_play_cardinality, label="Self-play")
            elif j == 2:
                plot_error(self_play_exps, label="Self-play")
        if dpp_psro:
            if j == 0:
                plot_error(dpp_psro_exps, label="DPPSRO")
            elif j == 1:
                plot_error(dpp_psro_cardinality, label="DPPSRO")
            elif j == 2:
                plot_error(dpp_psro_exps, label="DPPSRO")
        if epsro:
            if j == 0:
                plot_error(epsro_exps, label="NEPSRO")
            elif j == 1:
                plot_error(epsro_cardinality, label="NEPSRO")
            elif j == 2:
                plot_error(epsro_exps, label="NEPSRO")
        if pepsro:
            if j == 0:
                plot_error(pepsro_exps, label="EPSRO")
            elif j == 1:
                plot_error(pepsro_cardinality, label="EPSRO")
            elif j == 2:
                plot_error(pepsro_exps, label="EPSRO")
        if iterative:
            if j == 0:
                plot_error(iterative_exps, label="Mixed-Oracles")
            elif j == 1:
                plot_error(iterative_cardinality, label="Mixed-Oracles")
            elif j == 2:
                plot_error(iterative_exps, label="Mixed-Oracles")

        plt.legend(loc="upper left")
        plt.title(args.game_name)

        if logscale and (j == 0):
            plt.yscale("log")

        if j == 0:
            string = "Exploitability Log ({})".format(lr)
        elif j == 1:
            string = "Cardinality ({})".format(lr)
        elif j == 2:
            string = "Exploitability Standard ({})".format(lr)

        plt.savefig(os.path.join(PATH_RESULTS, "figure_" + string + ".pdf"))


if __name__ == "__main__":
    start_time = time.time()
    run_experiments(
        num_experiments=args.nb_exps,
        num_threads=2,
        iters=args.nb_iters,
        lr=args.lr,
        thresh=TH,
        psro=True,
        pipeline_psro=True,  # True,
        rectified=True,  # True,
        self_play=True,  # True,
        dpp_psro=False,
        epsro=True,
        pepsro=True,  # True,
        iterative=True,
        verbose=True,
    )
    end_time = time.time()
    # print('Total time for {}'.format(args.nb_exps) + ' experiments was {}'.format(end_time - start_time) + ' seconds when multiprocessing was: {}'.format(args.mp))
