## Dependencies
import itertools
from itertools import chain
from importlib import reload
from tqdm import trange
import numpy as np
import pickle
import torch
import torch.nn as nn

## Utils
import data_gen_utils as utils

test_set_names = ['in_sample',
                  'outdist_1',
                  'outdist_2', 
                  'outdist_3',
                  'outdist_4', 
                ]
                
### OUTER LOOP : Iterate over N ###
for n_players in list(range(4, 11)):

    print(n_players)
    
    PARAMS = {
        'n_games'     : 1000,
        'n_players'   : n_players, 
        'prob_q'      : 'gauss',
        'alpha'       : 1,
        'beta'        : 1,
        'loc'         : 1,
        'path'        : '',
        'filename'    : 'players_test',
        'least_core'  : True,
        'shapley'     : True,
        'banzhaf'     : True,  
        'approximate' : True, # Approximate Shapley value labels
        'n_samples'   : 1000,
        'n_resamples' : 10,
    }

    ### INNER LOOP 1: Iterate over parameter sets ###
    for iter, test_name in enumerate(test_set_names): 

        if test_name == 'in_sample':
            PARAMS.update({
                'ALPHA' : 1,
                'BETA'  : 1, 
                'LOC'   : 1, 
                'prob_q': 'gauss',
            })

        if test_name == 'outdist_1':
            PARAMS.update({
                'ALPHA' : 1,
                'BETA'  : 1, 
                'LOC'   : 2.5*n_players,
            })

        if test_name == 'outdist_2':
            PARAMS.update({
                'ALPHA' : 8,
                'BETA'  : 12, 
                'LOC'   : 2,
            })

        if test_name == 'outdist_3':
            PARAMS.update({
                'ALPHA' : 7,
                'BETA'  : 1.5, 
                'LOC'   : 1.5*n_players,
            })

        
        if test_name == 'outdist_4':
            PARAMS.update({
                'ALPHA' : 12,
                'BETA'  : 8, 
                'LOC'   : 3*n_players,
            })

        N = PARAMS['n_players']
        G = PARAMS['n_games']

        # Get all combinations and permutations of these N players
        combs_arr = np.array(list(chain.from_iterable(itertools.combinations(np.arange(N), k) \
                    for k in range(1, N + 1))), dtype='object')

        # Create arrays
        W = np.zeros((G, N))
        X = np.zeros((G, N))
        q = np.zeros((G))
        sol_stack_lc = np.zeros((G, N + 1))   
        Y_banz = np.zeros((G, N))
        cset_min_win = [] 

        if PARAMS['n_players'] > 9: # Approximate Shapley values
            Y_shap = np.zeros((G, N, PARAMS['n_resamples']))
        else: 
            Y_shap = np.zeros((G, N))

        ### INNER LOOP 2: Iterate over games ###
        for game in trange(G):
                
            # Sample weights and quota
            weights = (np.random.beta(a=PARAMS['alpha'], b=PARAMS['beta'], 
                                    size=(N))) * ((2 * N)-1) + PARAMS['loc']
            quota = utils.gen_quota(N, PARAMS['prob_q'])

            while quota > weights.sum(): # Make sure there is at least one solution
                quota = utils.gen_quota(N, prob_dist=PARAMS['prob_q'])

            # Generate set of winning and minimal winning coalitions
            coals_win = [i for i in combs_arr if weights[tuple([i])].sum() >= quota]
            coals_min_win = utils.get_min_win_coals(coals_win, weights, quota)

            # Store 
            W[game, :] = weights
            X[game, :] = weights / quota
            q[game] = quota
            cset_min_win.append(coals_min_win)

            # Solve
            if PARAMS['shapley']:
                if PARAMS['n_players'] > 9: # Approximate Shapley values
                    for s in range(PARAMS['n_resamples']): 
                        sampled_perms = utils.sample_permutations(N, PARAMS['n_samples'])
                        Y_shap[game, :, s] = utils.compute_shapley_vals(N, weights, quota, sampled_perms)
                else: 
                    # Compute all permutations
                    perms_arr = np.array(list(itertools.permutations(range(N), N)), dtype='object')
                    Y_shap[game, :] = utils.compute_shapley_vals(N, weights, quota, perms_arr)

            if PARAMS['banzhaf']:
                Y_banz[game, :] = utils.compute_banzhaf_index(N, weights, quota, coals_win)

            if PARAMS['least_core']:
                sol_stack_lc[game, :] = utils.solve_optimal_payoff(N, coals_min_win)

        if PARAMS['n_players'] > 9: # Make sure to average over the number of resamples
            Y_shap = Y_shap.mean(axis=2)

        # Save dataset
        data_dict = {   'W'              : torch.from_numpy(W),                       
                        'q'              : torch.from_numpy(q),            
                        'cset_min_win'   : cset_min_win,          
                        'X'              : torch.from_numpy(X),  
                        'Y_shap'         : torch.from_numpy(Y_shap),
                        'Y_banz'         : torch.from_numpy(Y_banz),
                        'sol_stack_lc'   : torch.from_numpy(sol_stack_lc),
                    }

        with open(f'{PARAMS["path"]}{N}{PARAMS["filename"]}_{test_name}.pickle', 'wb') as handle:
            pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print(f'{n_players} players data stored succesfully.')
    
