import argparse
import time
import os
from black import logging

import ray
import numpy as np
import torch
import nashpy as nash

import pprint

from collections import namedtuple

from torch.optim import SGD
from ray.util import ActorPool

from expground.settings import LOG_DIR
from expground.types import List, Tuple
from expground.logger import Log
from expground.envs.matrix.env import payoff_matrix_generate


NashConv = namedtuple("NashConv", "nash_conv, player1_improvement, player2_improvement")


def nash_conv(urr_matrix, eqs) -> NashConv:
    r_improve = np.max(
        urr_matrix[0].dot(eqs[1])
    )  # - eqs[0].dot(urr_matrix[0]).dot(eqs[1])
    c_improve = np.max(
        eqs[0].dot(urr_matrix[1])
    )  # - eqs[0].dot(urr_matrix[1]).dot(eqs[1])
    return NashConv(r_improve + c_improve, r_improve, c_improve)


def smooth(learning_curves: List[np.ndarray], ratio: float):
    for i, e in enumerate(learning_curves):
        last = 0
        tmp = []
        for _e in e:
            v = ratio * _e + last * (1.0 - ratio)
            tmp.append(v)
            last = v
        learning_curves[i] = tmp
    return learning_curves


# compute oracle
def oracle(matrix, n=10000):
    game = nash.Game(*matrix)

    *_, eqs = iter(game.fictitious_play(iterations=n))
    eq = [tuple(map(lambda x: x / np.sum(x), eqs))][0]
    x = eq[0].reshape(1, -1).dot(matrix[0]).dot(eq[1])
    y = eq[0].reshape(1, -1).dot(matrix[1]).dot(eq[1])
    return eq, (x, y)


def meta_policy(logits, to_numpy: bool = False):
    res = [torch.softmax(x, dim=-1) for x in logits]
    if to_numpy:
        res = [e.detach().numpy() for e in res]
    return res


def look_ahead(policy, old_policy, optimizer, matrix, dim):
    for i, table in enumerate(matrix):
        # import pdb;pdb.set_trace()
        if i == 0:
            # given opponent to optimize the lowest case
            v = torch.matmul(
                torch.DoubleTensor(table),
                old_policy[1].view(
                    -1,
                ),
            ).detach()
            idx = torch.argmax(v)
            # convert idx to one hot
            one_hot = torch.zeros(dim, dtype=torch.double)
            one_hot.scatter_(0, idx, 1.0)
            main_policy = policy[0] * one_hot
            loss = -torch.sum(main_policy * v)
        else:
            v = torch.matmul(old_policy[0].view(-1), torch.DoubleTensor(table)).detach()
            idx = torch.argmax(v)
            # convert idx to one hot
            one_hot = torch.zeros(dim, dtype=torch.double)
            one_hot.scatter_(0, idx, 1.0)
            main_policy = policy[1] * one_hot
            loss = -torch.sum(v * main_policy)
        optimizer[i].zero_grad()
        loss.backward()
        optimizer[i].step()


def optimize(logits, matrix, dims, optimizer, br_player):
    # for _ in range(1000):
    old_policy = meta_policy(logits)
    for i, table in enumerate(matrix):
        table = torch.DoubleTensor(table.copy())
        if i == 0:
            v = torch.matmul(old_policy[0], table).reshape(-1)
        else:
            v = torch.matmul(table, old_policy[1]).reshape(-1)
        opponent_idx = torch.argmin(v)
        if i == 0:
            loss = -torch.sum(old_policy[0] * table[:, opponent_idx].reshape(-1))
        else:
            loss = -torch.sum(old_policy[1] * table[opponent_idx, :].reshape(-1))
        optimizer[i].zero_grad()
        loss.backward()
        optimizer[i].step()
    # if br_player == 0:
    #     print("curr payff", old_policy[0].detach().numpy()@matrix[0]@old_policy[1].detach().numpy())
    return logits


def init_support_sets(dim_a: int, dim_b: int, symmetric: bool):
    """Init constraint support sets.

    Args:
        dim_a (int): Original policy space size player a.
        dim_b (int): Original policy space size for player b.
        symmetric (bool): Init support set as symmetric mode or not.

    Returns:
        List[Tuple[int], Tuple[int]]: A list of tuple of int.
    """

    if symmetric:
        assert dim_a == dim_b, (dim_a, dim_b)
        idx = np.random.choice(dim_a)
        return [[(idx, None)], [(idx, None)]]
    else:
        idx_a = np.random.choice(dim_a)
        idx_b = np.random.choice(dim_b)
        return [[(idx_a, None)], [(idx_b, None)]]


def approximate_br(args, urr_matrix, dims, br_player, uu_matrix):
    meta_logits = [
        torch.tensor(np.random.random(dim), requires_grad=True) for dim in dims
    ]

    optimizer = [SGD([x], lr=args.lr) for i, x in enumerate(meta_logits)]
    for i in range(args.n_episode):
        meta_logits = optimize(meta_logits, urr_matrix, dims, optimizer, br_player)
        policy = meta_policy(meta_logits)
        utilities = [
            torch.matmul(
                torch.matmul(
                    policy[0].view((1, -1)), torch.DoubleTensor(urr_matrix[0].copy())
                ),
                policy[1].view((-1, 1)).detach(),
            ),
            torch.matmul(
                torch.matmul(
                    policy[0].view((1, -1)), torch.DoubleTensor(urr_matrix[1].copy())
                ),
                policy[1].view((-1, 1)).detach(),
            ),
        ]
        utilities = [torch.sum(e).detach() for e in utilities]

    # eq, (u1, u2) = oracle(matrix, n=2000)
    # if br_player == 0:
    #     nc = u2[0] - eq[0].dot(matrix[1]).dot(policy[1].detach().numpy())
    # else:
    #     nc = u1[0] - policy[0].detach().numpy().dot(matrix[0]).dot(eq[1])
    if br_player == 0:
        epsilon = np.max(urr_matrix[0] @ policy[1].detach().numpy())
    else:
        epsilon = np.max(policy[0].detach().numpy() @ urr_matrix[1])
    # epsilon = abs(utilities[br_player])
    return epsilon, meta_policy(meta_logits, to_numpy=True)


def exact_br(
    args,
    urr_matrix: List[np.ndarray],
    dims: List[int],
    br_player: int,
    uu_matrix: List[np.ndarray],
    real_eqs,
    ru1,
    ru2,
):
    eqs, _ = oracle(urr_matrix, n=args.n_episode)
    if br_player == 0:
        epsilon = ru2[0] - real_eqs[0].dot(urr_matrix[1]).dot(eqs[1])
    else:
        epsilon = ru1[0] - eqs[0].dot(urr_matrix[0]).dot(real_eqs[1])
    return epsilon, eqs


@ray.remote
def train(
    args,
    dims: Tuple[int, int],
    matrix: List[np.ndarray],
    br_player: int,
    original_matrix: List[np.ndarray],
    use_approximate: bool = True,
    real_eqs=None,
    ru1=None,
    ru2=None,
) -> Tuple[List, List, float]:
    if use_approximate:
        epsilon, policy = approximate_br(args, matrix, dims, br_player, original_matrix)
    else:
        epsilon, policy = exact_br(
            args, matrix, dims, br_player, original_matrix, real_eqs, ru1, ru2
        )

    # eq, (u1, u2) = oracle(matrix, n=2000)
    # if br_player == 0:
    #     nc = u2[0] - eq[0].dot(matrix[1]).dot(policy[1].detach().numpy())
    # else:
    #     nc = u1[0] - policy[0].detach().numpy().dot(matrix[0]).dot(eq[1])
    br = [None, None]
    u_vector = [None, None]

    pure_supports = []
    mass_placeholder = np.zeros(dims[br_player])
    for i, x in enumerate(policy[br_player]):
        if x > 0:
            mass_placeholder[i] = x
            pure_supports.append(i)

    br[br_player] = (pure_supports, mass_placeholder)
    if len(pure_supports) > 0:
        mat = original_matrix[br_player]
        if br_player == 0:
            u_vector[br_player] = mass_placeholder.dot(mat)
        else:
            u_vector[br_player] = mat.dot(mass_placeholder)

    return br, u_vector, epsilon


def load_game(f_path: str = None, game_type: int = None, rng=None, dim: int = None):
    if f_path is not None:
        import pickle

        with open(f_path, "rb") as f:
            data = pickle.load(f)
        return [data, -data]

    if game_type is not None:
        return payoff_matrix_generate(2, game_type, dim, rng)


@ray.remote
def learn(args, smooth_rate: float = 0.6, seed: int = None):
    # set log level
    Log.setLevel(logging.DEBUG if args.debug else logging.INFO)
    learning_curve = [[], []]
    support_sets = init_support_sets(args.dim_a, args.dim_b, args.symmetric_init)
    actors: List[type] = [train.remote, train.remote]
    # load matrix from local
    original_matrix = load_game(
        args.payoff_config,
        args.game_type,
        rng=np.random.RandomState(seed),
        dim=[args.dim_a, args.dim_b],
    )
    epoch = 0

    matrix_for_a = [x[:, support_sets[1][0][0]].reshape(-1, 1) for x in original_matrix]
    matrix_for_b = [x[support_sets[0][0][0], :].reshape(1, -1) for x in original_matrix]

    real_eqs, (ru1, ru2) = oracle(original_matrix, n=2000)

    dims = [len(x) for x in support_sets]
    params_list = [
        # BR learning for player a, give a constraint policy set of player b.
        [
            args,
            (args.dim_a, dims[1]),
            matrix_for_a,
            0,
            original_matrix,
            args.use_approximate,
            real_eqs,
            ru1,
            ru2,
        ],
        # BR learning for player b, give a constraint policy set of player a.
        [
            args,
            (dims[0], args.dim_b),
            matrix_for_b,
            1,
            original_matrix,
            args.use_approximate,
            real_eqs,
            ru1,
            ru2,
        ],
    ]

    while epoch < args.max_support_size:
        start = time.time()
        res = ray.get([actor(*params) for actor, params in zip(actors, params_list)])
        end = time.time()
        Log.debug(
            "epoch {}/{} EXP={} EPS={} sec".format(
                epoch + 1,
                args.max_support_size,
                round(res[0][-1] + res[1][-1], 4),
                round((end - start) * (args.max_support_size - epoch - 1), 2),
            )
        )
        for i, e in enumerate(res):
            learning_curve[i].append(e[-1])
            assert e[0][i] is not None, (i, e)
            support_sets[i].append(e[0][i])
            assert e[1][i] is not None, (i, e)

            if i == 0:
                # player1 extends its restricted game with matrix b for player2's br learning
                mat = [np.zeros((dims[0] + 1, args.dim_b)) for _ in range(2)]
                mat[0][: dims[0]] = matrix_for_b[0].copy()
                mat[1][: dims[0]] = matrix_for_b[1].copy()
                mat[1][:-1] = -e[1][i].copy()
                mat[0][:-1] = e[1][i].copy()
                matrix_for_b = mat
            else:
                # player2 extends its restricted game with matrix a
                mat = [np.zeros((args.dim_a, dims[1] + 1)) for _ in range(2)]
                mat[0][:, : dims[1]] = matrix_for_a[0].copy()
                mat[1][:, : dims[1]] = matrix_for_a[1].copy()
                mat[1][:, -1] = e[1][i].copy()
                mat[0][:, -1] = -e[1][i].copy()
                matrix_for_a = mat
            # print("\t* eps for player-{}: {}".format(i, e[-1]))
        # print("")

        # update dims and params_list
        dims = [len(x) for x in support_sets]
        params_list = [
            # BR learning for player a, give a constraint policy set of player b.
            [
                args,
                (args.dim_a, dims[1]),
                matrix_for_a,
                0,
                original_matrix,
                args.use_approximate,
                real_eqs,
                ru1,
                ru2,
            ],
            # BR learning for player b, give a constraint policy set of player a.
            [
                args,
                (dims[0], args.dim_b),
                matrix_for_b,
                1,
                original_matrix,
                args.use_approximate,
                real_eqs,
                ru1,
                ru2,
            ],
        ]
        epoch += 1

        # a, b = list(zip(*history))
    # smooth
    merged = list(map(lambda x: x[0] + x[1], zip(*learning_curve)))
    learning_curve.append(merged)
    learning_curve = smooth(learning_curve, 0.6)
    return learning_curve
    # plt.plot(np.arange(args.max_support_size), learning_curve[0], label="player1")
    # plt.plot(np.arange(args.max_support_size), learning_curve[1], label="player2")
    # plt.legend(["player1", "player2"])
    # plt.savefig("{}/epsilon_curve_{}.png".format(log_path, int(time.time())))


def multi_learn(args):
    # learn.remote
    pool = ActorPool([learn.remote for _ in range(5)])
    task = pool.map_unordered(
        lambda a, v: a(args, smooth_rate=1.0, seed=v),
        [np.random.randint(1, 100) for _ in range(args.n_group)],
    )
    res = list(task)
    # merge res
    player1, player2, merged = list(map(np.array, zip(*res)))
    mean_player1 = np.mean(player1, axis=0)
    mean_player2 = np.mean(player2, axis=0)
    min_player1 = np.min(player1, axis=0)
    min_player2 = np.min(player2, axis=0)
    max_player1 = np.max(player1, axis=0)
    max_player2 = np.max(player2, axis=0)
    merged_mean = np.mean(merged, axis=0)
    std = np.std(merged, axis=0)
    merged_min = merged_mean - std
    merged_max = merged_mean + std

    return {
        "player1": {
            "utility": {"mean": mean_player1, "min": min_player1, "max": max_player1}
        },
        "player2": {
            "utility": {"mean": mean_player2, "min": min_player2, "max": max_player2}
        },
        "NashConv": {"mean": merged_mean, "min": merged_min, "max": merged_max},
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Minimax optimization for simulation-free PSRO.")
    parser.add_argument(
        "--dim_a", type=int, help="Original policy space size for player A.", default=3
    )
    parser.add_argument(
        "--dim_b", type=int, help="Original policy space size for player B.", default=3
    )
    parser.add_argument(
        "--game_type",
        type=int,
        help="Game type, 0 for RANDOM_SYMMETRIC, 1 for RANDOM, 2 for CYCLIC",
        default=2,
    )
    parser.add_argument(
        "--symmetric_init",
        action="store_true",
        help="Init constraint policy space in symmetric mode or not.",
    )
    parser.add_argument("--lr", default=1.0, type=float)
    parser.add_argument("--n_episode", type=int, help="Training epochs.", default=1000)
    parser.add_argument("--n_group", default=5, type=int)
    parser.add_argument(
        "--max_support_size",
        type=int,
        help="Determine how big the final support is.",
        default=10,
    )
    parser.add_argument("--seed", type=int, default=100)
    parser.add_argument("--payoff_config", type=str, default=None)
    parser.add_argument("--local_mode", action="store_true")

    args = parser.parse_args()
    args.symmetric_init = True

    if not ray.is_initialized():
        ray.init(local_mode=args.local_mode)

    log_path = os.path.join(LOG_DIR, "imgs")
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    res = multi_learn(args)
    pprint.pprint(res)
