"""
Generates all tabular experiments. Can either run them using this the --execute flag or output to a .txt file
"""

import numpy as np
import argparse
from main import main as main_tabular
from copy import deepcopy
from tqdm import tqdm
from multiprocessing import Pool
from itertools import product
from pathlib import Path

AGENT_ARGS = {
    'exploration_strategy': 'epsilon-greedy', 
    'algorithm': 'q-learning', 
    'epsilon': 0.1, 
    'seed': 1, 
    'num_timesteps': 800000, 
    'gamma': 1, 
    'eta': 12.0, 
    'temp': 1.0, 
    'initial_optimism' : 0, # set this to different values for experiments. No need for a flag atm
    'only_store_rewards': True,  
    'env_name': 'deep_sea/0', 
    'save_type' : 'reward_per_step',
    'log_interval' : 1000000,
    'eval_episodes_num' : 30,
    'save_path': 'results',
    'initial_optimism' : 0
    }

### Experimental Parameters ###
SEEDS = list(range(0, 30))
ENV_IDS =  ['riverswim_variants:stochastic-riverswim-v0', 'gym_riverswim:riverswim-v0']  

NUM_TIMESTEPS = 800000 # constant for all environments

ENV_TIMESTEPS = {
        'deep_sea/0': NUM_TIMESTEPS,
        'gym_riverswim:riverswim-v0': NUM_TIMESTEPS,
        'riverswim_variants:stochastic-riverswim-v0': NUM_TIMESTEPS,
        'gym_exploration:HardSquare-v0': 100000}

EXP_STRATEGIES = ['mellowmax', 'resmax', 'softmax', 'epsilon-greedy']

SAVE_PATH = {
        'deep_sea/0': 'results_deepsea',
        'gym_riverswim:riverswim-v0': 'results_riverswim',
        'riverswim_variants:stochastic-riverswim-v0': 'results_stochastic_rs',
        'gym_exploration:HardSquare-v0': 'results_hardsquare_step_size'
    }

ALGORITHMS = ['q-learning', 'expected-sarsa']

STEP_SIZES = [0.1]

EXP_VALUES = {
        'epsilon-greedy':  np.arange(0, 1.1, 0.1) ,
        'softmax':  [2**n for n in range(-4, 9, 1)],
        'mellowmax':  [2**n for n in range(-4, 9, 1)],
        'resmax':  [2**n for n in range(0, 13, 1)], 
        }


EXP_PARAM_NAME = {
        'epsilon-greedy': 'epsilon',
        'softmax': 'temp',
        'mellowmax': 'omega',
        'resmax': 'eta',
        }

GAMMAS = {
    'deep_sea/0': 1,
    'gym_riverswim:riverswim-v0': 1,
    'riverswim_variants:stochastic-riverswim-v0': 1,
    'riverswim_variants:scaled-riverswim-v0' : 1,
    'gym_exploration:HardSquare-v0' : 0.95
}


###############################

def to_command(dic):
    command = 'python3 main.py'
    for key, value in dic.items():
        if key == "only_store_rewards" or  key == 'save_policy' :
            command += ' --{}'.format(key)
        else:
            if key == 'epsilon': # epsilon has fp issues
                value = np.round(value, 1)
            command += ' --{} {}'.format(key, value)

    return command + '\n'

def get_args(): 
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Initial Experiments for Tabular')

    parser.add_argument('--output_type',  default= "compute_canada_format", type=str, choices=("bash_file", "compute_canada_format", "execute"), help="What should be the output of this file: bash_file: generates a bash file of the commands that you should run in linux for all the experiments | compute_canada_format: generates a file that can be used to run all the experiments on compute canada | execute: will run all experiments on your computer")

    parser.add_argument('--output_path', default='experiments', type=str,
            nargs='?', help="The path to save the output file of this script")

    parser.add_argument('--save_path', default='results', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments (This path will be passed to main.py as an argument)")

    parser.add_argument('--num_threads', default=1, type=int,
            nargs='?', help="How many concurrent experiments to run")

    return vars(parser.parse_args())


def main(args):
    def output(args, tabular_args, bash_file_commands):
        if args['output_type'] == 'bash_file':
            bash_file_commands.append(to_command(tabular_args))
        
        elif args['output_type'] == 'execute' and args['num_threads'] == 1:
            main_tabular(tabular_args)
            with open('experiments_done_so_far.txt', 'w+') as output:
                output.write(to_command(tabular_args))
        
        elif args['output_type'] == 'execute' and args['num_threads'] != 1:
            all_tabular_args.append(tabular_args)
        
        elif args['output_type'] == 'compute_canada_format':
            bash_file_commands.append(to_command( tabular_args))

    all_tabular_args = list()
    bash_file_commands = []
    run = 0

    prod = product(
        ENV_IDS, 
        EXP_STRATEGIES,
        ALGORITHMS,
        STEP_SIZES,
        SEEDS)

    for config in prod: #general configuations 
        env_id = config[0]
        exp_strategy = config[1]
        algorithm = config[2]
        step_size = config[3]
        seed = config[4]

        tabular_args = deepcopy(AGENT_ARGS)
        tabular_args['env_name'] = env_id
        tabular_args['algorithm'] = algorithm
        tabular_args['seed'] = int(seed)
        tabular_args['step_size'] = step_size
        tabular_args['exploration_strategy'] = exp_strategy
        tabular_args['num_timesteps'] = ENV_TIMESTEPS[env_id]
        tabular_args['gamma'] = GAMMAS[env_id]

        for exp_val in EXP_VALUES[exp_strategy]:
            tabular_args[EXP_PARAM_NAME[exp_strategy]] = exp_val
            tabular_args['save_path'] = Path(SAVE_PATH[env_id])/str(run)
            output(args, tabular_args, bash_file_commands)
            run += 1

    # run multithreaded experiments
    if args['output_type'] == 'execute' and  args['num_threads'] != 1:
        with Pool(args['num_threads']) as p:
            r = list(tqdm(p.imap(main_tabular, all_tabular_args), total=len(all_tabular_args)))

    if args['output_type'] == 'bash_file':
        with open(args['output_path'] + '.bash', 'w') as output:
            for row in bash_file_commands:
                output.write(str(row))

    elif args['output_type'] == 'compute_canada_format':
        with open(args['output_path'] + '.txt', 'w') as output: # This .txt file can use a command list for GNU Parallel
            for row in bash_file_commands:
                output.write(str(row))
 

if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)
