import numpy as np
import argparse
# from main import main as main_linear
from copy import deepcopy
import os

# LINEAR_ARGS = {'exploration_strategy': 'epsilon-greedy', 'algorithm': 'q-learning', 'epsilon': 0.1, 'reps': 1, 'tile_coding': True, 'seed': 1, 'num_timesteps': 1000000, 'num_episodes': 0, 'episode_based': 0, 'gamma': 0.99, 'eta': 12.0, 'temp': 1.0, 'only_store_rewards': False, 'exploration_schedule': 0, 'env_name': 'MountainCar-v0', 'save_path': 'results', 'num_tiles': '8', 'num_tiling': '8', 'agent_type': 'linear', 'iht_size': '4096', 'max_iter': '5000', 'init': '0.003125', 'rand_init': '0'}
LINEAR_ARGS = {'algorithm': 'q-learning', 'exploration_strategy': 'epsilon-greedy', 'epsilon': 0.1, 'reps': 1, 'tile_coding': True, 'seed': 1, 'num_timesteps': 500000, 'num_episodes': 0, 'episode_based': 0, 'gamma': 0.99, 'eta': 12.0, 'temp': 1.0, 'omega': 1.0, 'only_store_rewards': False, 'exploration_schedule': 1, 'env_name': 'CartPole-v0', 'save_path': 'results', 'num_tiles': '8', 'num_tiling': '8', 'agent_type': 'linear', 'iht_size': '65536', 'max_iter': '200', 'init': '0.0', 'rand_init': '0', 'save_type': 'episodic_steps', 'log_interval': 10000, 'g_min': '1.0', 'g_max': '99.34', 'eval_episodes_num': '30', 'normalization_scheme': "none", 'init': '0.0', 'td_step_size': '0.9', 'td_epsilon': '0.01', 'zeta': '1', 'outside_value': 0.1, 'portion_decay': 0.1}

### Experimental Parameters ###

SEEDS = np.arange(1, 31)
ENV_IDS = ['CartPole-v0']
EXP_STRATEGIES = ['epsilon-greedy']
TILES = [8]
ALGORITHMS = ["q-learning", "expected_sarsa"]

STEP_SIZES = [0.1]
OUTSIDE_VALUES = [0.1, 0.05, 0.01]
PORTION_DECAYS = [0.05, 0.1, 0.2, 0.3]
EXP_VALUES = {
        EXP_STRATEGIES[0]: [0.1],
        }
EXP_PARAM_NAME = {
        EXP_STRATEGIES[0]: 'epsilon',
        }

###############################

def to_command(dic):
    command = 'python main.py'
    for key, value in dic.items():
        if isinstance(value, bool):
            if value:
                command += ' --{}'.format(key)
        else:
            command += ' --{} {}'.format(key, value)
    return command + '\n'

def get_args():
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Linear Experiments')

    parser.add_argument('--output_type',  default='text_file', type=str, choices=("text_file", 'failed_text_file'))

    parser.add_argument('--output_path', default='linear_cartpole_experiments_adapt_3', type=str,
            nargs='?', help="The path to save the output file of this script")

    parser.add_argument('--save_path', default='results_adapt', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments (This path will be passed to main.py as an argument)")

    return vars(parser.parse_args())


def main(args):

    bash_file_commands = []
    for env_id in ENV_IDS:
        for algorithm in ALGORITHMS:
            for exp_strategy in EXP_STRATEGIES:
                for outside_value in OUTSIDE_VALUES:
                    for portion_decay in PORTION_DECAYS:
                        for step_size in STEP_SIZES:
                            for exp_value in EXP_VALUES[exp_strategy]:
                                for tiles in TILES:
                                    for seed in SEEDS:
                                        linear_args = deepcopy(LINEAR_ARGS)
                                        linear_args['verbose'] = 0
                                        linear_args['env_name'] = env_id
                                        linear_args['seed'] = seed
                                        linear_args['step_size'] = step_size
                                        linear_args['exploration_strategy'] = exp_strategy
                                        linear_args['portion_decay'] = portion_decay
                                        linear_args['outside_value'] = outside_value                                        
                                        linear_args[EXP_PARAM_NAME[exp_strategy]] = exp_value
                                        linear_args['save_path'] = args['save_path']
                                        linear_args['algorithm'] = algorithm
                                        linear_args['num_tiling'] = tiles
                                        linear_args['num_tiles'] = tiles
                                        if args['output_type'] == 'failed_text_file':
                                            save_dir = os.path.join(linear_args['save_path'], algorithm, env_id, exp_strategy, str(linear_args['portion_decay']), str(linear_args['outside_value']), str(step_size), str(float(exp_value)), str(seed))
                                            failed = True
                                            if os.path.exists(save_dir):
                                                child_dirs = os.listdir(save_dir)

                                                is_seed_dir = any(['episode_returns' in child_dir for child_dir in child_dirs])
                                                if is_seed_dir:
                                                    for recorded_file in child_dirs:
                                                        if 'episode_returns' in recorded_file:
                                                            failed = len(np.load(os.path.join(save_dir, recorded_file))) != int(linear_args['num_timesteps']//linear_args['log_interval'])
                                                            break
                                            if failed:
                                                print('save_dir: ', save_dir)
                                                print('Failed: ', to_command(linear_args)) 
                                                bash_file_commands.append(to_command(linear_args))
                                        if args['output_type'] == 'text_file':
                                            bash_file_commands.append(to_command(linear_args))

    with open(args['output_path'] + '.txt', 'w') as output: # This .txt file can use a command list for GNU Parallel
        for row in bash_file_commands:
            output.write(str(row))
 
if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)