from mb_jpsro.utils import *
from mb_jpsro.utils import _eliminate_dominated_decorator, _expand_meta_game, _unexpand_meta_dist, _ace_constraints, \
    _partition_by_player, _try_two_solvers, _qp_ce, _linear, _qp_cce, _cce_constraints
from open_spiel.python.algorithms import projected_replicator_dynamics
from open_spiel.python.egt import alpharank as alpharank_lib
DIST_TOL = 1e-8


# Meta-solvers - Baselines.
def _uni(meta_game, per_player_repeats, ignore_repeats=False):
    """Uniform."""
    if ignore_repeats:
        num_policies = meta_game.shape[1:]
        num_dists = np.prod(num_policies)
        meta_dist = np.full(num_policies, 1. / num_dists)
    else:
        outs = [ppr / np.sum(ppr) for ppr in per_player_repeats]
        labels = string.ascii_lowercase[:len(outs)]
        comma_labels = ",".join(labels)
        meta_dist = np.einsum("{}->{}".format(comma_labels, labels), *outs)
    return meta_dist, dict()


@_eliminate_dominated_decorator
def _undominated_uni(meta_game, per_player_repeats, ignore_repeats=False):
    """Undominated uniform."""
    return _uni(meta_game, per_player_repeats, ignore_repeats=ignore_repeats)


def _rj(meta_game, per_player_repeats, ignore_repeats=False):
    """Random joint."""
    ignore_repeats = True
    pvals, _ = _uni(
        meta_game, per_player_repeats, ignore_repeats=ignore_repeats)
    meta_dist = np.reshape(
        np.random.multinomial(1, pvals.flat), pvals.shape).astype(np.float64)
    return meta_dist, dict()


@_eliminate_dominated_decorator
def _undominated_rj(meta_game, per_player_repeats, ignore_repeats=False):
    """Undominated random joint."""
    return _rj(meta_game, per_player_repeats, ignore_repeats=ignore_repeats)


def _rd(meta_game, per_player_repeats, ignore_repeats=False):
    """Random dirichlet."""
    ignore_repeats = True
    if ignore_repeats:
        num_policies = meta_game.shape[1:]
        alpha = np.ones(num_policies)
    else:
        outs = [ppr for ppr in per_player_repeats]
        labels = string.ascii_lowercase[:len(outs)]
        comma_labels = ",".join(labels)
        alpha = np.einsum("{}->{}".format(comma_labels, labels), *outs)
    meta_dist = np.reshape(
        np.random.dirichlet(alpha.flat), alpha.shape).astype(np.float64)
    return meta_dist, dict()


@_eliminate_dominated_decorator
def _undominated_rd(meta_game, per_player_repeats, ignore_repeats=False):
    """Undominated random dirichlet."""
    return _rd(meta_game, per_player_repeats, ignore_repeats=ignore_repeats)


def _prd(meta_game, per_player_repeats, ignore_repeats=False):
    """Projected replicator dynamics."""
    if not ignore_repeats:
        meta_game = _expand_meta_game(meta_game, per_player_repeats)
    meta_dist = projected_replicator_dynamics.projected_replicator_dynamics(
        meta_game)
    labels = string.ascii_lowercase[:len(meta_dist)]
    comma_labels = ",".join(labels)
    meta_dist = np.einsum("{}->{}".format(comma_labels, labels), *meta_dist)
    meta_dist[meta_dist < DIST_TOL] = 0.0
    meta_dist /= np.sum(meta_dist)
    meta_dist = _unexpand_meta_dist(meta_dist, per_player_repeats)
    return meta_dist, dict()


@_eliminate_dominated_decorator
def _alpharank(meta_game, per_player_repeats, ignore_repeats=False):
    """AlphaRank."""
    if not ignore_repeats:
        meta_game = _expand_meta_game(meta_game, per_player_repeats)
    meta_dist = alpharank_lib.sweep_pi_vs_epsilon([mg for mg in meta_game])
    meta_dist[meta_dist < DIST_TOL] = 0.0
    meta_dist /= np.sum(meta_dist)
    meta_dist = np.reshape(meta_dist, meta_game.shape[1:])
    if not ignore_repeats:
        meta_dist = _unexpand_meta_dist(meta_dist, per_player_repeats)
    return meta_dist, dict()


# Meta-solvers - CEs.
@_eliminate_dominated_decorator
def _mgce(meta_game, per_player_repeats, ignore_repeats=False):
    """Maximum Gini CE."""
    a_mat, e_vec, meta = _ace_constraints(
        meta_game, [0.0] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    a_mats = _partition_by_player(
        a_mat, meta["p_vec"], len(per_player_repeats))
    e_vecs = _partition_by_player(
        e_vec, meta["p_vec"], len(per_player_repeats))
    dist, _ = _try_two_solvers(_qp_ce, meta_game, a_mats, e_vecs,
        action_repeats=(None if ignore_repeats else per_player_repeats))
    return dist, dict()


@_eliminate_dominated_decorator
def _min_epsilon_mgce(meta_game, per_player_repeats, ignore_repeats=False):
    """Min Epsilon Maximum Gini CE."""
    a_mat, e_vec, meta = _ace_constraints(
        meta_game, [0.0] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    a_mats = _partition_by_player(
        a_mat, meta["p_vec"], len(per_player_repeats))
    e_vecs = _partition_by_player(
        e_vec, meta["p_vec"], len(per_player_repeats))
    dist, _ = _try_two_solvers(
        _qp_ce,
        meta_game, a_mats, e_vecs,
        action_repeats=(None if ignore_repeats else per_player_repeats),
        min_epsilon=True)
    return dist, dict()


@_eliminate_dominated_decorator
def _approx_mgce(meta_game, per_player_repeats, ignore_repeats=False,
                 epsilon=0.01):
    """Approximate Maximum Gini CE."""
    a_mat, e_vec, meta = _ace_constraints(
        meta_game, [0.0] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    max_ab = 0.0
    if a_mat.size:
        max_ab = np.max(a_mat.mean(axis=1))
    a_mat, e_vec, meta = _ace_constraints(
        meta_game, [epsilon * max_ab] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    a_mats = _partition_by_player(
        a_mat, meta["p_vec"], len(per_player_repeats))
    e_vecs = _partition_by_player(
        e_vec, meta["p_vec"], len(per_player_repeats))
    dist, _ = _try_two_solvers(
        _qp_ce,
        meta_game, a_mats, e_vecs,
        action_repeats=(None if ignore_repeats else per_player_repeats))
    return dist, dict()


@_eliminate_dominated_decorator
def _rmwce(meta_game, per_player_repeats, ignore_repeats=False):
    """Random maximum welfare CE."""
    del ignore_repeats
    num_players = len(per_player_repeats)
    cost = np.ravel(np.sum(meta_game, axis=0))
    cost += np.ravel(np.random.normal(size=meta_game.shape[1:])) * 1e-6
    a_mat, e_vec, _ = _ace_constraints(
        meta_game, [0.0] * num_players, remove_null=True,
        zero_tolerance=1e-8)
    x, _ = _linear(meta_game, a_mat, e_vec, cost=cost)
    dist = np.reshape(x, meta_game.shape[1:])
    return dist, dict()


@_eliminate_dominated_decorator
def _mwce(meta_game, per_player_repeats, ignore_repeats=False):
    """Maximum welfare CE."""
    del ignore_repeats
    num_players = len(per_player_repeats)
    cost = np.ravel(np.sum(meta_game, axis=0))
    a_mat, e_vec, _ = _ace_constraints(
        meta_game, [0.0] * num_players, remove_null=True,
        zero_tolerance=1e-8)
    x, _ = _linear(meta_game, a_mat, e_vec, cost=cost)
    dist = np.reshape(x, meta_game.shape[1:])
    return dist, dict()


@_eliminate_dominated_decorator
def _rvce(meta_game, per_player_repeats, ignore_repeats=False):
    """Random vertex CE."""
    del ignore_repeats
    num_players = len(per_player_repeats)
    cost = np.ravel(np.random.normal(size=meta_game.shape[1:]))
    a_mat, e_vec, _ = _ace_constraints(
        meta_game, [0.0] * num_players, remove_null=True,
        zero_tolerance=1e-8)
    x, _ = _linear(meta_game, a_mat, e_vec, cost=cost)
    dist = np.reshape(x, meta_game.shape[1:])
    return dist, dict()


# Meta-solvers - CCEs.
def _mgcce(meta_game, per_player_repeats, ignore_repeats=False):
    """Maximum Gini CCE."""
    a_mat, meta = _cce_constraints(
        meta_game, [0.0] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    a_mats = _partition_by_player(
        a_mat, meta["p_vec"], len(per_player_repeats))
    dist, _ = _try_two_solvers(
        _qp_cce,
        meta_game, a_mats, [0.0] * len(per_player_repeats),
        action_repeats=(None if ignore_repeats else per_player_repeats))
    return dist, dict()


def _min_epsilon_mgcce(meta_game, per_player_repeats, ignore_repeats=False):
    """Min Epsilon Maximum Gini CCE."""
    a_mat, meta = _cce_constraints(
        meta_game, [0.0] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    a_mats = _partition_by_player(
        a_mat, meta["p_vec"], len(per_player_repeats))
    dist, _ = _try_two_solvers(
        _qp_cce,
        meta_game, a_mats, [0.0] * len(per_player_repeats),
        action_repeats=(None if ignore_repeats else per_player_repeats),
        min_epsilon=True)
    return dist, dict()


def _approx_mgcce(meta_game, per_player_repeats, ignore_repeats=False,
                  epsilon=0.01):
    """Maximum Gini CCE."""
    a_mat, meta = _cce_constraints(
        meta_game, [0.0] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    max_ab = 0.0
    if a_mat.size:
        max_ab = np.max(a_mat.mean(axis=1))
    a_mat, meta = _cce_constraints(
        meta_game, [epsilon * max_ab] * len(per_player_repeats), remove_null=True,
        zero_tolerance=1e-8)
    a_mats = _partition_by_player(
        a_mat, meta["p_vec"], len(per_player_repeats))
    dist, _ = _try_two_solvers(
        _qp_cce,
        meta_game, a_mats, [0.0] * len(per_player_repeats),
        action_repeats=(None if ignore_repeats else per_player_repeats))
    return dist, dict()


def _rmwcce(meta_game, per_player_repeats, ignore_repeats=False):
    """Random maximum welfare CCE."""
    del ignore_repeats
    num_players = len(per_player_repeats)
    cost = np.ravel(np.sum(meta_game, axis=0))
    cost += np.ravel(np.random.normal(size=meta_game.shape[1:])) * 1e-6
    a_mat, _ = _cce_constraints(
        meta_game, [0.0] * num_players, remove_null=True,
        zero_tolerance=1e-8)
    e_vec = np.zeros([a_mat.shape[0]])
    x, _ = _linear(meta_game, a_mat, e_vec, cost=cost)
    dist = np.reshape(x, meta_game.shape[1:])
    return dist, dict()


def _mwcce(meta_game, per_player_repeats, ignore_repeats=False):
    """Maximum welfare CCE."""
    del ignore_repeats
    num_players = len(per_player_repeats)
    cost = np.ravel(np.sum(meta_game, axis=0))
    a_mat, _ = _cce_constraints(
        meta_game, [0.0] * num_players, remove_null=True,
        zero_tolerance=1e-8)
    e_vec = np.zeros([a_mat.shape[0]])
    x, _ = _linear(meta_game, a_mat, e_vec, cost=cost)
    dist = np.reshape(x, meta_game.shape[1:])
    return dist, dict()


def _rvcce(meta_game, per_player_repeats, ignore_repeats=False):
    """Random vertex CCE."""
    del ignore_repeats
    num_players = len(per_player_repeats)
    cost = np.ravel(np.random.normal(size=meta_game.shape[1:]))
    a_mat, _ = _cce_constraints(
        meta_game, [0.0] * num_players, remove_null=True,
        zero_tolerance=1e-8)
    e_vec = np.zeros([a_mat.shape[0]])
    x, _ = _linear(meta_game, a_mat, e_vec, cost=cost)
    dist = np.reshape(x, meta_game.shape[1:])
    return dist, dict()


# Flags to functions.
_FLAG_TO_FUNC = dict(
    uni=_uni,
    undominated_uni=_undominated_uni,
    rj=_rj,
    undominated_rj=_undominated_rj,
    rd=_rd,
    undominated_rd=_undominated_rd,
    prd=_prd,
    alpharank=_alpharank,
    mgce=_mgce,
    min_epsilon_mgce=_min_epsilon_mgce,
    approx_mgce=_approx_mgce,
    rmwce=_rmwce,
    mwce=_mwce,
    rvce=_rvce,
    mgcce=_mgcce,
    min_epsilon_mgcce=_min_epsilon_mgcce,
    approx_mgcce=_approx_mgcce,
    rmwcce=_rmwcce,
    mwcce=_mwcce,
    rvcce=_rvcce,
)

INIT_POLICIES = (
    "uniform",  # Unopinionated but slower to evaluate.
    "random_deterministic",  # Faster to evaluate but requires samples.
)
UPDATE_PLAYERS_STRATEGY = (
    "all",
    "cycle",
    "random",
)
BRS = (
    # "cce",
    "ce",
)
BR_SELECTIONS = (
    "all",  # All policies.
    "all_novel",  # All novel policies.
    "random",  # Random.
    "random_novel",  # Random novel BR (one that has not be considered before).
    "largest_gap",  # The BR with the largest gap.
)
META_SOLVERS = (
    "rvcce",
)