import copy, os, json
from running.shorten_names import shorten_names


def get_param_list(param_grid):
    
    for key in param_grid.keys():
        if type(param_grid[key])==str:
            param_grid[key] = [param_grid[key]]

    n = 1
    for key in param_grid:
        n *= len(param_grid[key])

    param_list = []
    for i in range(n):
        param_dict = {}
        k = i
        for key in param_grid:
            key_len = len(param_grid[key])
            param_dict[key] = param_grid[key][k % key_len]
            k = k // key_len
        param_list.append(param_dict)

    return param_list


def get_joined_param_list(grid_list):
    param_list = []
    for grid in grid_list:
        param_list.extend(get_param_list(grid))
    return param_list


def get_config_list(config, folder_type):
    config_list = []
    
    if folder_type == 'learning':
        env_param_list = get_joined_param_list(config['envs'])
        learning_param_list = get_joined_param_list(config['learning'])
        for env_params in env_param_list:
            for learning_params in learning_param_list:
                config_copy = copy.deepcopy(config)
                config_copy['envs'] = env_params
                config_copy['learning'] = learning_params
                config_list.append(config_copy)
    
    elif folder_type == 'testing':
        testing_param_list = get_joined_param_list(config['testing'])
        for params in testing_param_list:
            config_copy = copy.deepcopy(config)
            config_copy['testing'] = params
            config_list.append(config_copy)
    
    return config_list


def get_folder_names(config_list, folder_type):
    config_keys = ['envs', 'learning'] if folder_type == 'learning' else ['testing']
    
    folder_names = []
    for config in config_list:
        
        folder_name = ''
        for key in config_keys:
            subconfig = config[key]
            
            for sub_key, sub_value in subconfig.items():
                if sub_key in ['env_name', 'alg_name']:
                    name_key, name_value = sub_key, sub_value
            
            if len(folder_name) > 0:
                folder_name += '_'
                
            folder_name += shorten_names(name_value)
            
            for sub_key, sub_value in subconfig.items():
                if sub_key != name_key:
                    #values = [config[key][sub_key] 
                    #          for config in config_list if config[key][name_key] == name_value]
                    #if len(set(values)) > 1:
                    folder_name += '_' + shorten_names(sub_key) + shorten_names(sub_value)
        
        folder_names.append(folder_name)
    
    return folder_names


def create_sub_path(path, folder_name):
    sub_path = os.path.join(path, folder_name)
    if not os.path.exists(sub_path):
        os.mkdir(sub_path)
    return sub_path


def get_skip(sub_path, sub_config, folder_type):
    sub_config_path = os.path.join(sub_path, 'config.json')
    if os.path.exists(sub_config_path):
        prev_config_file = open(sub_config_path)
        prev_config = json.load(prev_config_file)
        trs_exists = os.path.exists(os.path.join(sub_path, 'trs.npy'))
        if folder_type == 'learning':
            same_envs = sub_config['envs'] == prev_config['envs']
            same_learning = sub_config['learning'] == prev_config['learning']
            return same_envs and same_learning and trs_exists
        elif folder_type == 'testing':
            same_testing = sub_config['testing'] == prev_config['testing']
            return same_testing and trs_exists
    else:
        return False
    

def get_sub_paths(path, folder_names, sub_configs, folder_type, agent_index):
    sub_paths, skips = [], []
    
    for folder_name, sub_config in zip(folder_names, sub_configs):
        sub_path = create_sub_path(path, agent_index + folder_name)

        if folder_type == 'learning':
            for seed in sub_config['seeds']:
                sub_sub_path = create_sub_path(sub_path, str(seed))
                sub_paths.append(sub_sub_path)
                sub_config[f'seeds'] = [seed]
                skips.append(get_skip(sub_sub_path, sub_config, folder_type))
                
                with open(os.path.join(sub_sub_path, 'config.json'), 'w') as out_file:
                    json.dump(sub_config, out_file)
                    
        elif folder_type == 'testing':
            sub_config['agent_index'] = agent_index
            sub_paths.append(sub_path)
            skips.append(get_skip(sub_path, sub_config, folder_type))
            
            with open(os.path.join(sub_path, 'config.json'), 'w') as out_file:
                json.dump(sub_config, out_file)

    return sub_paths, skips


def prepare_folders(path, folder_type='learning'):
    file = open(os.path.join(path, 'config.json'))
    config = json.load(file)
    
    sub_configs = get_config_list(config, folder_type)
    folder_names = get_folder_names(sub_configs, folder_type)
    
    if folder_type == 'learning':
        sub_paths, skips = get_sub_paths(path, folder_names, sub_configs, folder_type, '')
    elif folder_type == 'testing':
        u_sub_paths, u_skips = get_sub_paths(path, folder_names, sub_configs, folder_type, 'u')
        v_sub_paths, v_skips = get_sub_paths(path, folder_names, sub_configs, folder_type, 'v')
        sub_paths, skips = u_sub_paths + v_sub_paths, u_skips + v_skips
    else:
        print('Inncorrect floder_type!')
        
    return sub_paths, skips
