import argparse
import ray
import torch
import numpy as np
import nashpy as nash

from collections import defaultdict, namedtuple
from pprint import pprint
from ray.util import ActorPool
from torch.distributions import Categorical

from expground.types import List
from expground.logger import Log
from expground.envs.matrix.env import payoff_matrix_generate


NashConv = namedtuple("NashConv", "nash_conv, player1_improvement, player2_improvement")


def load_game(f_path: str = None, game_type: int = None, rng=None, dim: int = None):
    if f_path is not None:
        import pickle

        with open(f_path, "rb") as f:
            data = pickle.load(f)
        return [data, -data]

    if game_type is not None:
        return payoff_matrix_generate(2, game_type, dim, rng)


def oracle(matrix, n=10000):
    game = nash.Game(*matrix)

    *_, eqs = iter(game.fictitious_play(iterations=2000))
    eq = [tuple(map(lambda x: x / np.sum(x), eqs))][0]
    x = eq[0].reshape(1, -1).dot(matrix[0]).dot(eq[1]).reshape(-1)[0]
    y = eq[0].reshape(1, -1).dot(matrix[1]).dot(eq[1]).reshape(-1)[0]
    return eq, (x, y)


def best_response(
    args,
    dims: List[int],
    meta_strategies: List[np.ndarray],
    br_player: int,
    original_matrix: List[np.ndarray],
):
    U = original_matrix[br_player]
    opponent_strategy = meta_strategies[1 - br_player]
    if br_player == 0:
        v = U.dot(opponent_strategy.reshape(-1, 1)).reshape(-1)
    else:
        v = opponent_strategy.reshape(1, -1).dot(U).reshape(-1)
    br = np.argmax(v)
    return br


def approximate_best_response(
    args,
    dims: List[int],
    meta_strategies: List[np.ndarray],
    br_player: int,
    original_matrix: List[np.ndarray],
):
    U = original_matrix[br_player]
    opponent_strategy = meta_strategies[1 - br_player]
    # init a random policy
    logits = torch.tensor(np.random.random(dims[br_player]), requires_grad=True)
    optimizer = torch.optim.SGD([logits], lr=args.lr)

    # if br_player == 0:
    #     v = U.dot(opponent_strategy.reshape(-1, 1)).reshape(-1)
    # else:
    #     v = opponent_strategy.reshape(1, -1).dot(U).reshape(-1)
    opp_dist = Categorical(probs=torch.from_numpy(opponent_strategy))

    for epoch in range(args.n_episode):
        dist = Categorical(logits=logits)
        app_br = torch.softmax(logits, dim=-1).reshape(1, -1)
        idx = dist.sample(sample_shape=torch.Size([32])).reshape(-1)
        opp_idx = opp_dist.sample(sample_shape=torch.Size([32])).reshape(-1)

        if br_player == 0:
            one_hot = np.zeros((32, U.shape[-1]))
            one_hot[range(32), opp_idx.numpy()] = 1.0

            br_one_hot = np.zeros((32, U.shape[0]))
            br_one_hot[range(32), idx.numpy()] = 1.0
            assert len(app_br.shape) == len(br_one_hot.shape), (
                app_br.shape,
                br_one_hot.shape,
            )
            reward = np.sum(U.dot(one_hot.T).T, axis=-1)  # 32 * x
        else:
            one_hot = np.zeros((32, U.shape[0]))
            one_hot[range(32), opp_idx.numpy()] = 1.0

            br_one_hot = np.zeros((32, U.shape[-1]))
            br_one_hot[range(32), idx.numpy()] = 1.0
            assert len(app_br.shape) == len(br_one_hot.shape), (
                app_br.shape,
                br_one_hot.shape,
            )
            reward = np.sum(one_hot.dot(U), axis=-1)
        # print(app_br[idx].shape, reward.shape)
        # print("shae of f", br_one_hot.shape, app_br.shape, reward.shape)
        loss = -torch.mean(
            torch.log(torch.sum(app_br * torch.from_numpy(br_one_hot), dim=-1))
            * torch.from_numpy(reward)
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print("epoch: {}, loss: {}, reward: {}".format(epoch, loss.item(), reward))
    idx = torch.argmax(app_br)
    return idx


def nash_conv(urr_matrix, eqs) -> NashConv:
    r_improve = np.max(
        urr_matrix[0].dot(eqs[1])
    )  # - eqs[0].dot(urr_matrix[0]).dot(eqs[1])
    c_improve = np.max(
        eqs[0].dot(urr_matrix[1])
    )  # - eqs[0].dot(urr_matrix[1]).dot(eqs[1])
    return NashConv(r_improve + c_improve, r_improve, c_improve)


@ray.remote
def learn(args, seed):
    original_matrix = load_game(
        args.payoff_config,
        args.game_type,
        rng=np.random.RandomState(seed),
        dim=[args.dim_a, args.dim_b],
    )
    # init meta_strategies
    # 0: row_player, 1: colum_player
    meta_strategies = [
        np.zeros(args.dim_a, dtype=np.float32),
        np.zeros(args.dim_b, dtype=np.float32),
    ]
    supports = [set(), set()]
    dims = [args.dim_a, args.dim_b]
    # random choice
    for i, e in enumerate(meta_strategies):
        idx = np.random.choice(dims[i])
        supports[i].add(idx)
        e[idx] = 1.0
        meta_strategies[i] = e
    Log.debug("Initialize supports: {}".format(supports))

    report = []

    handler = approximate_best_response if args.use_approximate else best_response

    for i in range(args.max_support_size):
        brs = []
        for player in range(2):
            br = handler(args, dims, meta_strategies, player, original_matrix)
            supports[player].add(br)
            brs.append(br)
        Log.debug(
            "step [{}/{}] find best responses: {}".format(
                i + 1, args.max_support_size, brs
            )
        )

        # =========== compute NE ====================
        ordered_support = list(map(list, supports))
        # convert support to onehot
        r = np.zeros((len(supports[0]), args.dim_a))
        c = np.zeros((args.dim_b, len(supports[1])))

        r[range(len(supports[0])), np.asarray(ordered_support[0], dtype=int)] = 1.0
        c[np.asarray(ordered_support[1], dtype=int), range(len(supports[1]))] = 1.0

        rmatrix = [r.dot(e).dot(c) for e in original_matrix]
        Log.debug("\t* current restricted game:")
        # pprint(rmatrix)
        eq, (u1, u2) = oracle(rmatrix, n=2000)
        # Log.debug("\t* computed NE: {}".format(eq))
        # =============== compute nash conv ==================
        urr_matrix = [original_matrix[0].dot(c), r.dot(original_matrix[1])]
        v = nash_conv(urr_matrix, eq)
        report.append(v.nash_conv)
        Log.debug("\t* NashConv: {}".format(v))
        # map eq to u rank
        for player in range(2):
            shape = [None, None]
            shape[player] = 1
            shape[1 - player] = dims[player]
            tmp = np.zeros(dims[player], dtype=np.float32)
            for weight, idx in zip(eq[player], ordered_support[player]):
                tmp[idx] = weight
            meta_strategies[player] = tmp

    return [report]


def multi_learn(args):
    # learn.remote
    pool = ActorPool([learn.remote for _ in range(5)])
    task = pool.map_unordered(
        lambda a, v: a(args, v),
        [np.random.randint(1, 100) for _ in range(args.n_group)],
    )
    res = list(task)
    # merge res
    nash_conv = list(map(np.array, zip(*res)))[0]
    mean_nc = np.mean(nash_conv, axis=0)
    std = np.std(nash_conv, axis=0)
    min_nc = mean_nc - std
    max_nc = mean_nc + std

    return {
        "NashConv": {"mean": mean_nc, "min": min_nc, "max": max_nc},
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Minimax optimization for simulation-free PSRO.")
    parser.add_argument(
        "--dim_a", type=int, help="Original policy space size for player A.", default=3
    )
    parser.add_argument(
        "--dim_b", type=int, help="Original policy space size for player B.", default=3
    )
    parser.add_argument(
        "--game_type",
        type=int,
        help="Game type, 0 for RANDOM_SYMMETRIC, 1 for RANDOM, 2 for CYCLIC",
        default=2,
    )
    parser.add_argument(
        "--symmetric_init",
        action="store_true",
        help="Init constraint policy space in symmetric mode or not.",
    )
    parser.add_argument("--n_group", default=5, type=int)
    parser.add_argument(
        "--max_support_size",
        type=int,
        help="Determine how big the final support is.",
        default=10,
    )
    parser.add_argument("--lr", type=float, default=1.0)
    parser.add_argument("--n_episode", type=int, default=100)
    parser.add_argument("--seed", type=int, default=100)
    parser.add_argument("--payoff_config", type=str, default=None)
    parser.add_argument("--local_mode", action="store_true")

    args = parser.parse_args()

    if not ray.is_initialized():
        ray.init(local_mode=args.local_mode)

    res = multi_learn(args)
    pprint(res)
