from glob import escape
from sre_parse import FLAGS
from statistics import mean
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from scipy import stats
import pickle
import json
from numpy.random import RandomState
import argparse
import multiprocessing as mp
from tqdm import tqdm
from scipy.spatial.distance import cdist
from scipy.spatial.distance import pdist, squareform
from sklearn.random_projection import johnson_lindenstrauss_min_dim, SparseRandomProjection
import cvxpy as cp


np.set_printoptions(suppress=True)
np.set_printoptions(precision=3)

np.random.seed(0)

parser = argparse.ArgumentParser(description='Random Game Skill')
parser.add_argument('--nb_iters', type=int, default=200) # 200
parser.add_argument('--nb_exps', type=int, default=5)
parser.add_argument('--mp', default=False, action='store_false', help='Set --mp for False, otherwise leave it for True')
parser.add_argument('--game_name', type=str, default='go(board_size=4,komi=6.5)')

args = parser.parse_args()

LR = 0.5
TH = 0.03
TM = 0.8
WS = 40
MU = 0.02
DIV = 0.01
expected_card = []
sizes = []

time_string = time.strftime("%Y%m%d-%H%M%S")
PATH_RESULTS = os.path.join('results', time_string + '_' + str(args.game_name) + '_' + str(LR)+ '_' + str(TM))
if not os.path.exists(PATH_RESULTS):
    os.makedirs(PATH_RESULTS)


# Search over the pure strategies to find the BR to a strategy
def get_br_to_strat(strat, payoffs=None, verbose=False):
    row_weighted_payouts = strat @ payoffs
    br = np.zeros_like(row_weighted_payouts)
    br[np.argmin(row_weighted_payouts)] = 1
    if verbose:
        print(row_weighted_payouts[np.argmin(row_weighted_payouts)], "exploitability")
    return br


# Fictituous play as a nash equilibrium solver
def fictitious_play(iters=2000, payoffs=None, verbose=False):
    dim = payoffs.shape[0]
    pop = np.random.uniform(0, 1, (1, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    averages = pop
    exps = []
    for i in range(iters):
        average = np.average(pop, axis=0)
        br = get_br_to_strat(average, payoffs=payoffs)
        exp1 = average @ payoffs @ br.T
        exp2 = br @ payoffs @ average.T
        exps.append(exp2 - exp1)
        averages = np.vstack((averages, average))
        pop = np.vstack((pop, br))
    return averages, exps


# Solve exploitability of a nash equilibrium over a fixed population
def get_exploitability(pop, payoffs, iters=1000):
    emp_game_matrix = pop @ payoffs @ pop.T
    averages, _ = fictitious_play(payoffs=emp_game_matrix, iters=iters)
    strat = averages[-1] @ pop  # Aggregate
    test_br = get_br_to_strat(strat, payoffs=payoffs)
    exp1 = strat @ payoffs @ test_br.T
    exp2 = test_br @ payoffs @ strat
    return exp2 - exp1

def pop_effective_diversity(pop, payoff, iters):
    emp_game_matrix = pop @ payoff
    row_player_dim = emp_game_matrix.shape[0]
    column_player_dim = emp_game_matrix.shape[1]
    row_pop = np.random.uniform(0, 1, (1, row_player_dim))
    row_pop = row_pop / row_pop.sum()
    column_pop = np.random.uniform(0, 1, (1, column_player_dim))
    column_pop = column_pop / column_pop.sum()
    for i in range(iters):
        row_avg = np.average(row_pop, axis=0)
        column_avg = np.average(column_pop, axis=0)
        br_column = get_br_to_strat(row_avg, emp_game_matrix)
        br_row = get_br_to_strat(column_avg, -emp_game_matrix.T)
        row_pop = np.vstack((row_pop, br_row))
        column_pop = np.vstack((column_pop, br_column))
    row_avg = np.average(row_pop, axis=0)
    column_avg = np.average(column_pop, axis=0)
    print(f"Nash is {row_avg}")
    return -row_avg @ emp_game_matrix @ column_avg.T

def distance_loss(pop, payoffs, meta_nash, k, lambda_weight, lr):
    dim = payoffs.shape[0]

    br = np.zeros((dim,))
    cards = []

    if np.random.randn() < lambda_weight:
        aggregated_enemy = meta_nash @ pop[:k]
        values = payoffs @ aggregated_enemy.T
        br[np.argmax(values)] = 1
    else:
        for i in range(dim):
            br_tmp = np.zeros((dim,))
            br_tmp[i] = 1.

            pop_k = lr * br_tmp + (1 - lr) * pop[k]
            pop_tmp = np.vstack((pop[:k], pop_k))
            M = pop_tmp @ payoffs @ pop[:k].T
            old_payoff = M[0:-1].T
            new_vec = M[-1].reshape(-1, 1)
            distance = distance_solver(old_payoff, new_vec)
            cards.append(distance)
        br[np.argmax(cards)] = 1

    return br

def distance_solver(A, b):
    One = np.ones(shape=(A.shape[1], 1))
    I = np.identity(A.shape[0])
    A_pinv = np.linalg.pinv(A)
    I_minus_AA_pinv = I - A @ A_pinv
    Sigma_min = min(np.linalg.svd(A.T, full_matrices=True)[1])
    distance = ((Sigma_min ** 2) / A.shape[1]) * ((1 - (One.T @ A_pinv @ b)[0, 0]) ** 2) + np.square(
        I_minus_AA_pinv @ b).sum()
    return distance


def joint_loss(pop, payoffs, meta_nash, k, lambda_weight, lr):
    dim = payoffs.shape[0]

    br = np.zeros((dim,))
    values = []
    cards = []

    aggregated_enemy = meta_nash @ pop[:k]
    values = payoffs @ aggregated_enemy.T

    if np.random.randn() < lambda_weight:
        br[np.argmax(values)] = 1
    
    else:
        for i in range(dim):
            br_tmp = np.zeros((dim, ))
            br_tmp[i] = 1.

            aggregated_enemy = meta_nash @ pop[:k]
            pop_k = lr * br_tmp + (1 - lr) * pop[k]
            pop_tmp = np.vstack((pop[:k], pop_k))
            M = pop_tmp @ payoffs @ pop_tmp.T
            # metanash_tmp, _ = fictitious_play(payoffs=M, iters=1000)
            #L = np.diag(metanash_tmp[-1]) @ M @ M.T @ np.diag(metanash_tmp[-1])
            L = M @ M.T
            l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
            cards.append(l_card)
        br[np.argmax(cards)] = 1

    return br

def js_divergence(n, target_dist):
    def entropy(p_k):
        p_k = p_k + 1e-8
        p_k = p_k / sum(p_k)
        return -(p_k * np.log(p_k)).sum()

    original_dist = np.zeros(shape=target_dist.shape)
    original_dist[n] = 1
    return 2 * entropy(original_dist + target_dist) - entropy(original_dist) - entropy(target_dist)

def kl_divergence(prob_a, prob_b):
    if prob_a.ndim == 1:
        prob_a = prob_a.reshape(1, -1)
    if prob_b.ndim == 1:
        prob_b = prob_b.reshape(1, -1)
    prob_a += 1e-3
    prob_a /= prob_a.sum()
    prob_b += 1e-3
    prob_b = (prob_b.T / prob_b.sum(1)).T
    res = prob_a * np.log(prob_a/prob_b)
    return res.sum(1)

def divergence_loss(pop, payoffs, meta_nash, k, lambda_weight, lr, i):
    dim = payoffs.shape[0]
    br = np.zeros((dim,))
    if i <= 75:
        alpha = 500
    elif i <= 150:
        alpha = 100
    else:
        alpha = 50
    if np.random.randn() < lambda_weight:
        aggregated_enemy = meta_nash @ pop[:k]
        values = payoffs @ aggregated_enemy.T
        br[np.argmax(values)] = 1
        # print(f'Best Response {np.argmax(values)}')
    else:
        aggregated_enemy = meta_nash @ pop[:k]
        values = payoffs @ aggregated_enemy.T

        aggregated_enemy = aggregated_enemy.reshape(-1)
        # min_index = [i for i in range(len(aggregated_enemy)) if aggregated_enemy[i] == np.min(aggregated_enemy)]
        diverse_response = [values[i] + alpha * js_divergence(i, aggregated_enemy) for i in
                            range(len(aggregated_enemy))]
        selected_index = np.argmax(diverse_response)
        br[selected_index] = 1
        # print(f'Diverse: value[{np.argmax(values)}]={np.max(values)} diverse[{selected_index}]={np.max(diverse_response)}')

    return br


def fsp_non_symmetric_game(emp_game_matrix, iters=2000):

    row_player_dim = emp_game_matrix.shape[0]
    column_player_dim = emp_game_matrix.shape[1]
    row_avg = np.random.uniform(0, 1, row_player_dim)
    row_avg = row_avg / row_avg.sum()
    column_avg = np.random.uniform(0, 1, column_player_dim)
    column_avg = column_avg / column_avg.sum()
    for i in range(iters):
        # row_avg = np.average(row_pop, axis=0)
        # column_avg = np.average(column_pop, axis=0)
        br_column = get_br_to_strat(row_avg, emp_game_matrix)
        br_row = get_br_to_strat(column_avg, -emp_game_matrix.T)
        row_avg = (row_avg * (i+1) + br_row) / (i+2)
        column_avg = (column_avg * (i+1) + br_column) / (i+2)
    # row_avg = np.average(row_pop, axis=0)
    # column_avg = np.average(column_pop, axis=0)
    # print(f"Nash is {row_avg}")
    return abs(row_avg @ emp_game_matrix @ column_avg.T)


def convex_hull_kl_min(new_strategy, pool_strategies):

    new_strategy = new_strategy.reshape(-1).astype(np.float64)
    new_strategy = np.maximum(new_strategy, 1e-10)
    new_strategy = new_strategy / new_strategy.sum()


    pool_array = np.array(pool_strategies, dtype=np.float64)
    pool_array = np.maximum(pool_array, 1e-10)
    pool_array = pool_array / pool_array.sum(axis=1, keepdims=True)
    num_strategies = pool_array.shape[0]
    action_dim = pool_array.shape[1]


    if num_strategies == 0:
        return np.inf

    alpha = cp.Variable(num_strategies)


    q = alpha @ pool_array


    kl_div = cp.sum(new_strategy * cp.log(new_strategy) - new_strategy * cp.log(q))


    objective = cp.Minimize(kl_div)

    constraints = [
        alpha >= 1e-6,
        cp.sum(alpha) == 1.0
    ]


    prob = cp.Problem(objective, constraints)
    prob.solve(
        solver=cp.SCS,
        max_iters=5000,
        feastol=1e-3,
        reltol=1e-3,
        verbose=True
    )

    if prob.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
        return objective.value if objective.value is not None else np.inf

    else:

        min_kl = np.inf
        for strat in pool_strategies:
            kl = kl_divergence(new_strategy, strat)
            if kl < min_kl:
                min_kl = kl
        return min_kl


def sample_convex_hull(population, num_samples=100):
    """生成凸包上的采样点"""
    num_strategies = population.shape[0]
    samples = []

    for _ in range(num_samples):

        alpha = np.random.dirichlet(np.ones(num_strategies))

        convex_strategy = alpha @ population
        samples.append(convex_strategy)
    samples = np.vstack([samples, population])
    return np.array(samples)

def psd_update(pop, payoffs, meta_nash, k, lr, it, convx=False):
    dim = payoffs.shape[0]
    cards = []
    lambda_weight = 0.85
    update_psd_term = 0.6

    aggregated_enemy = meta_nash @ pop[:k]
    values = payoffs @ aggregated_enemy.T
    if np.random.randn() < lambda_weight:
        br = np.zeros((dim,))
        br[np.argmax(values)] = 1
        pop[k] = lr * br + (1 - lr) * pop[k]
    else:
        br = np.zeros((dim,))
        convex_hull_samples = sample_convex_hull(pop[:k])
        for i in range(dim):
            br_tmp = np.zeros((dim, ))
            br_tmp[i] = 1.
            pop_k = br_tmp
            if convx:
                min_kl = kl_divergence(pop_k.copy(), convex_hull_samples).min()
            else:
                min_kl = kl_divergence(pop_k.copy(), pop[:k].copy()).min()
            cards.append(min_kl * 0.01  + values[i])
        br[np.argmax(cards)] = 1
        pop[k] = update_psd_term * br + (1 - update_psd_term) * pop[k]

def gaussian_kernel(X, Y, gamma=None):
    if gamma is None:
        gamma = 1.0 / X.shape[1]
    pairwise_dists = cdist(X, Y, 'sqeuclidean')
    kernel_matrix = np.exp(-gamma * pairwise_dists)
    return kernel_matrix

def random_projection(X, n_components=None):
    n_features = X.shape[1]
    if n_components is None:
        n_components = johnson_lindenstrauss_min_dim(n_samples=X.shape[0], eps=0.1)
    if n_components >= n_features:
        return X
    transformer = SparseRandomProjection(n_components=n_components, random_state=0)
    X_projected = transformer.fit_transform(X)
    return X_projected

def linear_represented( new_strategy,pool_strategies, convx=False):

    if not convx:
        X = np.array(pool_strategies).T
        Y = new_strategy.reshape(-1, 1)
        coefficients, residuals, _, _ = np.linalg.lstsq(X, Y, rcond=None)
        residual_norm = np.sqrt(residuals[0]) if len(residuals) > 0 else 0
        return residual_norm

    else:

        new_strategy = new_strategy.reshape(-1)  # (action_dim,)
        pool_array = np.array(pool_strategies)  # (num_strategies, action_dim)
        num_strategies = pool_array.shape[0]

        alpha = cp.Variable(num_strategies)


        convex_combination = cp.sum(cp.multiply(alpha[:, None], pool_array), axis=0)
        objective = cp.Minimize(cp.norm(new_strategy - convex_combination, 2))

        constraints = [
            alpha >= 1e-8,
            cp.sum(alpha) == 1.0
        ]


        prob = cp.Problem(objective, constraints)
        prob.solve(solver=cp.ECOS, max_iters=1000, feastol=1e-4)


        if prob.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
            return objective.value if objective.value is not None else np.inf





def sparse_update(pop, payoffs, meta_nash, k, lr, selected_pop,term,div,convx=False):
    dim = payoffs.shape[0]
    cards = []
    lambda_weight = 0.85
    update_psd_term = term

    aggregated_enemy = meta_nash @ selected_pop
    values = payoffs @ aggregated_enemy.T
    if np.random.randn() < lambda_weight:
        br = np.zeros((dim,))
        br[np.argmax(values)] = 1
        pop[k] = lr * br + (1 - lr) * pop[k]
    else:
        br = np.zeros((dim,))
        for i in range(dim):
            br_tmp = np.zeros((dim, ))
            br_tmp[i] = 1.
            pop_k = br_tmp
            cards.append(linear_represented(pop_k.copy(), selected_pop.copy(),convx=convx) * div  + values[i])
        br[np.argmax(cards)] = 1
        pop[k] = update_psd_term * br + (1 - update_psd_term) * pop[k]

def psd_psro_steps(iters=5, payoffs=None, verbose=False, seed=0,
                        num_learners=4, improvement_pct_threshold=.03, lr=.2, loss_func='dpp', full=False):
    dim = payoffs.shape[0]

    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (1 + num_learners, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]
    l_cards = [exp]
    time_records = [0.0]

    learner_performances = [[.1] for i in range(num_learners + 1)]
    for i in tqdm(range(iters)):
        iter_start_time = time.time()
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[0] - j - 1
            emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            meta_nash, _ = fictitious_play(payoffs=emp_game_matrix, iters=1000)
            population_strategy = meta_nash[-1] @ pop[:k]  # aggregated enemy according to nash
            if loss_func == "psd":
                psd_update(pop, payoffs, meta_nash[-1], k, lr, i)
            elif loss_func == "convx_psd":
                psd_update(pop, payoffs, meta_nash[-1], k, lr, i, convx=True)
            performance = pop[k] @ payoffs @ population_strategy.T + 1  # make it positive for pct calculation
            learner_performances[k].append(performance)

            # if the first learner plateaus, add a new policy to the population
            if j == num_learners - 1 and performance / learner_performances[k][-2] - 1 < improvement_pct_threshold:
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                pop = np.vstack((pop, learner))
                learner_performances.append([0.1])

        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000)
        print(f"Iteration: {i}, Exp: {exp}")
        exps.append(exp)

        emp_game_matrix = pop[:k] @ payoffs
        l_cards.append(fsp_non_symmetric_game(emp_game_matrix))
        iter_time = time.time() - iter_start_time
        time_records.append(iter_time)

    return pop, exps, l_cards, time_records

def our_steps(iters=5, payoffs=None, verbose=False, seed=0,
               num_learners=4, improvement_pct_threshold=.03, lr=.2,term=0.8,window=100,mu=0.02,div=0.01, loss_func='dpp', full=False):
    dim = payoffs.shape[0]
    print(dim)
    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (1 + num_learners, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    l_cards = [exp]

    exps = [exp]
    time_records = [0.0]
    window_size = window
    distance_window = []

    # M = pop @ payoffs @ pop.T
    # L = M @ M.T
    # l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
    # pop_eff = [pop_effectivity]

    learner_performances = [[.1] for i in range(num_learners + 1)]

    for i in tqdm(range(iters)):
        # Define the weighting towards diversity as a function of the fixed population size, this is currently a hyperparameter
        iter_start_time = time.time()
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[0] - j - 1

            emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            meta_nash, _ = fictitious_play(payoffs=emp_game_matrix, iters=1000)
            population_strategy = meta_nash[-1] @ pop[:k]  # aggregated enemy according to nash

            if loss_func == 'br':
                # standard PSRO
                br = get_br_to_strat(population_strategy, payoffs=payoffs)
                pop[k] = lr * br + (1 - lr) * pop[k]
            elif loss_func == "sparse":
                sparse_update(pop, payoffs, meta_nash[-1], k, lr, pop[:k],term,div)
            elif loss_func == "convx_sparse":
                sparse_update(pop, payoffs, meta_nash[-1], k, lr, pop[:k],term,div,convx=True)
            else:
                raise
            # br = get_br_to_strat(population_strategy, payoffs=payoffs)
            # pop[k] = lr * br + (1 - lr) * pop[k]
            performance = pop[k] @ payoffs @ population_strategy.T + 1  # make it positive for pct calculation

            learner_performances[k].append(performance)


            # if the first learner plateaus, add a new policy to the population

            if j == num_learners - 1 and performance / learner_performances[k][-2] - 1 < improvement_pct_threshold:
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                if loss_func == "sparse":
                    distance = linear_represented(learner,pop)
                else:
                    distance = linear_represented(learner,pop,convx=True)
                distance_window.append(distance)
                if len(distance_window) > window_size:
                    distance_window.pop(0)
                mu=np.percentile(distance_window, 50)

                if distance > mu or i%3==0:
                    pop = np.vstack((pop, learner))
                    learner_performances.append([0.1])



        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000)
        emp_game_matrix = pop[:k] @ payoffs
        l_cards.append(fsp_non_symmetric_game(emp_game_matrix))
        exps.append(exp)
        print(f"Iteration: {i}, Exp: {exp},PE: {l_cards[-1]}")




        # M = pop @ payoffs @ pop.T
        # L = M @ M.T
        # l_card = np.trace(np.eye(L.shape[0]) - np.linalg.inv(L + np.eye(L.shape[0])))
        # pop_eff.append(pop_effectivity)
        iter_time = time.time() - iter_start_time
        time_records.append(iter_time)
    return pop, exps, l_cards,time_records

def sparse_steps(iters=5, payoffs=None, verbose=False, seed=0,
                        num_learners=4, improvement_pct_threshold=.03, lr=.2, term=0.8,div=0.01,loss_func='br', full=False):
    dim = payoffs.shape[0]

    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (1 + num_learners, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    l_cards = [exp]
    exps = [exp]

    time_records = [0.0]
    learner_performances = [[.1] for i in range(num_learners + 1)]
    for i in tqdm(range(iters)):
        # Define the weighting towards diversity as a function of the fixed population size, this is currently a hyperparameter
        lambda_weight = 0.85
        iter_start_time = time.time()
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[0] - j - 1
            emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            meta_nash, _ = fictitious_play(payoffs=emp_game_matrix, iters=1000)
            population_strategy = meta_nash[-1] @ pop[:k]  # aggregated enemy according to nash

            if loss_func == 'br':
                # standard PSRO
                br = get_br_to_strat(population_strategy, payoffs=payoffs)
                pop[k] = lr * br + (1 - lr) * pop[k]
            elif loss_func == 'dpp':
                # Diverse PSRO
                br = joint_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr)
                br_orig = get_br_to_strat(population_strategy, payoffs=payoffs)
                pop[k] = lr * br + (1 - lr) * pop[k]
            elif loss_func == "bd_rd":
                if np.random.uniform() < 0.5:
                    br = divergence_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr, i)
                else:
                    br = distance_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr)
                pop[k] = lr * br + (1 - lr) * pop[k]
            elif loss_func == "psd":
                psd_update(pop, payoffs, meta_nash[-1], k, lr, i)
            elif loss_func == "sparse":
                sparse_update(pop, payoffs, meta_nash[-1], k, lr, pop[:k],term,div,convx=True)
            else:
                raise

            # Update the mixed strategy towards the pure strategy which is returned as the best response to the
            # nash equilibrium that is being trained against.

            performance = pop[k] @ payoffs @ population_strategy.T + 1  # make it positive for pct calculation
            learner_performances[k].append(performance)

            # if the first learner plateaus, add a new policy to the population
            if j == num_learners - 1 and performance / learner_performances[k][-2] - 1 < improvement_pct_threshold:
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                pop = np.vstack((pop, learner))
                learner_performances.append([0.1])

        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000)
        print(f"Iteration: {i}, Exp: {exp}")
        exps.append(exp)
        emp_game_matrix = pop[:k] @ payoffs
        l_cards.append(fsp_non_symmetric_game(emp_game_matrix))
        iter_time = time.time() - iter_start_time
        time_records.append(iter_time)

    return pop, exps, l_cards,time_records

def psro_steps(iters=5, payoffs=None, verbose=False, seed=0,
                        num_learners=4, improvement_pct_threshold=.03, lr=.2, loss_func='dpp', full=False):
    dim = payoffs.shape[0]

    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (1 + num_learners, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    l_cards = [exp]
    exps = [exp]

    time_records = [0.0]
    learner_performances = [[.1] for i in range(num_learners + 1)]
    for i in tqdm(range(iters)):
        # Define the weighting towards diversity as a function of the fixed population size, this is currently a hyperparameter
        lambda_weight = 0.85
        iter_start_time = time.time()
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = pop.shape[0] - j - 1
            emp_game_matrix = pop[:k] @ payoffs @ pop[:k].T
            meta_nash, _ = fictitious_play(payoffs=emp_game_matrix, iters=1000)
            population_strategy = meta_nash[-1] @ pop[:k]  # aggregated enemy according to nash

            if loss_func == 'br':
                # standard PSRO
                br = get_br_to_strat(population_strategy, payoffs=payoffs)
            elif loss_func == 'dpp':
                # Diverse PSRO
                br = joint_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr)
                br_orig = get_br_to_strat(population_strategy, payoffs=payoffs)
            elif loss_func == "bd_rd":
                if np.random.uniform() < 0.5:
                    br = divergence_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr, i)
                else:
                    br = distance_loss(pop, payoffs, meta_nash[-1], k, lambda_weight, lr)
            else:
                raise

            # Update the mixed strategy towards the pure strategy which is returned as the best response to the
            # nash equilibrium that is being trained against.
            pop[k] = lr * br + (1 - lr) * pop[k]
            performance = pop[k] @ payoffs @ population_strategy.T + 1  # make it positive for pct calculation
            learner_performances[k].append(performance)

            # if the first learner plateaus, add a new policy to the population
            if j == num_learners - 1 and performance / learner_performances[k][-2] - 1 < improvement_pct_threshold:
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                pop = np.vstack((pop, learner))
                learner_performances.append([0.1])

        # calculate exploitability for meta Nash of whole population
        exp = get_exploitability(pop, payoffs, iters=1000)
        print(f"Iteration: {i}, Exp: {exp}")
        exps.append(exp)

        emp_game_matrix = pop[:k] @ payoffs
        l_cards.append(fsp_non_symmetric_game(emp_game_matrix))
        iter_time = time.time() - iter_start_time
        time_records.append(iter_time)

    return pop, exps, l_cards,time_records


# Define the self-play algorithm
def self_play_steps(iters=10, payoffs=None, verbose=False, improvement_pct_threshold=.03, lr=.2, seed=0):
    dim = payoffs.shape[0]
    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (2, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]
    performances = [.01]
    time_records = [0.0]

    pop_effectivity = pop_effective_diversity(pop, payoffs, iters=2000)
    pop_eff = [pop_effectivity]

    for i in range(iters):
        iter_start_time = time.time()
        br = get_br_to_strat(pop[-2], payoffs=payoffs)
        pop[-1] = lr * br + (1 - lr) * pop[-1]
        performance = pop[-1] @ payoffs @ pop[-2].T + 1
        performances.append(performance)
        if performance / performances[-2] - 1 < improvement_pct_threshold:
            learner = np.random.uniform(0, 1, (1, dim))
            learner = learner / learner.sum(axis=1)[:, None]
            pop = np.vstack((pop, learner))
        exp = get_exploitability(pop, payoffs, iters=1000)
        exps.append(exp)

        pop_effectivity = pop_effective_diversity(pop, payoffs, iters=2000)
        pop_eff.append(pop_effectivity)
        iter_time = time.time() - iter_start_time
        time_records.append(iter_time)

    return pop, exps, pop_eff,time_records


# Define the PSRO rectified nash algorithm
def psro_rectified_steps(iters=10, payoffs=None, verbose=False, eps=1e-2, seed=0,
                         num_start_strats=1, num_pseudo_learners=4, lr=0.3, threshold=0.001):
    dim = payoffs.shape[0]
    r = np.random.RandomState(seed)
    pop = r.uniform(0, 1, (num_start_strats, dim))
    pop = pop / pop.sum(axis=1)[:, None]
    exp = get_exploitability(pop, payoffs, iters=1000)
    exps = [exp]
    counter = 0

    l_cards = [exp]
    time_records = [0.0]
    while counter < iters * num_pseudo_learners:
        iter_start_time = time.time()
        new_pop = np.copy(pop)
        emp_game_matrix = pop @ payoffs @ pop.T
        averages, _ = fictitious_play(payoffs=emp_game_matrix, iters=iters)

        # go through all policies. If the policy has positive meta Nash mass,
        # find policies it wins against, and play against meta Nash weighted mixture of those policies
        for j in range(pop.shape[0]):
            if counter > iters * num_pseudo_learners:
                return pop, exps, l_cards,time_records
            # if positive mass, add a new learner to pop and update it with steps, submit if over thresh
            # keep track of counter
            if averages[-1][j] > eps:
                # create learner
                learner = np.random.uniform(0, 1, (1, dim))
                learner = learner / learner.sum(axis=1)[:, None]
                new_pop = np.vstack((new_pop, learner))
                idx = new_pop.shape[0] - 1

                current_performance = 0.02
                last_performance = 0.01
                while current_performance / last_performance - 1 > threshold:
                    counter += 1
                    mask = emp_game_matrix[j, :]
                    mask[mask >= 0] = 1
                    mask[mask < 0] = 0
                    weights = np.multiply(mask, averages[-1])
                    weights /= weights.sum()
                    strat = weights @ pop
                    br = get_br_to_strat(strat, payoffs=payoffs)
                    new_pop[idx] = lr * br + (1 - lr) * new_pop[idx]
                    last_performance = current_performance
                    current_performance = new_pop[idx] @ payoffs @ strat + 1

                    if counter % num_pseudo_learners == 0:
                        # count this as an 'iteration'

                        # exploitability
                        exp = get_exploitability(new_pop, payoffs, iters=1000)
                        exps.append(exp)

                        emp_game_matrix2 = new_pop @ payoffs
                        l_cards.append(fsp_non_symmetric_game(emp_game_matrix2))

        pop = np.copy(new_pop)
        iter_time = time.time() - iter_start_time
        time_records.append(iter_time)
    return pop, exps, l_cards,time_records


def run_experiment(param_seed):
    params, seed = param_seed
    iters = params['iters']
    num_threads = params['num_threads']
    lr = params['lr']
    thresh = params['thresh']
    term = params['term']
    window = params['window']
    mu=params['mu']
    div=params['div']
    psro = params['psro']
    pipeline_psro = params['pipeline_psro']
    dpp_psro = params['dpp_psro']
    rectified = params['rectified']
    self_play = params['self_play']
    psd_psro = params['psd_psro']
    bd_rd_psro = params["bd_rd_psro"]
    our_psro = params['our_psro']
    sparse_psro_1 = params['sparse_psro_1']
    sparse_psro_2 = params['sparse_psro_2']
    convx_psd_psro = params['convx_psd_psro']
    psro_exps = []
    psro_cardinality = []
    pipeline_psro_exps = []
    pipeline_psro_cardinality = []
    dpp_psro_exps = []
    dpp_psro_cardinality = []
    rectified_exps = []
    rectified_cardinality = []
    self_play_exps = []
    self_play_cardinality = []
    psd_psro_exps = []
    psd_psro_cardinality = []
    bd_rd_psro_exps = []
    bd_rd_psro_cardinality = []
    our_psro_exps = []
    our_psro_cardinality = []
    sparse_psro_1_exps = []
    sparse_psro_1_cardinality = []
    sparse_psro_2_exps = []
    sparse_psro_2_cardinality = []
    convx_psd_psro_exps = []
    convx_psd_psro_cardinality = []

    psro_times=[]
    pipeline_psro_times=[]
    dpp_psro_times=[]
    rectified_times=[]
    self_play_times=[]
    psd_psro_times=[]
    bd_rd_psro_times=[]
    our_psro_times=[]
    sparse_psro_1_times=[]
    sparse_psro_2_times=[]
    convx_psd_psro_times = []

    psro_pop = []
    pipeline_psro_pop = []
    dpp_psro_pop = []
    rectified_pop = []
    self_play_pop = []
    psd_psro_pop = []
    bd_rd_psro_pop = []
    our_psro_pop = []
    sparse_psro_1_pop = []
    sparse_psro_2_pop = []
    convx_psd_psro_pop = []

    print('Experiment: ', seed + 1)
    np.random.seed(seed)
    with open("payoffs_data/" + str(args.game_name) + ".pkl", "rb") as fh:
        payoffs = pickle.load(fh)
        payoffs /= np.abs(payoffs).max() 
    
    if psd_psro:
        print('PSD PSRO')
        pop, exps, cards,times = psd_psro_steps(iters=iters, num_learners=num_threads, seed=seed+1,
                                                              improvement_pct_threshold=thresh, lr=lr,
                                                              payoffs=payoffs, loss_func='psd')
        psd_psro_pop = pop
        psd_psro_exps = exps
        psd_psro_cardinality = cards
        psd_psro_times = times
    if convx_psd_psro:
        print('Sparse PSRO without 3.1')
        pop, exps, cards,times = our_steps(iters=iters, num_learners=num_threads, seed=seed + 1,
                                          improvement_pct_threshold=thresh, lr=lr,term=term,window=window,mu=mu,div=div,
                                          payoffs=payoffs, loss_func='br')
        convx_psd_psro_pop = pop
        convx_psd_psro_exps = exps
        convx_psd_psro_cardinality = cards
        convx_psd_psro_times = times
    if psro:
        print('PSRO')
        pop, exps, cards,times = psro_steps(iters=iters, num_learners=1, seed=seed+1,
                                                              improvement_pct_threshold=thresh, lr=lr,
                                                              payoffs=payoffs, loss_func='br')
        psro_pop = pop
        psro_exps = exps
        psro_cardinality = cards
        psro_times = times
    if pipeline_psro:
        print('Pipeline PSRO')
        pop, exps, cards,times = psro_steps(iters=iters, num_learners=num_threads, seed=seed+1,
                                                              improvement_pct_threshold=thresh, lr=lr,
                                                              payoffs=payoffs, loss_func='br')
        pipeline_psro_pop = pop
        pipeline_psro_exps = exps
        pipeline_psro_cardinality = cards
        pipeline_psro_times = times
    if dpp_psro:
        print('Pipeline DPP')
        pop, exps, cards ,times= psro_steps(iters=iters, num_learners=num_threads, seed=seed+1,
                                                              improvement_pct_threshold=thresh, lr=lr,
                                                              payoffs=payoffs, loss_func='dpp')
        dpp_psro_pop = pop
        dpp_psro_exps = exps
        dpp_psro_cardinality = cards
        dpp_psro_times = times
    if rectified:
        print('Rectified')
        pop, exps, cards, times= psro_rectified_steps(iters=iters, num_pseudo_learners=num_threads, payoffs=payoffs, seed=seed+1,
                                         lr=lr, threshold=thresh)
        rectified_pop = pop
        rectified_exps = exps
        rectified_cardinality = cards
        rectified_times = times
    if self_play:
        print('Self-play')
        pop, exps, cards ,times= self_play_steps(iters=iters, payoffs=payoffs, improvement_pct_threshold=thresh, lr=lr, seed=seed+1)
        self_play_pop = pop
        self_play_exps = exps
        self_play_cardinality = cards
        self_play_times = times
    if bd_rd_psro:
        print('BD-RD PSRO')
        pop, exps, cards,times = psro_steps(iters=iters, num_learners=num_threads, seed=seed+1,
                                    improvement_pct_threshold=thresh, lr=lr,
                                    payoffs=payoffs, loss_func='bd_rd')
        bd_rd_psro_pop = pop
        bd_rd_psro_exps = exps
        bd_rd_psro_cardinality = cards
        bd_rd_psro_times = times
    if our_psro:
        print('Our PSRO')
        pop, exps, cards ,times= our_steps(iters=iters, num_learners=num_threads, seed=seed + 1,
                                          improvement_pct_threshold=thresh, lr=lr,term=term,window=window,mu=mu,div=div,
                                          payoffs=payoffs, loss_func='sparse')
        our_psro_pop = pop
        our_psro_exps = exps
        our_psro_cardinality = cards
        our_psro_times = times

    if sparse_psro_1:
        print('convx Sparse PSRO')
        pop, exps, cards ,times= our_steps(iters=iters, num_learners=num_threads, seed=seed + 1,
                                          improvement_pct_threshold=thresh, lr=lr,term=term,window=window,mu=mu,div=div,
                                          payoffs=payoffs, loss_func='convx_sparse')
        sparse_psro_1_pop = pop
        sparse_psro_1_exps = exps
        sparse_psro_1_cardinality = cards
        sparse_psro_1_times = times

    if sparse_psro_2:
        print('Sparse PSRO without 3.2')
        pop, exps, cards ,times= sparse_steps(iters=iters, num_learners=num_threads, seed=seed + 1,
                                          improvement_pct_threshold=thresh, lr=lr,term=term,div=div,
                                          payoffs=payoffs, loss_func='sparse')
        sparse_psro_2_pop = pop
        sparse_psro_2_exps = exps
        sparse_psro_2_cardinality = cards
        sparse_psro_2_times = times
    return {
        'psro_exps': psro_exps,
        'psro_cardinality': psro_cardinality,
        'pipeline_psro_exps': pipeline_psro_exps,
        'pipeline_psro_cardinality': pipeline_psro_cardinality,
        'dpp_psro_exps': dpp_psro_exps,
        'dpp_psro_cardinality': dpp_psro_cardinality,
        'rectified_exps': rectified_exps,
        'rectified_cardinality': rectified_cardinality,
        'self_play_exps': self_play_exps,
        'self_play_cardinality': self_play_cardinality,
        'psd_psro_exps': psd_psro_exps,
        'psd_psro_cardinality': psd_psro_cardinality,
        "bd_rd_psro_exps": bd_rd_psro_exps,
        "bd_rd_psro_cardinality": bd_rd_psro_cardinality,
        "our_psro_exps": our_psro_exps,
        "our_psro_cardinality": our_psro_cardinality,
        'sparse_psro_1_exps': sparse_psro_1_exps,
        'sparse_psro_1_cardinality': sparse_psro_1_cardinality,
        'sparse_psro_2_exps': sparse_psro_2_exps,
        'sparse_psro_2_cardinality': sparse_psro_2_cardinality,
        'convx_psd_psro_exps': convx_psd_psro_exps,
        'convx_psd_psro_cardinality': convx_psd_psro_cardinality,

        'psro_times': psro_times,
        'pipeline_psro_times': pipeline_psro_times,
        'dpp_psro_times': dpp_psro_times,
        'rectified_times': rectified_times,
        'self_play_times': self_play_times,
        'psd_psro_times': psd_psro_times,
        'bd_rd_psro_times': bd_rd_psro_times,
        'our_psro_times': our_psro_times,
        'sparse_psro_1_times': sparse_psro_1_times,
        'sparse_psro_2_times': sparse_psro_2_times,
        'convx_psd_psro_times': convx_psd_psro_times,

        'psro_pop': psro_pop,
        'pipeline_psro_pop': pipeline_psro_pop,
        'dpp_psro_pop': dpp_psro_pop,
        'rectified_pop': rectified_pop,
        'self_play_pop': self_play_pop,
        'psd_psro_pop': psd_psro_pop,
        'bd_rd_psro_pop': bd_rd_psro_pop,
        'our_psro_pop': our_psro_pop,
        'sparse_psro_1_pop': sparse_psro_1_pop,
        'sparse_psro_2_pop': sparse_psro_2_pop,
        'convx_psd_psro_pop': convx_psd_psro_pop,

    }


def run_experiments(num_experiments=2, iters=40, num_threads=20, lr=0.5, thresh=0.001,term=0.8,window=100,mu=0.02,div=0.01, logscale=True,
                    psro=False,
                    pipeline_psro=False,
                    rectified=False,
                    self_play=False,
                    dpp_psro=False,
                    psd_psro=False,
                    bd_rd_psro=False,
                    our_psro=False,
                    sparse_psro_1=False,
                    sparse_psro_2=False,
                    convx_psd_psro=False,
                    ):

    params = {
        'num_experiments': num_experiments,
        'iters': iters,
        'num_threads': num_threads,
        'lr': lr,
        'thresh': thresh,
        'term': term,
        'window': window,
        'mu': mu,
        'div': div,
        'psro': psro,
        'pipeline_psro': pipeline_psro,
        'dpp_psro': dpp_psro,
        'rectified': rectified,
        'self_play': self_play,
        'psd_psro': psd_psro,
        'bd_rd_psro': bd_rd_psro,
        'our_psro': our_psro,
        'sparse_psro_1': sparse_psro_1,
        'sparse_psro_2': sparse_psro_2,
        'convx_psd_psro': convx_psd_psro,
    }

    psro_exps = []
    psro_cardinality = []
    pipeline_psro_exps = []
    pipeline_psro_cardinality = []
    dpp_psro_exps = []
    dpp_psro_cardinality = []
    rectified_exps = []
    rectified_cardinality = []
    self_play_exps = []
    self_play_cardinality = []
    psd_psro_exps = []
    psd_psro_cardinality = []
    bd_rd_psro_exps = []
    bd_rd_psro_cardinality = []
    our_psro_exps = []
    our_psro_cardinality = []
    sparse_psro_1_exps = []
    sparse_psro_1_cardinality = []
    sparse_psro_2_exps = []
    sparse_psro_2_cardinality = []
    convx_psd_psro_exps = []
    convx_psd_psro_cardinality = []

    psro_times = []
    pipeline_psro_times = []
    dpp_psro_times = []
    rectified_times = []
    self_play_times = []
    psd_psro_times = []
    bd_rd_psro_times = []
    our_psro_times = []
    sparse_psro_1_times = []
    sparse_psro_2_times = []
    convx_psd_psro_times = []

    psro_pop = []
    pipeline_psro_pop = []
    dpp_psro_pop = []
    rectified_pop = []
    self_play_pop = []
    psd_psro_pop = []
    bd_rd_psro_pop = []
    our_psro_pop = []
    sparse_psro_1_pop = []
    sparse_psro_2_pop = []
    convx_psd_psro_pop = []

    with open(os.path.join(PATH_RESULTS, 'params.json'), 'w', encoding='utf-8') as json_file:
        json.dump(params, json_file, indent=4)

    result = []

    #print(args.mp)
    if args.mp == False:
        for i in range(num_experiments):
            result.append(run_experiment((params, i)))

    else:
        pool = mp.Pool()
        result = pool.map(run_experiment, [(params, i) for i in range(num_experiments)])

    for r in result:
        psro_exps.append(r['psro_exps'])
        psro_cardinality.append(r['psro_cardinality'])
        pipeline_psro_exps.append(r['pipeline_psro_exps'])
        pipeline_psro_cardinality.append(r['pipeline_psro_cardinality'])
        dpp_psro_exps.append(r['dpp_psro_exps'])
        dpp_psro_cardinality.append(r['dpp_psro_cardinality'])
        rectified_exps.append(r['rectified_exps'])
        rectified_cardinality.append(r['rectified_cardinality'])
        self_play_exps.append(r['self_play_exps'])
        self_play_cardinality.append(r['self_play_cardinality'])
        psd_psro_exps.append(r['psd_psro_exps'])
        psd_psro_cardinality.append(r['psd_psro_cardinality'])
        bd_rd_psro_exps.append(r["bd_rd_psro_exps"])
        bd_rd_psro_cardinality.append(r["bd_rd_psro_cardinality"])
        our_psro_exps.append(r['our_psro_exps'])
        our_psro_cardinality.append(r['our_psro_cardinality'])
        sparse_psro_1_exps.append(r['sparse_psro_1_exps'])
        sparse_psro_1_cardinality.append(r['sparse_psro_1_cardinality'])
        sparse_psro_2_exps.append(r['sparse_psro_2_exps'])
        sparse_psro_2_cardinality.append(r['sparse_psro_2_cardinality'])
        convx_psd_psro_exps.append(r['convx_psd_psro_exps'])
        convx_psd_psro_cardinality.append(r['convx_psd_psro_cardinality'])

        psro_times.append(r['psro_times'])
        pipeline_psro_times.append(r['pipeline_psro_times'])
        dpp_psro_times.append(r['dpp_psro_times'])
        rectified_times.append(r['rectified_times'])
        self_play_times.append(r['self_play_times'])
        psd_psro_times.append(r['psd_psro_times'])
        bd_rd_psro_times.append(r['bd_rd_psro_times'])
        our_psro_times.append(r['our_psro_times'])
        sparse_psro_1_times.append(r['sparse_psro_1_times'])
        sparse_psro_2_times.append(r['sparse_psro_2_times'])
        convx_psd_psro_times.append(r['convx_psd_psro_times'])

        psro_pop.append(r['psro_pop'])
        pipeline_psro_pop.append(r['pipeline_psro_pop'])
        dpp_psro_pop.append(r['dpp_psro_pop'])
        rectified_pop.append(r['rectified_pop'])
        self_play_pop.append(r['self_play_pop'])
        psd_psro_pop.append(r['psd_psro_pop'])
        bd_rd_psro_pop.append(r['bd_rd_psro_pop'])
        our_psro_pop.append(r['our_psro_pop'])
        sparse_psro_1_pop.append(r['sparse_psro_1_pop'])
        sparse_psro_2_pop.append(r['sparse_psro_2_pop'])
        convx_psd_psro_pop.append(r['convx_psd_psro_pop'])
    d = {
        'psro_exps': psro_exps,
        'psro_cardinality': psro_cardinality,
        'pipeline_psro_exps': pipeline_psro_exps,
        'pipeline_psro_cardinality': pipeline_psro_cardinality,
        'dpp_psro_exps': dpp_psro_exps,
        'dpp_psro_cardinality': dpp_psro_cardinality,
        'rectified_exps': rectified_exps,
        'rectified_cardinality': rectified_cardinality,
        'self_play_exps': self_play_exps,
        'self_play_cardinality': self_play_cardinality,
        'psd_psro_exps': psd_psro_exps, 
        'psd_psro_cardinality': psd_psro_cardinality,
        'bd_rd_psro_exps': bd_rd_psro_exps, 
        'bd_rd_psro_cardinality':bd_rd_psro_cardinality,
        "our_psro_exps": our_psro_exps,
        "our_psro_cardinality": our_psro_cardinality,
        'sparse_psro_1_exps': sparse_psro_1_exps,
        'sparse_psro_1_cardinality': sparse_psro_1_cardinality,
        'sparse_psro_2_exps': sparse_psro_2_exps,
        'sparse_psro_2_cardinality': sparse_psro_2_cardinality,
        'convx_psd_psro_exps': convx_psd_psro_exps,
        'convx_psd_psro_cardinality': convx_psd_psro_cardinality,

        'psro_times': psro_times,
        'pipeline_psro_times': pipeline_psro_times,
        'dpp_psro_times': dpp_psro_times,
        'rectified_times': rectified_times,
        'self_play_times': self_play_times,
        'psd_psro_times': psd_psro_times,
        'bd_rd_psro_times': bd_rd_psro_times,
        'our_psro_times': our_psro_times,
        'sparse_psro_1_times': sparse_psro_1_times,
        'sparse_psro_2_times': sparse_psro_2_times,
        'convx_psd_psro_times': convx_psd_psro_times,

        'psro_pop': psro_pop,
        'pipeline_psro_pop': pipeline_psro_pop,
        'dpp_psro_pop': dpp_psro_pop,
        'rectified_pop': rectified_pop,
        'self_play_pop': self_play_pop,
        'psd_psro_pop': psd_psro_pop,
        'bd_rd_psro_pop': bd_rd_psro_pop,
        'our_psro_pop': our_psro_pop,
        'sparse_psro_1_pop': sparse_psro_1_pop,
        'sparse_psro_2_pop': sparse_psro_2_pop,
        'convx_psd_psro_pop': convx_psd_psro_pop,
    }

    pickle.dump(d, open(os.path.join(PATH_RESULTS, 'data.p'), 'wb'))

    def plot_error(data, label=''):
        min_len = min([len(i) for i in data])
        data = [i[0:min_len] for i in data]
        data_mean = np.mean(np.array(data), axis=0)
        error_bars = stats.sem(np.array(data))
        plt.plot(data_mean, label=label)
        plt.fill_between([i for i in range(data_mean.size)],
                         np.squeeze(data_mean - error_bars),
                         np.squeeze(data_mean + error_bars), alpha=alpha)

    alpha = .4
    for j in range(3):
        fig_handle = plt.figure()

        if psro:
            if j == 0:
                plot_error(psro_exps, label='PSRO')
            elif j == 1:
                plot_error(psro_cardinality, label='PSRO')
                print(f"size of pop: {len(psro_pop[0])}")
            elif j == 2:
                plot_error(psro_times, label='PSRO')

        if pipeline_psro:
            if j == 0:
                plot_error(pipeline_psro_exps, label='P-PSRO')
            elif j == 1:
                plot_error(pipeline_psro_cardinality, label='P-PSRO')
                print(f"size of pop: {len(pipeline_psro_pop[0])}")
            elif j == 2:
                plot_error(pipeline_psro_times, label='P-PSRO')

        if rectified:
            if j == 0:
                length = min([len(l) for l in rectified_exps])
                for i, l in enumerate(rectified_exps):
                    rectified_exps[i] = rectified_exps[i][:length]
                plot_error(rectified_exps, label='PSRO-rN')
            elif j == 1:
                length = min([len(l) for l in rectified_cardinality])
                for i, l in enumerate(rectified_cardinality):
                    rectified_cardinality[i] = rectified_cardinality[i][:length]
                plot_error(rectified_cardinality, label='PSRO-rN')
                print(f"size of pop: {len(rectified_pop[0])}")
            elif j == 2:
                length = min([len(l) for l in rectified_times])
                for i, l in enumerate(rectified_times):
                    rectified_times[i] = rectified_times[i][:length]
                plot_error(rectified_times, label='PSRO-rN')

        if self_play:
            if j == 0:
                plot_error(self_play_exps, label='Self-play')
            elif j == 1:
                plot_error(self_play_cardinality, label='Self-play')
                print(f"size of pop: {len(self_play_pop[0])}")
            elif j == 2:
                plot_error(self_play_times, label='Self-play')

        if dpp_psro:
            if j == 0:
                plot_error(dpp_psro_exps, label='dpp_psro')
            elif j == 1:
                plot_error(dpp_psro_cardinality, label='dpp_psro')
                print(f"size of pop: {len(dpp_psro_pop[0])}")
            elif j == 2:
                plot_error(dpp_psro_times, label='dpp_psro')
        if bd_rd_psro:
            if j == 0:
                plot_error(bd_rd_psro_exps, label='bd_rd_psro')
            elif j == 1:
                plot_error(bd_rd_psro_cardinality, label='bd_rd_psro')
                print(f"size of pop: {len(bd_rd_psro_pop[0])}")
            elif j == 2:
                plot_error(bd_rd_psro_times, label='bd_rd_psro')
        if psd_psro:
            if j == 0:
                plot_error(psd_psro_exps, label='psd_psro')
            elif j == 1:
                plot_error(psd_psro_cardinality, label='psd_psro')
                print(f"size of pop: {len(psd_psro_pop[0])}")
            elif j == 2:
                plot_error(psd_psro_times, label='psd_psro')
        if convx_psd_psro:
            if j == 0:
                plot_error(convx_psd_psro_exps, label='convx_sparse_psro_1')
            elif j == 1:
                plot_error(convx_psd_psro_cardinality, label='convx_sparse_psro_1')
                print(f"size of pop: {len(convx_psd_psro_pop[0])}")
            elif j == 2:
                plot_error(convx_psd_psro_times, label='convx_sparse_psro_1')

        if sparse_psro_1:
            if j == 0:
                plot_error(sparse_psro_1_exps, label='convx_sparse_psro')
            elif j == 1:
                plot_error(sparse_psro_1_cardinality, label='convx_sparse_psro')
                print(f"size of pop: {len(sparse_psro_1_pop[0])}")
            elif j == 2:
                plot_error(sparse_psro_1_times, label='convx_sparse_psro')
        if sparse_psro_2:
            if j == 0:
                plot_error(sparse_psro_2_exps, label='convx_sparse_psro_2')
            elif j == 1:
                plot_error(sparse_psro_2_cardinality, label='convx_sparse_psro_2')
                print(f"size of pop: {len(sparse_psro_2_pop[0])}")
            elif j == 2:
                plot_error(sparse_psro_2_times, label='convx_sparse_psro_2')
        if our_psro:
            if j == 0:
                plot_error(our_psro_exps, label='sparse_psro(ours)')
            elif j == 1:
                plot_error(our_psro_cardinality, label='sparse_psro(ours)')
                print(f"size of pop: {len(our_psro_pop[0])}")
            elif j == 2:
                plot_error(our_psro_times, label='sparse_psro(ours)')
        plt.legend(loc="upper right")
        plt.title(args.game_name)

        if logscale and (j == 0) or (j == 1):
            plt.yscale('log')

        if j == 0:
            string = 'Exploitability Log'
        elif j == 1:
            string = 'pop exploitability'
        elif j == 2:
            string = 'Times'


        plt.savefig(os.path.join(PATH_RESULTS, 'figure_' + string + '.pdf'))


if __name__ == "__main__":
    start_time = time.time()
    run_experiments(num_experiments=args.nb_exps, num_threads=2, iters=args.nb_iters,  lr=LR, thresh=TH, term=TM,window=WS,mu=MU,div=DIV,
                    psro=False,
                    pipeline_psro=False,
                    rectified=False,
                    self_play=False,
                    dpp_psro=False,
                    psd_psro=False,
                    bd_rd_psro=False,
                    our_psro=False,
                    sparse_psro_1=False,
                    sparse_psro_2=False,
                    convx_psd_psro=True,
                    )
    end_time = time.time()
    print(f"Total time: {end_time - start_time}")
    print(f'The directory is {PATH_RESULTS}')