import os

import numpy as np
import argparse
from main import main as main_dqn
from copy import deepcopy

DQN_ARGS = {'exploration_strategy': 'epsilon-greedy', 'algorithm': 'q-learning', 'epsilon': 0.1, 'reps': 1, 'use_gpu': True, 'replay_buffer_size': int(1e6),
    'gpu_id': 0, 'agent_type': 'non-linear', 'seed': 1, 'batch_size': 64, 'num_timesteps': int(5e7), 'num_agent_train_steps_per_iter': 1,
    'target_update_freq': 1000, 'learning_starts': 50000, 'nn_size': 64, 'n_layers': 2, 'gamma': 0.99, 'eta': 12.0, 'temp': 1.0, 'step_size': 0.001,
    'only_store_rewards': True, 'double_q': False, 'exploration_schedule': 0, 'env_name': 'deep_sea/0', 'save_path': 'DQN_RESULTS', 'log_interval': 10000,
    'td_error_mg': 1.0, 'td_error_mg_lr': .9, 'td_error_mg_epsilon': 0.0001, 'td_error_scheduling': False}

### Experimental Parameters ###

SEEDS = list(range(1, 11))
ENV_IDS = ['AsterixNoFrameskip-v0', 'BreakoutNoFrameskip-v0', 'FreewayNoFrameskip-v0']

ENV_TIMESTEPS = {
        ENV_IDS[0]: int(5e6),
        ENV_IDS[1]: int(5e6),
        ENV_IDS[2]: int(5e6),
        }

EXP_STRATEGIES = ['epsilon-greedy', 'softmax', 'resmax']
STEP_SIZES = [1e-4]

EXP_VALUES = {
        EXP_STRATEGIES[0]: [0.1, 0.3, 0.5],
        EXP_STRATEGIES[1]: [(1/2**n) for n in [16, 8, 0]],
        EXP_STRATEGIES[2]: [(1/2**n) for n in [24, 16, 8]],
        }

EXP_PARAM_NAME = {
        EXP_STRATEGIES[0]: 'epsilon',
        EXP_STRATEGIES[1]: 'temp',
        EXP_STRATEGIES[2]: 'eta',
        }

###############################

def to_command(dic):
    command = 'python main.py'
    for key, value in dic.items():
        if isinstance(value, bool):
            if value:
                command += ' --{}'.format(key)
        else:
            command += ' --{} {}'.format(key, value)
    return command + '\n'

def get_args():
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Final Experiments for DQN')

    parser.add_argument('--output_type',  default='text_file', type=str, choices=('text_file', 'failed_text_file'), help="")

    parser.add_argument('--output_path', default='dqn_atari_experiments', type=str,
            nargs='?', help="The path to save the output file of this script")

    parser.add_argument('--save_path', default='dqn_atari_results', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments (This path will be passed to main.py as an argument)")
    
    return vars(parser.parse_args())

def main(args):

    bash_file_commands = []
    for env_id in ENV_IDS:
        for exp_strategy in EXP_STRATEGIES:
            for step_size in STEP_SIZES:
                for exp_value in EXP_VALUES[exp_strategy]:
                    for seed in SEEDS:
                        dqn_args = deepcopy(DQN_ARGS)
                        dqn_args['verbose'] = 0
                        dqn_args['env_name'] = env_id
                        dqn_args['seed'] = seed
                        dqn_args['step_size'] = step_size
                        dqn_args['num_timesteps'] = ENV_TIMESTEPS[env_id]
                        dqn_args['exploration_strategy'] = exp_strategy
                        dqn_args[EXP_PARAM_NAME[exp_strategy]] = exp_value
                        dqn_args['save_path'] = args['save_path']
                        if args['output_type'] == 'failed_text_file':
                            save_dir = os.path.join(dqn_args['save_path'], env_id, dqn_args['exploration_strategy'], str(dqn_args['step_size']), str(float(exp_value)), str(dqn_args['seed']))
                            failed = True
                            if os.path.exists(save_dir):
                                child_dirs = os.listdir(save_dir)

                                is_seed_dir = any(['episode_returns' in child_dir for child_dir in child_dirs])
                                if is_seed_dir:
                                    for recorded_file in child_dirs:
                                        if 'episode_returns' in recorded_file:
                                            failed = len(np.load(os.path.join(save_dir, recorded_file))) != int(dqn_args['num_timesteps']//dqn_args['log_interval'])
                                            break
                            if failed:
                                print('save_dir: ', save_dir)
                                print('Failed: ', to_command(dqn_args)) 
                                bash_file_commands.append(to_command(dqn_args))
  
                        if args['output_type'] == 'text_file':
                            bash_file_commands.append(to_command(dqn_args))

    with open(args['output_path'] + '.txt', 'w') as output: # This .txt file can use a command list for GNU Parallel
        for row in bash_file_commands:
            output.write(str(row))
 
if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)
