import sys
import os
import numpy as np

class Parameter():
    def __init__(self, partial_run={}):
        self.param = {'algo': ['CFR', 'CFR+', 'QFR', 'MMD', 'DCFR', 'PCFR', 'BFTRL', 'BOMD'],
                      'game': ['liars_dice(dice_sides=4)', 'leduc_poker(suit_isomorphism=True)', 'kuhn_poker', 'dark_hex(board_size=2,gameversion=adh)'],
                        'eta': [0.1, 0.01, 0.005, 0.001, 0.0001],
                        'tau': [0.1, 0.05, 0.01, 0.001, 0.0001, 0.0],
                        'gamma': [0.1, 0.01, 0.001, 0.0001, 0.0],
                        'regularizer': ['Euclidean', 'Entropy'],
                        'feedback': ['counterfactual', 'Q', 'traj-Q'],
                        'sample': [True, False],
                        'T': [100000],
                        'out-freq': [100],
                      }

        for i in range(len(self.param['game'])):
            self.param['game'][i] = '\"' + self.param['game'][i] + '\"'
        self.mask_list = [
                    ['game', ['\"liars_dice(dice_sides=6)\"'], 'sample', [True]],
                    #['game', [self.param['game'][1]], 'T', [100]],
                    #['game', [self.param['game'][1]], 'out-freq', [1]],
                    #['game', [self.param['game'][0]], 'T', [50]],
                    #['game', [self.param['game'][0]], 'out-freq', [1]],
                    ['algo', ['BFTRL', 'BOMD'], 'tau', [0.0]],
                    ['algo', ['BFTRL', 'BOMD'], 'gamma', [0.0]],
                    ['algo', ['BFTRL', 'BOMD'], 'regularizer', ['Entropy']],
                    ['algo', ['BFTRL', 'BOMD'], 'sample', [True]],
                    ['algo', ['MMD', 'BFTRL', 'BOMD'], 'gamma', [0.0]],
                    ['algo', ['CFR', 'CFR+', 'PCFR'], 'eta', [0.0]],
                    ['algo', ['CFR', 'CFR+', 'PCFR'], 'tau', [0.0]],
                    ['algo', ['PCFR'], 'gamma', [0.0]],
                    ['algo', ['CFR', 'CFR+'], 'gamma', [0.1, 0.01, 0.001, 0.0001, 0.0]],
                    ['algo', ['DCFR'], 'eta', [-15, 15] + [_*0.5 - 2.5 for _ in range(11)]],
                    ['algo', ['DCFR'], 'tau', [-15, 15] + [_*0.5 - 2.5 for _ in range(11)]],
                    ['algo', ['DCFR'], 'gamma', [0.5, 1, 1.5, 2, 2.5, 3]],
                    ['algo', ['DCFR', 'PCFR'], 'sample', [False]],
                    ['algo', ['CFR', 'CFR+', 'DCFR', 'PCFR'], 'regularizer', ['Euclidean']],
                    ['algo', ['CFR', 'CFR+', 'DCFR', 'PCFR'], 'feedback', ['counterfactual']],
                    #['algo', ['DCFR'], 'feedback', ['counterfactual']],
                    ['sample', [True], 'T', [1000000]],
                    ['sample', [True], 'out-freq', [100]],
                    ['sample', [True], 'feedback', ['counterfactual']],
                     ] # type_A, key_list_A, type_B, key_list_B.  Once key_list_A is selected, key_list_B is the only choice
        #for i in range(len(self.mask_list)):
        #    self.mask_list[i][1] = set(self.mask_list[i][1])
        #    self.mask_list[i][3] = set(self.mask_list[i][3])

        self.param_list = []
        self.Set_Param(0, {})

        self.n = len(self.param_list)

        for i in range(self.n):
            if self.param_list[i]['game'] == self.param['game'][1] and not self.param_list[i]['sample']:
                self.param_list[i]['T'] = 100
                self.param_list[i]['out-freq'] = 1
            if self.param_list[i]['game'] == self.param['game'][0] and not self.param_list[i]['sample']:
                self.param_list[i]['T'] = 50
                self.param_list[i]['out-freq'] = 1

        self.active_idx = []
        for i in range(self.n):
            flag = True
            for k in partial_run:
                if self.param_list[i][k] not in partial_run[k]:
                    flag = False
                    break
            if flag:
                self.active_idx.append(i)
        self.n = len(self.active_idx)
        print(self.n)

    def Set_Param(self, k, param):
        if k == len(list(self.param.keys())):
            self.param_list.append(param.copy())
            return
        
        name = list(self.param.keys())[k]
        param_list = self.param[name]

        if name in param.keys():
            self.Set_Param(k+1, param)
            return

        for mask in self.mask_list:
            if mask[0] in param.keys() and param[mask[0]] in mask[1] and name == mask[2]:
                param_list = mask[3]
            if mask[0] in param.keys() and param[mask[0]] in mask[1] and mask[2] in param.keys() and param[mask[2]] not in mask[3]:
                return
        for v in param_list:
            param[name] = v
            self.Set_Param(k+1, param)
        del param[name]

    def get_param(self, idx):
        if type(idx) == int:
            return self.param_list[self.active_idx[idx]]
        elif type(idx) == dict:
            param_idx_list = []
            for i in range(self.n):
                flag = True
                for k, v in idx.items():
                    if self.param_list[i][k] != v:
                        flag = False
                        break
                if flag:
                    param_idx_list.append(i)
            return np.array(param_idx_list)
    

if __name__ == '__main__':
    idx_start = int(sys.argv[1])
    current_dir = os.getcwd()
    game_prefix = f'{current_dir}/game_instances/'
    param = Parameter()
    
    for idx in range(idx_start, param.n, int(sys.argv[2])):
        param_str = ' '
        for k, v in param.get_param(idx).items():
            if type(v) == bool:
                if v:
                    param_str += f'--{k} '
            else:
                param_str += f'--{k} {v} '

        command_line = f'python {current_dir}/main.py ' + param_str + f' --json {current_dir}/result/' + str(param.active_idx[idx]) + '.json'
        print(command_line)
        os.system(command_line)
