import os

import numpy as np
import argparse
from main import main as main_dqn
from copy import deepcopy

DQN_ARGS = {'exploration_strategy': 'epsilon-greedy', 'algorithm': 'q-learning', 'epsilon': 0.1, 'reps': 1,
            'use_gpu': True, 'replay_buffer_size': int(1e6),
            'gpu_id': 0, 'agent_type': 'non-linear', 'seed': 1, 'batch_size': 64, 'num_timesteps': int(5e7),
            'num_agent_train_steps_per_iter': 1,
            'target_update_freq': 1000, 'learning_starts': 50000, 'nn_size': 64, 'n_layers': 2, 'gamma': 0.99,
            'eta': 12.0, 'temp': 1.0, 'step_size': 0.001,
            'only_store_rewards': True, 'double_q': False, 'exploration_schedule': 0, 'env_name': 'deep_sea/0',
            'save_path': 'DQN_RESULTS', 'log_interval': 100000,
            'td_error_mg': 1.0, 'td_error_mg_lr': .9, 'td_error_mg_epsilon': 0.0001, 'td_error_scheduling': False}

### Experimental Parameters ###

SEEDS = list(range(1, 11))

ENV_IDS = ['BreakoutNoFrameskip-v0', 'FreewayNoFrameskip-v0', 'AsterixNoFrameskip-v0', 'PitfallNoFrameskip-v0',
           'VentureNoFrameskip-v0', 'GravitarNoFrameskip-v0', 'PrivateEyeNoFrameskip-v0']

OUTSIDE_VALUES = [0.1, 0.05, 0.01]
PORTION_DECAYS = [0.05, 0.1, 0.2, 0.3]

ENV_TIMESTEPS = {
    ENV_IDS[0]: int(5e6),
    ENV_IDS[1]: int(5e6),
    ENV_IDS[2]: int(5e6),
    ENV_IDS[3]: int(5e6),
    ENV_IDS[4]: int(5e6),
    ENV_IDS[5]: int(5e6),
    ENV_IDS[6]: int(5e6),
    #        ENV_IDS[2]: int(5e6),
}

EXP_STRATEGIES = ['epsilon-greedy', 'softmax', 'resmax', 'mellowmax']
STEP_SIZES = [1e-4]

EXP_VALUES = {
    EXP_STRATEGIES[0]: [0.01, 0.1, 0.3, 0.5],
    EXP_STRATEGIES[1]: [(1/2**n) for n in [24, 20, 16, 12, 8, 4, 0]],
    EXP_STRATEGIES[2]: [(1/2**n) for n in [24, 20, 16, 12, 8, 4, 0]],
    EXP_STRATEGIES[3]: [(1/2**n) for n in [24, 20, 16, 12, 8, 4, 0]],
}


EXP_PARAM_NAME = {
    EXP_STRATEGIES[0]: 'epsilon',
    EXP_STRATEGIES[1]: 'temp',
    EXP_STRATEGIES[2]: 'eta',
    EXP_STRATEGIES[3]: 'omega',

}


###############################

def to_command(dic):
    command = 'python main.py'
    for key, value in dic.items():
        if isinstance(value, bool):
            if value:
                command += ' --{}'.format(key)
        else:
            command += ' --{} {}'.format(key, value)
    return command + '\n'


def get_args():
    """
    This function will extract the arguments from the command line
    """

    parser = argparse.ArgumentParser(description='Final Experiments for DQN')

    parser.add_argument('--output_type', default='text_file', type=str, choices=('text_file', 'failed_text_file'),
                        help="")

    parser.add_argument('--output_path', default='dqn_atari_experiments_epsilon_extra', type=str,
                        nargs='?', help="The path to save the output file of this script")

    parser.add_argument('--save_path', default='dqn_atari_results_epsilon_extra', type=str,
                        nargs='?',
                        help="The root path that should be used to save the results of our experiments (This path will be passed to main.py as an argument)")


    return vars(parser.parse_args())


def main(args):
    bash_file_commands = []
    for env_id in ENV_IDS:
        for exp_strategy in EXP_STRATEGIES:
            for step_size in STEP_SIZES:
                for exp_value in EXP_VALUES[exp_strategy]:
                    for seed in SEEDS:
                        dqn_args = deepcopy(DQN_ARGS)
                        dqn_args['verbose'] = 0
                        dqn_args['env_name'] = env_id
                        dqn_args['seed'] = seed
                        dqn_args['step_size'] = step_size
                        dqn_args['num_timesteps'] = ENV_TIMESTEPS[env_id]
                        dqn_args['exploration_strategy'] = exp_strategy
                        dqn_args[EXP_PARAM_NAME[exp_strategy]] = exp_value
                        dqn_args['save_path'] = args['save_path']
                        if args['output_type'] == 'failed_text_file':
                            save_dir = os.path.join(dqn_args['save_path'], env_id, dqn_args['exploration_strategy'],
                                                    str(dqn_args['step_size']), str(float(exp_value)),
                                                    str(dqn_args['seed']))
                            failed = True
                            if os.path.exists(save_dir):
                                child_dirs = os.listdir(save_dir)

                                return_files = [child_dir for child_dir in child_dirs if 'episode_returns' in child_dir]
                                if len(return_files) > 1:
                                    print('======= More than one file =======')
                                    print(save_dir)
                                    print(return_files)
                                   
                                if len(return_files) >= 1:
                                    for recorded_file in return_files:
                                        failed = len(np.load(os.path.join(save_dir, recorded_file))) != int(dqn_args['num_timesteps']//dqn_args['log_interval'])
                                        break
                                    if failed: 
                                        print((return_files))
                                        
                            if failed:
                                print('save_dir: ', save_dir)
                                print('Failed: ', to_command(dqn_args))
                                bash_file_commands.append(to_command(dqn_args))

                        if args['output_type'] == 'text_file':
                            bash_file_commands.append(to_command(dqn_args))

    with open(args['output_path'] + '.txt', 'w') as output:  # This .txt file can use a command list for GNU Parallel
        for row in bash_file_commands:
            output.write(str(row))


if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)

