## Dependencies
import itertools
from itertools import chain
from importlib import reload
from tqdm import trange
import numpy as np
import pickle
import seaborn as sns
import torch
import torch.nn as nn
import data_gen_utils as utils

test_set_names = ['in_sample',
                  'outdist_1',
                  'outdist_2', 
                  'outdist_3',
                  'outdist_4', 
                ]

PARAMS = {
    'player_list'   : np.arange(4, 11),
    'n_games_play'  : 1000,   # Number of games to generate per player
    'max_len'       : 20,     # Sequence length
    'prob_q'        : 'gauss',
    'log'           : False,
    'least_core'    : True,
    'shapley'       : True,
    'banzhaf'       : True,  
    'n_samples'     : 1000, # Number of games per player
    'n_resamples'   : 10,
}

N_max = PARAMS['max_len']
G = int(PARAMS['n_games_play'] * len(PARAMS['player_list']))

# Get number of samples per player based on proportions
n_player_repeats = np.repeat(PARAMS['player_list'], PARAMS['n_games_play']).astype(int)

# Get all combinations (coalitions)
combs_dict = {}
for play in PARAMS['player_list']:
    combs = np.array(list(chain.from_iterable(itertools.combinations(np.arange(play), k) for k in range(1, play + 1))), dtype='object')
    combs_dict.update({play : combs})

perms_dict = {}
for play in range(4, 10): # Compute the set of all perms for up to 9 players
    perms = np.array(list(itertools.permutations(range(int(play)), int(play))), dtype='object')
    perms_dict.update({play : perms})

# Generate all test datasets
for iter, test_name in enumerate(test_set_names):

    print(f'Generating {test_name} dataset')
    # Make arrays
    W = np.zeros((G, N_max))
    X = np.zeros((G, N_max))
    q = np.zeros((G))
    sol_stack = np.zeros((G, N_max + 1)) # Labels
    Y_shap = np.zeros((G, N_max))
    Y_banz = np.zeros((G, N_max))
    player_to_index = {} # Map players to shuffled index in array

    for game in trange(G):

        N = n_player_repeats[game]

        if test_name == 'in_sample':
            PARAMS.update({
                'ALPHA' : 1,
                'BETA'  : 1, 
                'LOC'   : 1, 
                'prob_q': 'gauss',
            })

        if test_name == 'outdist_1':
            PARAMS.update({
                'ALPHA' : 1,
                'BETA'  : 1, 
                'LOC'   : 2.5*N,
            })

        if test_name == 'outdist_2':
            PARAMS.update({
                'ALPHA' : 8,
                'BETA'  : 12, 
                'LOC'   : 2,
            })

        if test_name == 'outdist_3':
            PARAMS.update({
                'ALPHA' : 7,
                'BETA'  : 1.5, 
                'LOC'   : 1.5*N,
            })


        if test_name == 'outdist_4':
            PARAMS.update({
                'ALPHA' : 12,
                'BETA'  : 8, 
                'LOC'   : 3*N,
            })

        # Define the weighted voting game
        weights = (np.random.beta(a=PARAMS['ALPHA'], b=PARAMS['BETA'], 
                                size=(N))) * ((2*N)-1) + PARAMS['LOC']
        quota = utils.gen_quota(N, PARAMS['prob_q'])

        while quota > weights.sum(): # Make sure there is at least one solution
            quota = utils.gen_quota(N, prob_dist=PARAMS['prob_q'])

        # Generate set of winning and minimal winning coalitions
        coals_win = [i for i in combs_dict[N] if weights[tuple([i])].sum() >= quota]
        coals_min_win = utils.get_min_win_coals(coals_win, weights, quota)

        # Solve
        if PARAMS['least_core']:
            leastcore_sol = utils.solve_optimal_payoff(N, coals_min_win)

        if PARAMS['shapley']:
            if N > 9: # Approximate 
                shapley_sol_tmp = np.zeros((PARAMS['n_resamples'], N))
                for s in range(PARAMS['n_resamples']): 
                    sampled_perms = utils.sample_permutations(N, PARAMS['n_samples'])
                    shapley_sol_tmp[s, :] = utils.compute_shapley_vals(N, weights, quota, sampled_perms)
                # Average over resamples
                shapley_sol = shapley_sol_tmp.mean(axis=0)
            
            else:
                # Use all permutations 
                shapley_sol = utils.compute_shapley_vals(N, weights, quota, perms_dict[N])
                
        if PARAMS['banzhaf']:
            banzhaf_sol = utils.compute_banzhaf_index(N, weights, quota, coals_win)

        # Generate random permutation; store elements at these locs
        index = np.random.permutation(N_max)[:N] 
        player_to_index[game] = dict(zip(list(range(N)), list(index)))
        
        # Store inputs
        W[game, index] = weights
        X[game, index] = weights / quota
        q[game] = quota

        # Store solutions at random indices
        sol_stack[game, index] = leastcore_sol[:-1] 
        sol_stack[game, N_max] = leastcore_sol[-1] # Store epsilon as the last element in the array
        Y_shap[game, index] = shapley_sol
        Y_banz[game, index] = banzhaf_sol

        if PARAMS['log']:
            print(f'Game number {game}')
            print(f'G = [ w = {weights} ; q = {quota} ]')
            print(f'Payoffs full game [{N} players] y = {leastcore_sol[:-1]}, eps = {leastcore_sol[-1]} \n\n')

    # Store
    data_dict = { 'W'                : torch.from_numpy(W),                       
                  'q'                : torch.from_numpy(q),            
                  'X'                : torch.from_numpy(X),  
                  'sol_stack_lc'     : torch.from_numpy(sol_stack),
                  'Y_shap'           : torch.from_numpy(Y_shap),
                  'Y_banz'           : torch.from_numpy(Y_banz),
                }

    with open(f'{test_name}.pickle', 'wb') as handle:
        pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)