## Dependencies
import itertools
from itertools import chain
from importlib import reload
import tqdm
from tqdm import trange
import numpy as np
import pickle
import torch
import torch.nn as nn

## Utils
import data_gen_utils as utils

# Specify the range of players
for player in range(4, 11):

    PARAMS = {
        'n_games'     : 5000,
        'n_players'   : player, 
        'prob_q'      : 'gauss',
        'alpha'       : 1,
        'beta'        : 1,
        'loc'         : 1,
        'path'        : '',
        'filename'    : 'players_train.pickle',
        'least_core'  : True,
        'shapley'     : True,
        'banzhaf'     : True,  
        'approximate' : True, # Approximate Shapley value labels
        'n_samples'   : 1000,
        'n_resamples' : 10,
    }
    N = PARAMS['n_players']
    G = PARAMS['n_games']

    # Get all combinations and permutations of these N players
    combs_arr = np.array(list(chain.from_iterable(itertools.combinations(np.arange(N), k) \
                for k in range(1, N + 1))), dtype='object')

    # Create arrays
    W = np.zeros((G, N))
    X = np.zeros((G, N))
    q = np.zeros((G))
    sol_stack_lc = np.zeros((G, N + 1))   
    Y_banz = np.zeros((G, N))
    cset_min_win = [] 

    if PARAMS['n_players'] > 9: # Approximate Shapley values
        Y_shap = np.zeros((G, N, PARAMS['n_resamples']))
    else: 
        Y_shap = np.zeros((G, N))

    for game in trange(G):
            
        # Sample weights and quota
        weights = (np.random.beta(a=PARAMS['alpha'], b=PARAMS['beta'], 
                                size=(N))) * ((2 * N)-1) + PARAMS['loc']
        quota = utils.gen_quota(N, PARAMS['prob_q'])

        while quota > weights.sum(): # Make sure there is at least one solution
            quota = utils.gen_quota(N, prob_dist=PARAMS['prob_q'])

        # Generate set of winning and minimal winning coalitions
        coals_win = [i for i in combs_arr if weights[tuple([i])].sum() >= quota]
        coals_min_win = utils.get_min_win_coals(coals_win, weights, quota)

        # Store 
        W[game, :] = weights
        X[game, :] = weights / quota
        q[game] = quota
        cset_min_win.append(coals_min_win)

        # Solve
        if PARAMS['shapley']:
            if PARAMS['n_players'] > 9: # Approximate Shapley values
                for s in range(PARAMS['n_resamples']): 
                    sampled_perms = utils.sample_permutations(N, PARAMS['n_samples'])
                    Y_shap[game, :, s] = utils.compute_shapley_vals(N, weights, quota, sampled_perms)
            else: 
                # Compute all permutations
                perms_arr = np.array(list(itertools.permutations(range(N), N)), dtype='object')
                Y_shap[game, :] = utils.compute_shapley_vals(N, weights, quota, perms_arr)

        if PARAMS['banzhaf']:
            Y_banz[game, :] = utils.compute_banzhaf_index(N, weights, quota, coals_win)

        if PARAMS['least_core']:
            sol_stack_lc[game, :] = utils.solve_optimal_payoff(N, coals_min_win)

    if PARAMS['n_players'] > 9: # Make sure to average over the number of resamples
        Y_shap = Y_shap.mean(axis=2)

    # Save dataset
    data_dict = {   'W'              : torch.from_numpy(W),                       
                    'q'              : torch.from_numpy(q),            
                    'cset_min_win'   : cset_min_win,          
                    'X'              : torch.from_numpy(X),  
                    'Y_shap'         : torch.from_numpy(Y_shap),
                    'Y_banz'         : torch.from_numpy(Y_banz),
                    'sol_stack_lc'   : torch.from_numpy(sol_stack_lc),
                }

    with open(f'{PARAMS["path"]}{N}{PARAMS["filename"]}', 'wb') as handle:
        pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print('Data stored succesfully.')