import os

import numpy as np
import argparse
from main import main as main_dqn
from copy import deepcopy

DQN_ARGS = {'exploration_strategy': 'epsilon-greedy', 'algorithm': 'expected_sarsa', 'epsilon': 0.1, 'reps': 1, 'use_gpu': False, 'gpu_id': 0, 'agent_type': 'non-linear', 'seed': 1, 'batch_size': 64, 'num_timesteps': int(5e7), 'num_agent_train_steps_per_iter': 1, 'target_update_freq': 1000, 'learning_starts': 5000, 'nn_size': 64, 'n_layers': 2, 'gamma': 0.99, 'eta': 12.0, 'temp': 1.0, 'step_size': 0.001, 'only_store_rewards': True, 'double_q': False, 'exploration_schedule': 0, 'env_name': 'deep_sea/0', 'save_path': 'DQN_RESULTS', 'log_interval': 10000, 'td_error_mg': 1.0, 'td_error_mg_lr': .9, 'td_error_mg_epsilon': 0.0001, 'td_error_scheduling': False, 'replay_buffer_size': 50000, 'use_normalization_scheme': False}

### Experimental Parameters ###

SEEDS = list(range(1, 31))
ENV_IDS = ['MountainCar-v0', 'LunarLander-v2', 'Acrobot-v1']
# OUTSIDE_VALUES = [0.1, 0.05, 0.01]
# PORTION_DECAYS = [0.05, 0.1, 0.2, 0.3]

ENV_TIMESTEPS = {
        ENV_IDS[0]: int(5e5),
        ENV_IDS[1]: int(10e5),
        ENV_IDS[2]: int(5e5),
        }

#EXP_STRATEGIES = ['softmax', 'resmax']
EXP_STRATEGIES = ['epsilon-greedy', 'resmax', 'softmax']
STEP_SIZES = [1e-4]

EXP_VALUES = {
        EXP_STRATEGIES[0]: [0.1, 0.2, 0.3, 0.4, 0.5],
        EXP_STRATEGIES[1]: [1/2**n for n in [4, 6, 8, 10, 12]],
        EXP_STRATEGIES[2]: [1/2**n for n in [0, 2, 4, 6, 8, 10, 12]],
        }

EXP_PARAM_NAME = {
        EXP_STRATEGIES[0]: 'epsilon',
        EXP_STRATEGIES[1]: 'eta',
        EXP_STRATEGIES[2]: 'temp',
        }


# TARGET_UPDATE_FREQ = [8, 32, 128, 512, 1024, 4096]
REPLAY_BUFFER_SIZE = [1000, 5000, 10000, 20000, 50000, 100000]

###############################

def to_command(dic):
    command = 'python main.py'
    for key, value in dic.items():
        if isinstance(value, bool):
            if value:
                command += ' --{}'.format(key)
        else:
            command += ' --{} {}'.format(key, value)
    return command + '\n'

def get_args():
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Final Experiments for DQN')

    parser.add_argument('--output_type',  default='text_file', type=str, choices=('text_file', 'failed_text_file'), help="")

    parser.add_argument('--output_path', default='dqn_experiments_replay_buffer_study', type=str,
            nargs='?', help="The path to save the output file of this script")

    parser.add_argument('--save_path', default='dqn_results_replay_buffer_study', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments (This path will be passed to main.py as an argument)")
    
    return vars(parser.parse_args())

def main(args):

    bash_file_commands = []
    for env_id in ENV_IDS:
        for exp_strategy in EXP_STRATEGIES:
            for step_size in STEP_SIZES:
                for exp_value in EXP_VALUES[exp_strategy]:
                    for replay_buffer_size in REPLAY_BUFFER_SIZE:
                    # for target_update_freq in TARGET_UPDATE_FREQ:
                        for seed in SEEDS:
                            dqn_args = deepcopy(DQN_ARGS)
                            dqn_args['verbose'] = 0
                            dqn_args['env_name'] = env_id
                            dqn_args['seed'] = seed
                            dqn_args['step_size'] = step_size
                            dqn_args['num_timesteps'] = ENV_TIMESTEPS[env_id]
                            dqn_args['exploration_strategy'] = exp_strategy
                            dqn_args[EXP_PARAM_NAME[exp_strategy]] = exp_value
                            dqn_args['save_path'] = args['save_path']
                            dqn_args['replay_buffer_size'] = replay_buffer_size
                            # dqn_args['target_update_freq'] = target_update_freq
                            if args['output_type'] == 'failed_text_file':
                                save_dir = os.path.join(dqn_args['save_path'], env_id, dqn_args['exploration_strategy'], str(dqn_args['step_size']), str(float(exp_value)), str(dqn_args['seed']))
                                failed = True
                                if os.path.exists(save_dir):
                                    child_dirs = os.listdir(save_dir)

                                    is_seed_dir = any(['episode_returns' in child_dir for child_dir in child_dirs])
                                    if is_seed_dir:
                                        for recorded_file in child_dirs:
                                            if 'episode_returns' in recorded_file:
                                                failed = len(np.load(os.path.join(save_dir, recorded_file))) != int(dqn_args['num_timesteps']//dqn_args['log_interval'])
                                                break
                                if failed:
                                    print('save_dir: ', save_dir)
                                    print('Failed: ', to_command(dqn_args)) 
                                    bash_file_commands.append(to_command(dqn_args))
    
                            if args['output_type'] == 'text_file':
                                bash_file_commands.append(to_command(dqn_args))

    with open(args['output_path'] + '.txt', 'w') as output: # This .txt file can use a command list for GNU Parallel
        for row in bash_file_commands:
            output.write(str(row))
 
if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)
