import logging
import itertools
import string

import cvxpy as cp
import numpy as np
import scipy as sp


DEFAULT_ECOS_SOLVER_KWARGS = dict(
    solver="ECOS",
    max_iters=100000000,
    abstol=1e-7,
    reltol=1e-7,
    feastol=1e-7,
    abstol_inacc=1e-7,
    reltol_inacc=1e-7,
    feastol_inacc=1e-7,
    verbose=False,
)
DEFAULT_OSQP_SOLVER_KWARGS = dict(
    solver="OSQP",
    max_iter=1000000000,
    eps_abs=1e-8,
    eps_rel=1e-8,
    eps_prim_inf=1e-8,
    eps_dual_inf=1e-8,
    polish_refine_iter=100,
    check_termination=1000,
    sigma=1e-7,  # Default 1e-6
    delta=1e-7,  # Default 1e-06
    verbose=False,
)
DEFAULT_CVXOPT_SOLVER_KWARGS = dict(
    solver="CVXOPT",
    maxiters=200000,
    abstol=5e-8,
    reltol=5e-8,
    feastol=5e-8,
    refinement=10,
    verbose=False,
)

# Helper Functions - Dominated strategy elimination.
def _eliminate_dominated_payoff(
        payoff, epsilon, action_labels=None, action_repeats=None, weakly=False):
    """Eliminate epsilon dominated strategies."""
    num_players = payoff.shape[0]
    eliminated = True
    if action_labels is None:
        action_labels = [np.arange(na, dtype=np.int32) for na in payoff.shape[1:]]
    if action_repeats is not None:
        action_repeats = [ar for ar in action_repeats]
    while eliminated:
        eliminated = False
        for p in range(num_players):
            if epsilon > 0.0:
                continue
            num_actions = payoff.shape[1:]
            if num_actions[p] <= 1:
                continue
            for a in range(num_actions[p]):
                index = [slice(None) for _ in range(num_players)]
                index[p] = slice(a, a + 1)
                if weakly:
                    diff = payoff[p] <= payoff[p][tuple(index)]
                else:
                    diff = payoff[p] < payoff[p][tuple(index)]
                axis = tuple(range(p)) + tuple(range(p + 1, num_players))
                less = np.all(diff, axis=axis)
                less[a] = False  # Action cannot eliminate itself.
                if np.any(less):
                    nonzero = np.nonzero(less)
                    payoff = np.delete(payoff, nonzero, axis=p + 1)
                    action_labels[p] = np.delete(action_labels[p], nonzero)
                    if action_repeats is not None:
                        action_repeats[p] = np.delete(action_repeats[p], nonzero)
                    eliminated = True
                    break
    return payoff, action_labels, action_repeats


def _reconstruct_dist(eliminated_dist, action_labels, num_actions):
    """Returns reconstructed dist from eliminated_dist and action_labels.

    Redundant dist elements are given values 0.

    Args:
      eliminated_dist: Array of shape [A0E, A1E, ...].
      action_labels: List of length N and shapes [[A0E], [A1E], ...].
      num_actions: List of length N and values [A0, A1, ...].

    Returns:
      reconstructed_dist: Array of shape [A0, A1, ...].
    """
    reconstructed_payoff = np.zeros(num_actions)
    reconstructed_payoff[np.ix_(*action_labels)] = eliminated_dist
    return reconstructed_payoff


def _eliminate_dominated_decorator(func):
    """Wrap eliminate dominated."""

    def wrapper(payoff, per_player_repeats, *args, eliminate_dominated=True,
                **kwargs):
        epsilon = getattr(kwargs, "epsilon", 0.0)
        if not eliminate_dominated:
            return func(payoff, *args, **kwargs)
        num_actions = payoff.shape[1:]
        eliminated_payoff, action_labels, eliminated_action_repeats = _eliminate_dominated_payoff(
            payoff, epsilon, action_repeats=per_player_repeats)
        eliminated_dist, meta = func(
            eliminated_payoff, eliminated_action_repeats, *args, **kwargs)
        meta["eliminated_dominated_dist"] = eliminated_dist
        meta["eliminated_dominated_payoff"] = eliminated_payoff
        dist = _reconstruct_dist(
            eliminated_dist, action_labels, num_actions)
        return dist, meta

    return wrapper


# Optimization.
def _try_two_solvers(func, *args, **kwargs):
    try:
        logging.debug("Trying CVXOPT.", flush=True)
        kwargs_ = {"solver_kwargs": DEFAULT_CVXOPT_SOLVER_KWARGS, **kwargs}
        res = func(*args, **kwargs_)
    except:  # pylint: disable=bare-except
        logging.debug("CVXOPT failed. Trying OSQP.", flush=True)
        kwargs_ = {"solver_kwargs": DEFAULT_OSQP_SOLVER_KWARGS, **kwargs}
        res = func(*args, **kwargs_)
    return res


# Helper Functions - CCEs.
def _indices(p, a, num_players):
    return [a if p_ == p else slice(None) for p_ in range(num_players)]


def _sparse_indices_generator(player, action, num_actions):
    indices = [(action,) if p == player else range(na)
               for p, na in enumerate(num_actions)]
    return itertools.product(*indices)


def _partition_by_player(val, p_vec, num_players):
    """Partitions a value by the players vector."""
    parts = []
    for p in range(num_players):
        inds = p_vec == p
        if inds.size > 0:
            parts.append(val[inds])
        else:
            parts.append(None)
    return parts


def _cce_constraints(payoff, epsilons, remove_null=True, zero_tolerance=1e-8):
    """Returns the coarse correlated constraints.

    Args:
      payoff: A [NUM_PLAYER, NUM_ACT_0, NUM_ACT_1, ...] shape payoff tensor.
      epsilons: Per player floats corresponding to the epsilon.
      remove_null: Remove null rows of the constraint matrix.
      zero_tolerance: Zero out elements with small value.

    Returns:
      a_mat: The gain matrix for deviting to an action or shape [SUM(A), PROD(A)].
      meta: Dictionary containing meta information.
    """
    num_players = payoff.shape[0]
    num_actions = payoff.shape[1:]
    num_dists = int(np.prod(num_actions))

    cor_cons = int(np.sum(num_actions))

    a_mat = np.zeros([cor_cons] + list(num_actions))
    p_vec = np.zeros([cor_cons], dtype=np.int32)
    i_vec = np.zeros([cor_cons], dtype=np.int32)
    con = 0
    for p in range(num_players):
        for a1 in range(num_actions[p]):
            a1_inds = _indices(p, a1, num_players)
            for a0 in range(num_actions[p]):
                a0_inds = _indices(p, a0, num_players)
                a_mat[con][tuple(a0_inds)] += payoff[p][tuple(a1_inds)]
            a_mat[con] -= payoff[p]
            a_mat[con] -= epsilons[p]

            p_vec[con] = p
            i_vec[con] = a0

            con += 1

    a_mat = np.reshape(a_mat, [cor_cons, num_dists])
    a_mat[np.abs(a_mat) < zero_tolerance] = 0.0
    if remove_null:
        null_cons = np.any(a_mat != 0.0, axis=-1)
        redundant_cons = np.max(a_mat, axis=1) >= 0
        nonzero_mask = null_cons & redundant_cons
        a_mat = a_mat[nonzero_mask, :].copy()
        p_vec = p_vec[nonzero_mask].copy()
        i_vec = i_vec[nonzero_mask].copy()

    meta = dict(
        p_vec=p_vec,
        i_vec=i_vec,
        epsilons=epsilons,
    )

    return a_mat, meta


def _ace_constraints(payoff, epsilons, remove_null=True, zero_tolerance=0.0):
    """Returns sparse alternate ce constraints Ax - epsilon <= 0.

    Args:
      payoff: Dense payoff tensor.
      epsilons: Scalar epsilon approximation.
      remove_null: Whether to remove null row constraints.
      zero_tolerance: Smallest absolute value.

    Returns:
      a_csr: Sparse gain matrix from switching from one action to another.
      e_vec: Epsilon vector.
      meta: Dictionary containing meta information.
    """
    num_players = payoff.shape[0]
    num_actions = payoff.shape[1:]
    num_dists = int(np.prod(num_actions))

    num_cons = 0
    for p in range(num_players):
        num_cons += num_actions[p] * (num_actions[p] - 1)

    a_dok = sp.sparse.dok_matrix((num_cons, num_dists))
    e_vec = np.zeros([num_cons])
    p_vec = np.zeros([num_cons], dtype=np.int32)
    i_vec = np.zeros([num_cons, 2], dtype=np.int32)

    num_null_cons = None
    num_redundant_cons = None
    num_removed_cons = None

    if num_cons > 0:
        con = 0
        for p in range(num_players):
            generator = itertools.permutations(range(num_actions[p]), 2)
            for a0, a1 in generator:
                a0_inds = _sparse_indices_generator(p, a0, num_actions)
                a1_inds = _sparse_indices_generator(p, a1, num_actions)

                for a0_ind, a1_ind in zip(a0_inds, a1_inds):
                    a0_ind_flat = np.ravel_multi_index(a0_ind, num_actions)
                    val = payoff[p][a1_ind] - payoff[p][a0_ind]
                    if abs(val) > zero_tolerance:
                        a_dok[con, a0_ind_flat] = val

                e_vec[con] = epsilons[p]
                p_vec[con] = p
                i_vec[con] = [a0, a1]
                con += 1

        a_csr = a_dok.tocsr()
        if remove_null:
            null_cons = np.logical_or(
                a_csr.max(axis=1).todense() != 0.0,
                a_csr.min(axis=1).todense() != 0.0)
            null_cons = np.ravel(null_cons)
            redundant_cons = np.ravel(a_csr.max(axis=1).todense()) >= e_vec
            nonzero_mask = null_cons & redundant_cons
            a_csr = a_csr[nonzero_mask, :]
            e_vec = e_vec[nonzero_mask].copy()
            p_vec = p_vec[nonzero_mask].copy()
            i_vec = i_vec[nonzero_mask].copy()
            num_null_cons = np.sum(~null_cons)
            num_redundant_cons = np.sum(~redundant_cons)
            num_removed_cons = np.sum(~nonzero_mask)

    else:
        a_csr = a_dok.tocsr()

    meta = dict(
        p_vec=p_vec,
        i_vec=i_vec,
        epsilons=epsilons,
        num_null_cons=num_null_cons,
        num_redundant_cons=num_redundant_cons,
        num_removed_cons=num_removed_cons,
    )

    return a_csr, e_vec, meta


def _get_repeat_factor(action_repeats):
    """Returns the repeat factors for the game."""
    num_players = len(action_repeats)
    out_labels = string.ascii_lowercase[:len(action_repeats)]
    in_labels = ",".join(out_labels)
    repeat_factor = np.ravel(np.einsum(
        "{}->{}".format(in_labels, out_labels), *action_repeats))
    indiv_repeat_factors = []
    for player in range(num_players):
        action_repeats_ = [
            np.ones_like(ar) if player == p else ar
            for p, ar in enumerate(action_repeats)]
        indiv_repeat_factor = np.ravel(np.einsum(
            "{}->{}".format(in_labels, out_labels), *action_repeats_))
        indiv_repeat_factors.append(indiv_repeat_factor)
    return repeat_factor, indiv_repeat_factors


# Solvers.
def _linear(
        payoff,
        a_mat,
        e_vec,
        action_repeats=None,
        solver_kwargs=None,
        cost=None):
    """Returns linear solution.

    This is a linear program.

    Args:
      payoff: A [NUM_PLAYER, NUM_ACT_0, NUM_ACT_1, ...] shape payoff tensor.
      a_mat: Constaint matrix.
      e_vec: Epsilon vector.
      action_repeats: List of action repeat counts.
      solver_kwargs: Solver kwargs.
      cost: Cost function of same shape as payoff.

    Returns:
      An epsilon-correlated equilibrium.
    """
    num_players = payoff.shape[0]
    num_actions = payoff.shape[1:]
    num_dists = int(np.prod(num_actions))

    if solver_kwargs is None:
        solver_kwargs = DEFAULT_ECOS_SOLVER_KWARGS

    if a_mat.shape[0] > 0:
        # Variables.
        x = cp.Variable(num_dists, nonneg=True)

        # Classifier.
        epsilon_dists = cp.matmul(a_mat, x) - e_vec

        # Constraints.
        dist_eq_con = cp.sum(x) == 1
        cor_lb_con = epsilon_dists <= 0

        # Objective.
        if cost is None:
            player_totals = [
                cp.sum(cp.multiply(payoff[p].flat, x)) for p in range(num_players)]
            reward = cp.sum(player_totals)
        else:
            reward = cp.sum(cp.multiply(cost.flat, x))
        obj = cp.Maximize(reward)

        prob = cp.Problem(obj, [
            dist_eq_con,
            cor_lb_con,
        ])

        # Solve.
        prob.solve(**solver_kwargs)
        status = prob.status

        # Distribution.
        dist = np.reshape(x.value, num_actions)

        # Other.
        val = reward.value
    else:
        if action_repeats is not None:
            repeat_factor, _ = _get_repeat_factor(action_repeats)
            x = repeat_factor / np.sum(repeat_factor)
        else:
            x = np.ones([num_dists]) / num_dists
        val = 0.0  # Fix me.
        dist = np.reshape(x, num_actions)
        status = None

    meta = dict(
        x=x,
        a_mat=a_mat,
        val=val,
        status=status,
        payoff=payoff,
        consistent=True,
        unique=False,
    )

    return dist, meta


def _qp_cce(
        payoff,
        a_mats,
        e_vecs,
        assume_full_support=False,
        action_repeats=None,
        solver_kwargs=None,
        min_epsilon=False):
    """Returns the correlated equilibrium with maximum Gini impurity.

    Args:
      payoff: A [NUM_PLAYER, NUM_ACT_0, NUM_ACT_1, ...] shape payoff tensor.
      a_mats: A [NUM_CON, PROD(A)] shape gain tensor.
      e_vecs: Epsilon vector.
      assume_full_support: Whether to ignore beta values.
      action_repeats: Vector of action repeats for each player.
      solver_kwargs: Additional kwargs for solver.
      min_epsilon: Whether to minimize epsilon.

    Returns:
      An epsilon-correlated equilibrium.
    """
    num_players = payoff.shape[0]
    num_actions = payoff.shape[1:]
    num_dists = int(np.prod(num_actions))

    if solver_kwargs is None:
        solver_kwargs = DEFAULT_OSQP_SOLVER_KWARGS

    epsilon = None
    nonzero_cons = [a_mat.shape[0] > 0 for a_mat in a_mats if a_mat is not None]
    if any(nonzero_cons):
        x = cp.Variable(num_dists, nonneg=(not assume_full_support))
        if min_epsilon:
            epsilon = cp.Variable(nonpos=True)
            e_vecs = [epsilon] * num_players

        if action_repeats is not None:
            repeat_factor, _ = _get_repeat_factor(action_repeats)
            x_repeated = cp.multiply(x, repeat_factor)
            dist_eq_con = cp.sum(x_repeated) == 1
            cor_lb_cons = [
                cp.matmul(a_mat, cp.multiply(x, repeat_factor)) <= e_vec
                for a_mat, e_vec in
                zip(a_mats, e_vecs) if a_mat.size > 0]
            eye = sp.sparse.diags(repeat_factor)
        else:
            repeat_factor = 1
            x_repeated = x
            dist_eq_con = cp.sum(x_repeated) == 1
            cor_lb_cons = [
                cp.matmul(a_mat, x) <= e_vec for a_mat, e_vec in
                zip(a_mats, e_vecs) if a_mat.size > 0]
            eye = sp.sparse.eye(num_dists)

        # This is more memory efficient than using cp.sum_squares.
        cost = 1 - cp.quad_form(x, eye)
        if min_epsilon:
            cost -= cp.multiply(2, epsilon)

        obj = cp.Maximize(cost)
        prob = cp.Problem(obj, [dist_eq_con] + cor_lb_cons)
        cost_value = prob.solve(**solver_kwargs)
        status = prob.status
        alphas = [cor_lb_con.dual_value for cor_lb_con in cor_lb_cons]
        lamb = dist_eq_con.dual_value

        val = cost.value
        x = x_repeated.value
        dist = np.reshape(x, num_actions)
    else:
        cost_value = 0.0
        val = 1 - 1 / num_dists
        if action_repeats is not None:
            repeat_factor, _ = _get_repeat_factor(action_repeats)
            x = repeat_factor / np.sum(repeat_factor)
        else:
            x = np.ones([num_dists]) / num_dists
        dist = np.reshape(x, num_actions)
        status = None
        alphas = [np.zeros([])]
        lamb = None

    meta = dict(
        x=x,
        a_mats=a_mats,
        status=status,
        cost=cost_value,
        val=val,
        alphas=alphas,
        lamb=lamb,
        unique=True,
        min_epsilon=None if epsilon is None else epsilon.value,
    )
    return dist, meta


def _qp_ce(
        payoff,
        a_mats,
        e_vecs,
        assume_full_support=False,
        action_repeats=None,
        solver_kwargs=None,
        min_epsilon=False):
    """Returns the correlated equilibrium with maximum Gini impurity.

    Args:
      payoff: A [NUM_PLAYER, NUM_ACT_0, NUM_ACT_1, ...] shape payoff tensor.
      a_mats: A [NUM_CON, PROD(A)] shape gain tensor.
      e_vecs: Epsilon vector.
      assume_full_support: Whether to ignore beta values.
      action_repeats: Vector of action repeats for each player.
      solver_kwargs: Additional kwargs for solver.
      min_epsilon: Whether to minimize epsilon.

    Returns:
      An epsilon-correlated equilibrium.
    """
    num_players = payoff.shape[0]
    num_actions = payoff.shape[1:]
    num_dists = int(np.prod(num_actions))

    if solver_kwargs is None:
        solver_kwargs = DEFAULT_OSQP_SOLVER_KWARGS

    epsilon = None
    nonzero_cons = [a_mat.shape[0] > 0 for a_mat in a_mats if a_mat is not None]
    if any(nonzero_cons):
        x = cp.Variable(num_dists, nonneg=(not assume_full_support))
        if min_epsilon:
            epsilon = cp.Variable(nonpos=True)
            e_vecs = [epsilon] * num_players

        if action_repeats is not None:
            repeat_factor, indiv_repeat_factors = _get_repeat_factor(
                action_repeats)
            x_repeated = cp.multiply(x, repeat_factor)
            dist_eq_con = cp.sum(x_repeated) == 1
            cor_lb_cons = [
                cp.matmul(a_mat, cp.multiply(x, rf)) <= e_vec for a_mat, e_vec, rf in
                zip(a_mats, e_vecs, indiv_repeat_factors) if a_mat.size > 0]
            eye = sp.sparse.diags(repeat_factor)
        else:
            repeat_factor = 1
            x_repeated = x
            dist_eq_con = cp.sum(x_repeated) == 1
            cor_lb_cons = [
                cp.matmul(a_mat, x) <= e_vec for a_mat, e_vec in
                zip(a_mats, e_vecs) if a_mat.size > 0]
            eye = sp.sparse.eye(num_dists)

        # This is more memory efficient than using cp.sum_squares.
        cost = 1 - cp.quad_form(x, eye)
        if min_epsilon:
            cost -= cp.multiply(2, epsilon)

        obj = cp.Maximize(cost)
        prob = cp.Problem(obj, [dist_eq_con] + cor_lb_cons)
        cost_value = prob.solve(**solver_kwargs)
        status = prob.status
        alphas = [cor_lb_con.dual_value for cor_lb_con in cor_lb_cons]
        lamb = dist_eq_con.dual_value

        val = cost.value
        x = x_repeated.value
        dist = np.reshape(x, num_actions)
    else:
        cost_value = 0.0
        val = 1 - 1 / num_dists
        if action_repeats is not None:
            repeat_factor, indiv_repeat_factors = _get_repeat_factor(
                action_repeats)
            x = repeat_factor / np.sum(repeat_factor)
        else:
            x = np.ones([num_dists]) / num_dists
        dist = np.reshape(x, num_actions)
        status = None
        alphas = [np.zeros([])]
        lamb = None

    meta = dict(
        x=x,
        a_mats=a_mats,
        status=status,
        cost=cost_value,
        val=val,
        alphas=alphas,
        lamb=lamb,
        unique=True,
        min_epsilon=None if epsilon is None else epsilon.value,
    )
    return dist, meta


def _expand_meta_game(meta_game, per_player_repeats):
    num_players = meta_game.shape[0]
    for player in range(num_players):
        meta_game = np.repeat(meta_game, per_player_repeats[player], axis=player + 1)
    return meta_game


def _unexpand_meta_dist(meta_dist, per_player_repeats):
    num_players = len(meta_dist.shape)
    for player in range(num_players):
        meta_dist = np.add.reduceat(
            meta_dist, [0] + np.cumsum(per_player_repeats[player]).tolist()[:-1],
            axis=player)
    return meta_dist


