''' Functions to generate weighted voting games (WVGs)

See the GenerateGameDataset.ipynb for an example of how to use these functions.
'''

## Dependencies
from itertools import chain, combinations as combs
from typing import List
import numpy as np
import math

# LP solver
from ortools.linear_solver import pywraplp


#############
# General
#############

def get_min_win_coals(coals: List, weights: np.array, quota: float) -> List:
    '''
    Get the set of minimal winning coalitions
    :param coals: all possible coalitions in the game / the set of winning coals
    :param weights: weight of each player in the game
    :param quota: threshold that determines when a coalition can achieve the task
    :return: the set of all minimal winning coalitions
    '''
    coals_min_win = [c for c in coals if weights[tuple([c])].sum() >= quota and \
         all([(weights[tuple([c])].sum() - weights[j]) < quota for j in c])]
    return coals_min_win


def gen_quota(n_players: int, prob_dist: str) -> float:
    ''' 
    Sample a quota from specified distribution
    :param n_players: The number of players in the game
    :param prob_q: Probability distribution from which to sample the quota
    :return: one quota
    '''
    if prob_dist == 'gauss':
        quota_loc = (((n_players * 2) + 1) / 2) * (n_players / 2)
        quota_scale = np.sqrt(n_players * 2)
        quota = np.random.normal(loc=quota_loc, scale=quota_scale)
        while (quota < 0): # Resample until quota is non-negative
            quota = np.random.normal(loc=quota_loc, scale=quota_scale, size=1,)
        return quota  
    else:
        raise NotImplementedError('Probability distribution not supported.')

#############
# Least core 
#############

def solve_optimal_payoff(n_players: int, coals_win_min: List) -> np.array:
    ''' 
    Compute optimal feasible payoff in the least core plus epsilon
    :param n_players: The number of players in the game
    :param coals_win_min: The minimal list of winning coalitions
    :return: Solution of shape n_players + 1
    '''
    sol = np.zeros(n_players + 1)

    # Instantiate a Glop solver
    solver = pywraplp.Solver.CreateSolver('GLOP') 
    _inf = solver.infinity()
    # Create variables and set bounds
    y = {i : solver.NumVar(0, _inf, f'y[{i}]') for i in np.arange(n_players)}
    eps = solver.NumVar(-_inf, _inf, 'eps')
    # Objective function
    solver.Minimize(eps)

    # Define constraints
    solver.Add(sum([y[i] for i in np.arange(n_players)]) == 1)
    for c in coals_win_min:
        solver.Add(sum([eps, *[y[i] for i in c]]) >= 1)

    # Solve the system and get solutions
    status = solver.Solve()
    if status == pywraplp.Solver.OPTIMAL:
        # Store payoffs
        sol[:n_players] = np.array([y[i].solution_value() for i in range(n_players)])
        # Store epsilon
        sol[-1] = solver.Objective().Value()
        return sol

    elif status != pywraplp.Solver.OPTIMAL:
        raise Exception('Failed to find a solution.')


def compute_excess(coal: tuple, payoffs: np.array, weights: np.array, quota: float) -> float:
    ''' 
    Compute the excess of a coalition
    :param coal: A coalition 
    :param payoffs: A payoff allocation (imputation)
    :param weights: Set of weights for the players
    :param quota: The threshold
    :return: The excess value
    '''
    # Compute coalition value 
    v_coal = [1 if weights[[coal]].sum() >= quota else 0][0]
    # Return difference between v_C and allocated payoffs
    return v_coal - payoffs[[coal]].sum()

#############
# Shapley 
#############

def sample_permutations(n_players: int, n_samples: int) -> np.array:
    ''' 
    Computes multiple random permutations of n_players
    :param n_players: Number of n_players in the permutation
    :param n_samples: Number permutations to approximate the sol concept with
    :return: Permutations of players in shape [n_samples, n_players]
    '''
    samples = np.zeros(shape=(n_samples, n_players), dtype=int)
    N = np.arange(n_players)
    for i in range(n_samples):
        samples[i, :] = np.random.choice(N, size=(n_players), replace=False)
    return samples


def compute_shapley_vals(n_players: int, weights: np.array, quota: float, perms: np.array) -> np.array:
    ''' 
    Computes the Shapley values for a weighted voting game.
    :param n_players: Number of players
    :param weights: Set of weights for the players
    :param quota: The threshold in the game
    :param perms: Array with permutations over the players
    :return: Shapley value for each player in the game
    '''
    count_perm_contrib = np.zeros(n_players)

    for perm in perms:
        grand_coal = [] 
        for player_i in perm:
            grand_coal.append(player_i)
            if weights[grand_coal].sum() >= quota:
                count_perm_contrib[player_i] += 1
                break
    shapley_values = count_perm_contrib / perms.shape[0]

    if math.isclose(shapley_values.sum(), 1, abs_tol=1e-3) == False:
        raise ValueError('Sum of the Shapley values must equal one.')
    else:
        return shapley_values

#############
# Banzhaf
#############

def compute_banzhaf_index(n_players: int, weights: np.array, quota: float, coals_win: np.array) -> np.array:
    '''
    Computes the Banzhaf power index for a weighted voting game.
    :param n_players: Number of players
    :param weights: Set of weights for the players
    :param quota: The threshold in the game
    :param coals_win: Array with all winning coalitions (combinations)
    :return: Banzhaf power index for each player in the game
    '''
    count_critical = np.zeros(n_players)

    for coal in coals_win:
        for player_i in range(n_players):
            coal_without_player = [j for j in coal if j != player_i]
            if weights[coal_without_player].sum() < quota:
                count_critical[player_i] += 1
    banzhaf_index = count_critical / count_critical.sum()

    if math.isclose(banzhaf_index.sum(), 1, abs_tol=1e-3) == False:
        raise ValueError('Sum of the Banzhaf indices has to equal one.')
    else:
        return banzhaf_index