import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import time
import os
from scipy import stats
import pickle

np.set_printoptions(suppress=True)
np.set_printoptions(precision=3)

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as f

np.random.seed(0)
import copy

from expground.settings import BASE_DIR

dim = 30
payoffs = np.tril(np.random.uniform(-1, 1, (dim, dim)), -1)
payoffs = payoffs - payoffs.T

LR = 0.1
TRAIN_ITERS = 5

expected_card = []
sizes = []

time_string = time.strftime("%Y%m%d-%H%M%S")
PATH_RESULTS = os.path.join(BASE_DIR, "results", time_string)
os.mkdir(PATH_RESULTS)

device = "cpu"

FILE_TRAJ = {
    "rectified": "rectified.p",
    "psro": "psro.p",
    "ppsro": "p_psro.p",
    "dpp": "dpp.p",
    "epsro": "epsro.p",
    "pepsro": "pepsro.p",
    "iterative": "iterative.p",
}


class MyGaussianPDF(nn.Module):
    def __init__(self, mu):
        super(MyGaussianPDF, self).__init__()
        self.mu = mu
        self.cov = 0.54 * torch.eye(2)
        # self.c = (1./(2*np.pi))
        self.c = 1.0

    def forward(self, x):
        return self.c * torch.exp(
            -0.5 * torch.diagonal((x - self.mu) @ self.cov @ (x - self.mu).t())
        )


class GMMAgent(nn.Module):
    def __init__(self, mu, one_hot=False, length=None):
        super(GMMAgent, self).__init__()
        self.one_hot = one_hot
        if not one_hot:
            self.gauss = MyGaussianPDF(mu).to(device)
            self.x = nn.Parameter(
                0.01 * torch.randn(2, dtype=torch.float), requires_grad=False
            )
        else:
            t = np.zeros(length)
            t[mu] = 1.0
            self.x = torch.from_numpy(t).float().to(device)
        self._data = None
        self.fixed = False

    def hard_set(self, x):
        self.fixed = True
        self._data = torch.from_numpy(x).float().to(device)

    def forward(self):
        if not self.one_hot:
            if not self.fixed:
                return self.gauss(self.x)
            else:
                return self._data
        else:
            return self.x


def multivariate_gaussian(pos, mu, Sigma):
    """Return the multivariate Gaussian distribution on array pos."""

    n = mu.shape[0]
    Sigma_det = np.linalg.det(Sigma)
    Sigma_inv = np.linalg.inv(Sigma)
    N = np.sqrt((2 * np.pi) ** n * Sigma_det)
    # This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized
    # way across all the input variables.
    fac = np.einsum("...k,kl,...l->...", pos - mu, Sigma_inv, pos - mu)
    return np.exp(-fac / 2) / N


class TorchPop:
    def __init__(self, num_learners, seed=0, async_mode=False, mixed_oracles=False):
        torch.manual_seed(seed)
        self.pop_size = num_learners + 1

        mus = np.array(
            [
                [2.8722, -0.025255],
                [1.8105, 2.2298],
                [1.8105, -2.2298],
                [-0.61450, 2.8058],
                [-0.61450, -2.8058],
                [-2.5768, 1.2690],
                [-2.5768, -1.2690],
            ]
        )
        mus = torch.from_numpy(mus).float().to(device)
        self.mus = mus

        self.game = (
            torch.from_numpy(
                np.array(
                    [
                        [0.0, 1.0, 1.0, 1, -1, -1, -1],
                        [-1.0, 0.0, 1.0, 1.0, 1.0, -1.0, -1.0],
                        [-1.0, -1.0, 0.0, 1.0, 1.0, 1.0, -1],
                        [-1.0, -1.0, -1.0, 0, 1.0, 1.0, 1.0],
                        [1.0, -1.0, -1.0, -1.0, 0.0, 1.0, 1.0],
                        [1.0, 1.0, -1.0, -1, -1, 0.0, 1.0],
                        [1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 0.0],
                    ]
                )
            )
            .float()
            .to(device)
        )

        if mixed_oracles:
            self.br_pop = [GMMAgent(mus) for _ in range(self.pop_size)]
        else:
            self.br_pop = None

        self.pop = [GMMAgent(mus) for _ in range(self.pop_size)]
        self.pop_hist = [
            [self.pop[i].x.detach().cpu().clone().numpy()] for i in range(self.pop_size)
        ]
        length = self.game.shape[0]
        self.async_mode = async_mode
        self.mixed_oracles = mixed_oracles
        self.full_pop = [
            GMMAgent(i, one_hot=True, length=length) for i in range(length)
        ]

    def visualise_pop(self, br=None, ax=None, color=None):
        agents = [agent.x.detach().cpu().numpy() for agent in self.pop]
        agents = list(zip(*agents))

        # Colors
        if color is None:
            colors = cm.rainbow(np.linspace(0, 1, len(agents[0])))
        else:
            colors = [color] * len(agents[0])

        # fig = plt.figure(figsize=(6, 6))
        ax.scatter(
            agents[0],
            agents[1],
            alpha=1.0,
            marker=".",
            color=colors,
            s=8 * plt.rcParams["lines.markersize"] ** 2,
        )
        if br is not None:
            ax.scatter(br[0], br[1], marker=".", c="k")
        for i, hist in enumerate(self.pop_hist):
            if hist:
                hist = list(zip(*hist))
                ax.plot(hist[0], hist[1], alpha=0.8, color=colors[i], linewidth=4)

        # ax = plt.gca()
        for i in range(7):
            ax.scatter(self.mus[i, 0].item(), self.mus[i, 1].item(), marker="x", c="k")
            for j in range(4):
                delta = 0.025
                x = np.arange(-4.5, 4.5, delta)
                y = np.arange(-4.5, 4.5, delta)
                X, Y = np.meshgrid(x, y)
                pos = np.empty(X.shape + (2,))
                pos[:, :, 0] = X
                pos[:, :, 1] = Y
                Z = multivariate_gaussian(pos, self.mus[i, :].numpy(), 0.54 * np.eye(2))
                levels = 10
                # levels = np.logspace(0.01, 1, 10, endpoint=True)
                CS = ax.contour(X, Y, Z, levels, colors="k", linewidths=0.5, alpha=0.2)
        ax.axes.xaxis.set_ticks([])
        ax.axes.yaxis.set_ticks([])
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        # ax.clabel(CS, fontsize=9, inline=1)
        # circle = plt.Circle((0, 0), 0.2, color='r')
        # ax.add_artist(circle)
        ax.set_xlim([-4.5, 4.5])
        ax.set_ylim([-4.5, 4.5])

    def get_payoff(self, agent1, agent2):
        p = agent1()
        q = agent2()
        return p @ self.game @ q + 0.5 * (p - q).sum()

    def mixed_brs(self, meta_nash, K):
        agg_agent = meta_nash[0] * self.br_pop[0]()
        for k in range(1, K):
            agg_agent += meta_nash[0] * self.br_pop[k]()
        return agg_agent

    def get_payoff_aggregate(self, agent1, metanash, K):
        # Computes the payoff of agent1 against the aggregated first :K agents using metanash as weights
        if not self.mixed_oracles:
            agg_agent = metanash[0] * self.pop[0]()
            for k in range(1, K):
                agg_agent += metanash[k] * self.pop[k]()
        else:
            agg_agent = self.pop[K - 1]()
        return agent1() @ self.game @ agg_agent + 0.5 * (agent1() - agg_agent).sum()

    def get_payoff_aggregate_weights(self, agent1, weights, K):
        # Computes the payoff of agent1 against the aggregated first :K agents using metanash as weights
        agg_agent = weights[0] * self.pop[0]()
        for k in range(1, len(weights)):
            agg_agent += weights[k] * self.pop[k]()
        return agent1() @ self.game @ agg_agent + 0.5 * (agent1() - agg_agent).sum()

    def get_br_to_strat(self, metanash, lr, nb_iters=20):
        br = GMMAgent(self.mus)
        br.x = nn.Parameter(
            0.1 * torch.randn(2, dtype=torch.float), requires_grad=False
        )
        br.x.requires_grad = True
        optimiser = optim.Adam(br.parameters(), lr=lr)
        for _ in range(nb_iters * 10):
            loss = -self.get_payoff_aggregate(
                br,
                metanash,
                self.pop_size,
            )
            # Optimise !
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

        return br

    def get_metagame(self, k=None, numpy=False, sub=False):
        if k == None:
            k = self.pop_size
        if numpy:
            with torch.no_grad():
                if not self.async_mode or sub:
                    metagame = torch.zeros(k, k)
                    for i in range(k):
                        for j in range(k):
                            metagame[i, j] = self.get_payoff(self.pop[i], self.pop[j])
                else:
                    metagame = torch.zeros(self.game.shape[0], k)
                    for i in range(self.game.shape[0]):
                        for j in range(k):
                            metagame[i, j] = self.get_payoff(
                                self.full_pop[i], self.pop[j]
                            )
                return metagame.detach().cpu().clone().numpy()
        else:
            if not self.async_mode or sub:
                metagame = torch.zeros(k, k)
                for i in range(k):
                    for j in range(k):
                        metagame[i, j] = self.get_payoff(self.pop[i], self.pop[j])
            else:
                metagame = torch.zeros(self.game.shape[0], k)
                for i in range(self.game.shape[0]):
                    for j in range(k):
                        metagame[i, j] = self.get_payoff(self.full_pop[i], self.pop[j])
            return metagame

    def add_new(self):
        with torch.no_grad():
            self.pop.append(GMMAgent(self.mus))
            self.pop_hist.append([self.pop[-1].x.detach().cpu().clone().numpy()])
            if self.mixed_oracles:
                self.br_pop.append(GMMAgent(self.mus))
            self.pop_size += 1

    def get_exploitability(self, metanash, lr, nb_iters=20):
        br = self.get_br_to_strat(metanash, lr, nb_iters=nb_iters)
        with torch.no_grad():
            exp = self.get_payoff_aggregate(br, metanash, self.pop_size).item()
        return exp


def gradient_loss_update(
    torch_pop,
    k,
    train_iters=10,
    lambda_weight=0.1,
    lr=0.1,
    dpp=True,
    symmetric=True,
):

    # We compute metagame M and then L in a differentiable way
    # We compute expected payoff of agent k-1 against aggregated strat

    # Make strategy k trainable
    torch_pop.pop[k].x.requires_grad = True

    # Optimiser
    optimiser = optim.Adam(torch_pop.pop[k].parameters(), lr=lr)

    for iter in range(train_iters):

        # Get metagame and metastrat
        M = torch_pop.get_metagame(k=k + 1)
        if symmetric:
            meta_nash = fictitious_play(
                payoffs=M.detach().cpu().clone().numpy()[:k, :k],
                iters=1000,
                symmetric=symmetric,
            )[-1]
        else:
            meta_nash, exp = fictitious_play(
                payoffs=M.detach().cpu().clone().numpy()[:, :k],
                iters=1000,
                symmetric=symmetric,
            )
            # print("exp:", exp[-1])

        # Compute cardinality of pop up until :k UNION training strategy. We use payoffs as features.
        if dpp:
            M = f.normalize(M, dim=1, p=2)  #  Normalise
            L = M @ M.t()  # Compute kernel
            L_card = torch.trace(
                torch.eye(L.shape[0]) - torch.inverse(L + torch.eye(L.shape[0]))
            )  # Compute cardinality

            # Compute the expected return given that enemy plays agg_strat (using :k first strats)
            exp_payoff = torch_pop.get_payoff_aggregate(torch_pop.pop[k], meta_nash, k)

            # Loss
            loss = -(lambda_weight * exp_payoff + (1.0 - lambda_weight) * L_card)
        else:
            with torch.no_grad():
                if not symmetric:
                    M = torch_pop.get_metagame(k=k + 1, sub=True)
                M = f.normalize(M, dim=1, p=2)  #  Normalise
                L = M @ M.t()  # Compute kernel
                L_card = torch.trace(
                    torch.eye(L.shape[0]) - torch.inverse(L + torch.eye(L.shape[0]))
                )  # Compute cardinality

            # Compute the expected return given that enemy plays agg_strat (using :k first strats)
            if not symmetric:
                exp_payoff = torch_pop.get_payoff_aggregate(
                    torch_pop.pop[k], meta_nash[-1][:, -1], k
                )
            else:
                exp_payoff = torch_pop.get_payoff_aggregate(
                    torch_pop.pop[k], meta_nash, k
                )

            # Loss
            loss = -(lambda_weight * exp_payoff)
            # torch_pop.pop[k].hard_set(meta_nash[0][-1])
            # break

        # Optimise !
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    # print("\t* target: {}\n\t*eval: {}\n\t* loss: {}".format(meta_nash[0][-1], torch_pop.pop[k](), loss.detach().item()))
    torch_pop.pop_hist[k].append(torch_pop.pop[k].x.detach().cpu().clone().numpy())

    # Make strategy k non-trainable
    torch_pop.pop[k].x.requires_grad = False
    return exp_payoff.item(), L_card.item()


def mixed_oracle_loss(
    torch_pop,
    k,
    train_iters=10,
    lambda_weight=0.1,
    lr=0.1,
):
    # Make strategy k trainable
    torch_pop.br_pop[k].x.requires_grad = True
    torch_pop.pop[k].x.requires_grad = True

    # Optimiser
    optimiser = optim.Adam(torch_pop.br_pop[k].parameters(), lr=lr)
    optimiser2 = optim.Adam(torch_pop.pop[k].parameters(), lr=lr)

    for iter in range(train_iters):

        # Get metagame and metastrat
        M = torch_pop.get_metagame(k=k + 1)
        meta_nash = fictitious_play(
            payoffs=M.detach().cpu().clone().numpy()[:k, :k],
            iters=1000,
            symmetric=True,
        )[-1]

        with torch.no_grad():
            M = f.normalize(M, dim=1, p=2)  #  Normalise
            L = M @ M.t()  # Compute kernel
            L_card = torch.trace(
                torch.eye(L.shape[0]) - torch.inverse(L + torch.eye(L.shape[0]))
            )  # Compute cardinality

        exp_payoff = torch_pop.get_payoff_aggregate(torch_pop.br_pop[k], meta_nash, k)

        # Loss
        loss1 = -(lambda_weight * exp_payoff)

        # mixed_oracles
        oracle_pop = torch_pop.mixed_brs(meta_nash, k).detach()
        loss2 = -torch.sum(torch.log(torch.clip(torch_pop.pop[k](), 1e-6)) * oracle_pop)

        loss = loss1 + loss2

        # Optimise !
        optimiser.zero_grad()
        optimiser2.zero_grad()
        loss.backward()
        optimiser.step()
        optimiser2.step()

    # print("\t* target: {}\n\t*eval: {}\n\t* loss: {}".format(meta_nash[0][-1], torch_pop.pop[k](), loss.detach().item()))
    torch_pop.pop_hist[k].append(torch_pop.pop[k].x.detach().cpu().clone().numpy())

    # Make strategy k non-trainable
    torch_pop.pop[k].x.requires_grad = False
    torch_pop.br_pop[k].x.requires_grad = False
    return exp_payoff.item(), L_card.item()


def psro_gradient(iters=5, num_learners=4, lr=0.2, train_iters=10, dpp=True, seed=0):

    # Generate population
    torch_pop = TorchPop(num_learners, seed=seed)

    # Compute initial exploitability and init stuff
    metagame = torch_pop.get_metagame(numpy=True)
    metanash = fictitious_play(payoffs=metagame, iters=1000)[0][-1]
    exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
    exps = [exp]
    L_card = 0.0
    L_cards = []

    for i in range(iters):
        # Define the weighting towards diversity
        lambda_weight = 0.0 + (0.7 / (1 + np.exp(-0.25 * (i - (25)))))
        lambda_weight = 1.0 - lambda_weight
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = torch_pop.pop_size - j - 1

            # Diverse PSRO
            exp_payoff, L_card = gradient_loss_update(
                torch_pop,
                k,
                train_iters=train_iters,
                lr=lr,
                lambda_weight=lambda_weight,
                dpp=dpp,
            )
            if j == num_learners - 1:
                torch_pop.add_new()

        metagame = torch_pop.get_metagame(numpy=True)
        metanash = fictitious_play(payoffs=metagame, iters=1000)[0][-1]
        exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
        exps.append(exp)
        L_cards.append(L_card)

        if i % 1 == 0:
            print(
                "ITERATION: ",
                i,
                " exp full: {:.4f}".format(exps[-1]),
                "L_CARD: {:.3f}".format(L_cards[-1]),
                "lw: {:.3f}".format(lambda_weight),
            )

    fig1, axs1 = plt.subplots(1, 1)
    torch_pop.visualise_pop(br=None, ax=axs1)

    if num_learners == 1:
        fstr = "psro"
    else:
        fstr = "dppLoss_" if dpp else "origLoss"
    plt.savefig(os.path.join(PATH_RESULTS, "trajectories_" + fstr + ".pdf"))

    return torch_pop, exps, L_cards


def psro_efficient_gradient(iters=5, num_learners=4, lr=0.2, train_iters=10, seed=0):

    # Generate population
    torch_pop = TorchPop(num_learners, seed=seed, async_mode=True)

    # Compute initial exploitability and init stuff
    metagame = torch_pop.get_metagame(numpy=True)
    metanash = fictitious_play(payoffs=metagame, iters=1000, symmetric=False)[0][-1][
        :, -1
    ]
    exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
    exps = [exp]
    L_card = 0.0
    L_cards = []

    for i in range(iters):
        # Define the weighting towards diversity
        lambda_weight = 0.0 + (0.7 / (1 + np.exp(-0.25 * (i - (25)))))
        lambda_weight = 1.0 - lambda_weight
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = torch_pop.pop_size - j - 1

            # Diverse PSRO
            exp_payoff, L_card = gradient_loss_update(
                torch_pop,
                k,
                train_iters=10,
                lr=0.1,
                lambda_weight=lambda_weight,
                dpp=False,
                symmetric=False,
            )
            if j == num_learners - 1:
                torch_pop.add_new()

        metagame = torch_pop.get_metagame(numpy=True)
        metanash = fictitious_play(payoffs=metagame, iters=1000, symmetric=False)[0][
            -1
        ][:, -1]
        exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
        exps.append(exp)
        L_cards.append(L_card)

        if i % 1 == 0:
            print(
                "ITERATION: ",
                i,
                " exp full: {:.4f}".format(exps[-1]),
                "L_CARD: {:.3f}".format(L_cards[-1]),
                "lw: {:.3f}".format(lambda_weight),
            )

    fig1, axs1 = plt.subplots(1, 1)
    torch_pop.visualise_pop(br=None, ax=axs1)

    if num_learners == 1:
        fstr = "epsro"
    else:
        fstr = "pepsro"
    plt.savefig(os.path.join(PATH_RESULTS, "trajectories_" + fstr + ".pdf"))

    return torch_pop, exps, L_cards


def iterative_gradient(
    iters=5, num_learners=1, lr=0.2, train_iters=10, dpp=True, seed=0
):
    # Generate population
    torch_pop = TorchPop(num_learners, seed=seed, mixed_oracles=True)

    # Compute initial exploitability and init stuff
    metagame = torch_pop.get_metagame(numpy=True)
    metanash = fictitious_play(payoffs=metagame, iters=1000)[0][-1]
    exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
    exps = [exp]
    L_card = 0.0
    L_cards = []

    for i in range(iters):
        # Define the weighting towards diversity
        lambda_weight = 0.0 + (0.7 / (1 + np.exp(-0.25 * (i - (25)))))
        lambda_weight = 1.0 - lambda_weight
        for j in range(num_learners):
            # first learner (when j=num_learners-1) plays against normal meta Nash
            # second learner plays against meta Nash with first learner included, etc.
            k = torch_pop.pop_size - j - 1

            # Diverse PSRO
            exp_payoff, L_card = mixed_oracle_loss(
                torch_pop,
                k,
                train_iters=train_iters,
                lr=lr,
                lambda_weight=lambda_weight,
            )
            if j == num_learners - 1:
                torch_pop.add_new()

        metagame = torch_pop.get_metagame(numpy=True)
        metanash = fictitious_play(payoffs=metagame, iters=1000)[0][-1]
        exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
        exps.append(exp)
        L_cards.append(L_card)

        if i % 1 == 0:
            print(
                "ITERATION: ",
                i,
                " exp full: {:.4f}".format(exps[-1]),
                "L_CARD: {:.3f}".format(L_cards[-1]),
                "lw: {:.3f}".format(lambda_weight),
            )

    fig1, axs1 = plt.subplots(1, 1)
    torch_pop.visualise_pop(br=None, ax=axs1)

    fstr = "iterative"
    plt.savefig(os.path.join(PATH_RESULTS, "trajectories_" + fstr + ".pdf"))

    return torch_pop, exps, L_cards


def gradient_loss_update_rectified(torch_pop, k, weights, train_iters=10, lr=0.1):

    # Make strategy k trainable
    torch_pop.pop[k].x.requires_grad = True

    # Optimiser
    optimiser = optim.Adam(torch_pop.pop[k].parameters(), lr=lr)

    for iter in range(train_iters):

        # Get metagame and metastrat
        M = torch_pop.get_metagame(k=k + 1)

        # Compute cardinality of pop up until :k UNION training strategy. We use payoffs as features.
        with torch.no_grad():
            M = f.normalize(M, dim=1, p=2)  #  Normalise
            L = M @ M.t()  # Compute kernel
            L_card = torch.trace(
                torch.eye(L.shape[0]) - torch.inverse(L + torch.eye(L.shape[0]))
            )  # Compute cardinality

        # Compute the expected return given that enemy plays agg_strat (using :k first strats)
        exp_payoff = torch_pop.get_payoff_aggregate_weights(
            torch_pop.pop[k], weights, k
        )

        # Loss
        loss = -exp_payoff

        # Optimise !
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    torch_pop.pop_hist[k].append(torch_pop.pop[k].x.detach().cpu().clone().numpy())

    # Make strategy k non-trainable
    torch_pop.pop[k].x.requires_grad = False
    return exp_payoff.item(), L_card.item()


# Define the PSRO rectified nash algorithm
def psro_rectified_gradient(
    iters=10, eps=1e-2, seed=0, train_iters=10, num_pseudo_learners=4, lr=0.3
):
    # Generate population
    torch_pop = TorchPop(num_pseudo_learners, seed=seed)

    # Compute initial exploitability and init stuff
    metagame = torch_pop.get_metagame(numpy=True)
    metanash = fictitious_play(payoffs=metagame, iters=1000)[0][-1]
    exp = torch_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
    exps = [exp]
    L_cards = []

    counter = 0
    while counter < iters * num_pseudo_learners:
        if counter % (5 * num_pseudo_learners) == 0:
            print("iteration: ", int(counter / num_pseudo_learners), " exp: ", exps[-1])
            print("size of population: ", torch_pop.pop_size)

        new_pop = copy.deepcopy(torch_pop)
        emp_game_matrix = torch_pop.get_metagame(numpy=True)
        averages, _ = fictitious_play(payoffs=emp_game_matrix, iters=iters)

        # go through all policies. If the policy has positive meta Nash mass,
        # find policies it wins against, and play against meta Nash weighted mixture of those policies
        for j in range(torch_pop.pop_size):
            if counter > iters * num_pseudo_learners:
                fig1, axs1 = plt.subplots(1, 1)
                torch_pop.visualise_pop(br=None, ax=axs1)
                plt.savefig(os.path.join(PATH_RESULTS, "trajectories_rectified.pdf"))
                return torch_pop, exps, L_cards
            # if positive mass, add a new learner to pop and update it with steps, submit if over thresh
            # keep track of counter
            if averages[-1][j] > eps:
                # create learner
                new_pop.add_new()
                idx = new_pop.pop_size - 1
                counter += 1
                print(counter)

                mask = emp_game_matrix[j, :]
                mask += 1e-5
                mask[mask >= 0] = 1
                mask[mask < 0] = 0
                weights = np.multiply(mask, averages[-1])
                weights /= weights.sum()

                exp_payoff, L_card = gradient_loss_update_rectified(
                    new_pop, idx, weights, train_iters=train_iters, lr=lr
                )

                if counter % num_pseudo_learners == 0:
                    metagame = new_pop.get_metagame(numpy=True)
                    metanash = fictitious_play(payoffs=metagame, iters=1000)[0][-1]
                    exp = new_pop.get_exploitability(metanash, lr, nb_iters=train_iters)
                    exps.append(exp)
                    L_cards.append(L_card)
        torch_pop = copy.deepcopy(new_pop)

    fig1, axs1 = plt.subplots(1, 1)
    torch_pop.visualise_pop(br=None, ax=axs1)
    plt.savefig(os.path.join(PATH_RESULTS, "trajectories_rectified.pdf"))

    return torch_pop, exps, L_cards


def get_br_to_strat(strat, payoffs=None, verbose=False, symmetric=True):
    if symmetric:
        row_weighted_payouts = strat @ payoffs
        br = np.zeros_like(row_weighted_payouts)
        br[np.argmin(row_weighted_payouts)] = 1
        if verbose:
            print(
                row_weighted_payouts[np.argmin(row_weighted_payouts)], "exploitability"
            )
    else:
        # print("ssfffe", strat[0].shape, strat[1].shape, payoffs.shape)
        row_weighted_payouts, column_weighted_payouts = (
            strat[0] @ payoffs,
            -payoffs @ strat[1],
        )
        rbr = np.zeros_like(row_weighted_payouts)
        cbr = np.zeros_like(column_weighted_payouts)
        rbr[np.argmin(row_weighted_payouts)] = 1
        cbr[np.argmin(column_weighted_payouts)] = 1
        br = [rbr, cbr]
        if verbose:
            print(
                row_weighted_payouts[np.argmin(row_weighted_payouts)]
                + column_weighted_payouts[np.argmin(column_weighted_payouts)],
                "exploitability",
            )
    return br


# Fictituous play as a nash equilibrium solver
def fictitious_play(iters=2000, payoffs=None, verbose=False, symmetric=True):
    exps = []
    if symmetric:
        dim = payoffs.shape[0]
        pop = np.random.uniform(0, 1, (1, dim))
        pop = pop / pop.sum(axis=1)[:, None]
        averages = pop
        for i in range(iters):
            average = np.average(pop, axis=0)
            br = get_br_to_strat(average, payoffs=payoffs)
            exp1 = average @ payoffs @ br.T
            exp2 = br @ payoffs @ average.T
            exps.append(exp2 - exp1)
            # if verbose:
            #     print(exp, "exploitability")
            averages = np.vstack((averages, average))
            pop = np.vstack((pop, br))
    else:
        rdim, cdim = payoffs.shape
        rpop = np.random.uniform(0, 1, (1, rdim))
        cpop = np.random.uniform(0, 1, (cdim, 1))
        rpop = rpop / rpop.sum(axis=1)[:, None]
        cpop = cpop / cpop.sum(axis=0)[None, :]
        averages = [rpop, cpop]
        for i in range(iters):
            average = [np.average(rpop, axis=0), np.average(cpop, axis=1)]
            br = get_br_to_strat(average, payoffs=payoffs, symmetric=symmetric)
            exp1 = average[0] @ payoffs @ br[0]
            exp2 = br[1] @ payoffs @ average[1]
            if verbose:
                print(exp1, exp2, "exploitability")
            exps.append(exp2 - exp1)
            # if verbose:
            #     print(exp, "exploitability")
            averages = [
                np.vstack((averages[0], average[0])),
                np.hstack((averages[1], average[1].reshape(-1, 1))),
            ]
            # print("shape", cpop.shape, br[0].shape, br[1].shape, rpop.shape)
            cpop = np.hstack((cpop, br[0].reshape(-1, 1)))
            rpop = np.vstack((rpop, br[1]))

        # print("averagefge", averages[0].shape, averages[1].shape)

    return averages, exps


# def fictitious_play(iters=2000, payoffs=payoffs, verbose=False):
#     dim = payoffs.shape[0]
#     pop = np.random.uniform(0, 1, (1, dim))
#     pop = pop / pop.sum(axis=1)[:, None]
#     averages = pop
#     exps = []
#     for i in range(iters):
#         average = np.average(pop, axis=0)
#         br = get_br_to_strat(average, payoffs=payoffs)
#         exp1 = average @ payoffs @ br.T
#         exp2 = br @ payoffs @ average.T
#         exps.append(exp2 - exp1)
#         # if verbose:
#         #     print(exp, "exploitability")
#         averages = np.vstack((averages, average))
#         pop = np.vstack((pop, br))
#     return averages, exps


def run_experiments(
    num_experiments=1,
    num_threads=20,
    iters=40,
    rectified=False,
    psro=False,
    pipeline_psro=False,
    dpp_psro=False,
    epsro=False,
    pepsro=False,
    iterative=False,
    yscale="none",
    verbose=False,
    train_iters=10,
):

    rectified_exps = []
    rectified_cardinality = []

    psro_exps = []
    psro_cardinality = []

    pipeline_exps = []
    pipeline_cardinality = []

    dpp_exps = []
    dpp_cardinality = []

    epsro_exps = []
    epsro_cardinality = []

    pepsro_exps = []
    pepsro_cardinality = []

    iterative_exps = []
    iterative_cardinality = []

    for i in range(num_experiments):
        print("Experiment: ", i + 1)

        if rectified:
            print("Rectified")
            torch_pop, exps, L_cards = psro_rectified_gradient(
                iters=iters,
                seed=i,
                train_iters=train_iters,
                num_pseudo_learners=1,
                lr=LR,
            )
            rectified_exps.append(exps)
            rectified_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["rectified"]) + ".p", "wb"),
            )

        if dpp_psro:
            print("Grad DPP")
            torch_pop, exps, L_cards = psro_gradient(
                iters=iters,
                num_learners=num_threads,
                lr=LR,
                train_iters=train_iters,
                seed=i,
                dpp=True,
            )
            dpp_exps.append(exps)
            dpp_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["dpp"]) + ".p", "wb"),
            )

        if pipeline_psro:
            print("Grad no DPP Pipeline PSRO")
            torch_pop, exps, L_cards = psro_gradient(
                iters=iters,
                num_learners=num_threads,
                lr=LR,
                train_iters=train_iters,
                seed=i,
                dpp=False,
            )
            pipeline_exps.append(exps)
            pipeline_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["ppsro"]) + ".p", "wb"),
            )

        if psro:
            print("PSRO no DPP")
            torch_pop, exps, L_cards = psro_gradient(
                iters=iters,
                num_learners=1,
                lr=LR,
                train_iters=train_iters,
                seed=i,
                dpp=False,
            )
            psro_exps.append(exps)
            psro_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["psro"]) + ".p", "wb"),
            )

        if epsro:
            print("EPSRO")
            torch_pop, exps, L_cards = psro_efficient_gradient(
                iters=iters,
                num_learners=1,
                lr=LR,
                train_iters=train_iters,
                seed=i,
            )
            epsro_exps.append(exps)
            epsro_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["epsro"]) + ".p", "wb"),
            )

        if pepsro:
            print("PEPSRO")
            torch_pop, exps, L_cards = psro_efficient_gradient(
                iters=iters,
                num_learners=num_threads,
                lr=LR,
                train_iters=train_iters,
                seed=i,
            )
            pepsro_exps.append(exps)
            pepsro_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["pepsro"]) + ".p", "wb"),
            )

        if iterative:
            print("Mixed-Oracles")
            torch_pop, exps, L_cards = iterative_gradient(
                iters=iters,
                num_learners=num_threads,
                lr=LR,
                train_iters=train_iters,
                seed=i,
            )
            iterative_exps.append(exps)
            iterative_cardinality.append(L_cards)
            pickle.dump(
                {"pop": torch_pop},
                open(os.path.join(PATH_RESULTS, FILE_TRAJ["pepsro"]) + ".p", "wb"),
            )

        d = {
            "rectified_exps": rectified_exps,
            "rectified_cardinality": rectified_cardinality,
            "pipeline_exps": pipeline_exps,
            "pipeline_cardinality": pipeline_cardinality,
            "dpp_exps": dpp_exps,
            "dpp_cardinality": dpp_cardinality,
            "psro_exps": psro_exps,
            "psro_cardinality": psro_cardinality,
            "epsro_exps": epsro_exps,
            "epsro_cardinality": epsro_cardinality,
            "pepsro_exps": pepsro_exps,
            "pepsro_cardinality": pepsro_cardinality,
            "iterative_exps": iterative_exps,
            "iterative_cardinality": iterative_cardinality,
        }
        pickle.dump(d, open(os.path.join(PATH_RESULTS, "checkpoint_" + str(i)), "wb"))

    def plot_error(data, label=""):
        avg = np.mean(np.array(data), axis=0)
        error_bars = stats.sem(np.array(data))
        plt.plot(avg, label=label)
        plt.fill_between(
            [i for i in range(avg.shape[0])],
            avg - error_bars,
            avg + error_bars,
            alpha=alpha,
        )

    num_plots = 2

    alpha = 0.4
    for j in range(num_plots):
        fig_handle = plt.figure()
        if rectified:
            if j == 0:
                length = min([len(l) for l in rectified_exps])
                for i, l in enumerate(rectified_exps):
                    rectified_exps[i] = rectified_exps[i][:length]
                plot_error(rectified_exps, label="PSRO-rN")
            elif j == 1:
                length = min([len(l) for l in rectified_cardinality])
                for i, l in enumerate(rectified_cardinality):
                    rectified_cardinality[i] = rectified_cardinality[i][:length]
                plot_error(rectified_cardinality, label="PSRO-rN")

        if psro:
            if j == 0:
                plot_error(psro_exps, label="PSRO")
            elif j == 1:
                plot_error(psro_cardinality, label="PSRO")

        if pipeline_psro:
            if j == 0:
                plot_error(pipeline_exps, label="P-PSRO")
            elif j == 1:
                plot_error(pipeline_cardinality, label="P-PSRO")

        if dpp_psro:
            if j == 0:
                plot_error(dpp_exps, label="Ours (DPP Loss)")
            elif j == 1:
                plot_error(dpp_cardinality, label="Ours (DPP Loss)")

        if epsro:
            if j == 0:
                plot_error(epsro_exps, label="EPSRO")
            elif j == 1:
                plot_error(epsro_cardinality, label="EPSRO")

        if pepsro:
            if j == 0:
                plot_error(pepsro_exps, label="PEPSRO")
            elif j == 1:
                plot_error(pepsro_cardinality, label="PEPSRO")

        if iterative:
            if j == 0:
                plot_error(pepsro_exps, label="Mixed-Oracle")
            elif j == 1:
                plot_error(pepsro_cardinality, label="Mixed-Oracle")

        plt.legend(loc="upper left")

        if yscale == "both":
            if j == 0:
                plt.yscale("log")
        elif yscale == "log":
            plt.yscale("log")

        plt.savefig(os.path.join(PATH_RESULTS, "figure_" + str(j) + ".pdf"))


def run_traj(
    rectified=False,
    dpp=False,
    ppsro=False,
    psro=False,
    epsro=False,
    pepsro=False,
    iterative=False,
):
    keys = locals()

    n_enabled = sum([int(e) for e in keys.values()])

    # "rectified": "rectified.p",
    # "psro": "psro.p",
    # "p-psro": "p_psro.p",
    # "dpp": "dpp.p",

    titles = {
        "rectified": "PSRO-rN",
        "dpp": "DPP-PSRO",
        "ppsro": "P-PSRO",
        "psro": "PSRO",
        "epsro": "EPSRO",
        "pepsro": "PEPSRO",
        "iterative": "Mixed-Oracles",
    }
    pops = {}
    fig1, axs1 = plt.subplots(
        1, n_enabled, figsize=(n_enabled * n_enabled, n_enabled * 1), dpi=200
    )
    axs1 = axs1.flatten()
    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"][:n_enabled]

    i = 0
    for key in FILE_TRAJ.keys():
        if keys[key]:
            ax = axs1[i]
            d = pickle.load(
                open(os.path.join(PATH_RESULTS, FILE_TRAJ[key]) + ".p", "rb")
            )
            pops[FILE_TRAJ[key]] = d["pop"]
            pops[FILE_TRAJ[key]].visualise_pop(ax=ax, color=colors[i])
            ax.set_title(titles[key])
            i += 1

    fig1.tight_layout()
    fig1.savefig(os.path.join(PATH_RESULTS, "trajectories.pdf"))


# PATH_RESULTS = os.path.join(BASE_DIR, "results", "selected_draw")
if __name__ == "__main__":

    ppsro = False  # True
    rectified = True
    psro = False  # True  # True
    epsro = False  # True  # True
    pepsro = False  # True
    iterative = True

    run_experiments(
        num_experiments=10,
        num_threads=4,
        iters=50,
        pipeline_psro=ppsro,
        dpp_psro=False,
        rectified=rectified,
        psro=psro,
        epsro=epsro,
        pepsro=pepsro,
        iterative=iterative,
        yscale="none",
        train_iters=TRAIN_ITERS,
    )
    run_traj(
        ppsro=ppsro,
        rectified=rectified,
        psro=psro,
        epsro=epsro,
        pepsro=pepsro,
        iterative=iterative,
    )

    # plt.show()
