import os.path
import sys
import numpy as np
import argparse

from util.configuration import Configuration
from Algorithms.action_selector import ActionSelector
from util.util import env_dict, learner_dict, mins_dict, maxes_dict

# If True, the file name will be arranged such that runs with different parameters do not overwrite each other.
sweep = False
retain = False
save_num_steps = True
save_return = True

parser = argparse.ArgumentParser()
parser.add_argument('--alpha', '-a', type=float, default=0.005)
parser.add_argument('--hidden_size', '-hs', type=int, default=128)
parser.add_argument('--buffer_capacity', '-bc', type=int, default=2500)
parser.add_argument('--json', '-j', type=str, default='puddle_world')  # MS_pinball or pinball
parser.add_argument('--aux_type', '-at', type=str, default='None')  # MSGT or None
parser.add_argument('--num_aux_tasks', '-nat', type=int, default=0)  # 4, 8, 16, etc., or 0
parser.add_argument('--starting_run', '-sr', type=int, default=0)
parser.add_argument('--feature_trace', '-ft', type=float, default=0.05)  # important for MSGT. Says how important on average a feature is. 0.01
# means how important over 100 steps
parser.add_argument('--test_freq', '-tf', type=int, default=1000)  # how often do you test and replace.
parser.add_argument('--replace_rate', '-rr', type=float, default=0.25)  # what percentage of features you replace. If high, makes unstable
parser.add_argument('--age_threshold', '-ath', type=int, default=0)
parser.add_argument('--pinball_random_goal_radius', '-prgr', type=float, default=0.035)
parser.add_argument('--pinball_random_goal_increment', '-prgi', type=float, default=0.025)
args = parser.parse_args()


policy_name = 'main'
generate_and_test = True if args.aux_type == 'MSGT' else False

config = Configuration('Experiments/{}.json'.format(args.json))

# === general parameters ===

num_runs = 1 if sweep is False else 3
num_episodes = config.num_episodes
max_steps = config.max_steps
episode_cutoff = config.episode_cutoff
max_episode = config.max_episode

max_episode = num_episodes
max_steps = 10e10

env = env_dict[config.env_name]

main_task_on = True
main_task_ind = 1 if main_task_on else 0
output_size = env.action_space.n * (main_task_ind + args.num_aux_tasks)

if config.env_name in ['minigrid', 'minigrid_door_key']:
    obs_size = np.prod(env.observation_space.shape)
else:
    obs_size = config.obs_size

learner_config = {
    'net_type': config.net_type,
    'cumulant_net': config.cumulant_net_type,
    'obs_size': obs_size,
    'num_actions': env.action_space.n,
    'output_size': output_size,
    'layer_number': config.layer_number,
    'hidden_size': args.hidden_size,
    'num_aux_tasks': args.num_aux_tasks,
    'epsilon': config.epsilon,
    'epsilon_annealing': config.epsilon_annealing,
    'alpha': args.alpha,
    'gamma': config.gamma,
    'aux_weight_loss': 1,
    'buffer_capacity': args.buffer_capacity,
    'num_replay': config.num_replay,
    'replay_start_size': config.replay_start_size,
    'buffer_batch_size': config.buffer_batch_size,
    'target_net_update_frequency': config.target_net_update_frequency,
    'main_task_on': main_task_on,
    'main_task_ind': main_task_ind,
    'mins': mins_dict[config.env_name],
    'maxes': maxes_dict[config.env_name],
    'aux_type': args.aux_type,
    'generate_and_test': generate_and_test,
    'feature_trace': args.feature_trace,
    'test_freq': args.test_freq,
    'replace_rate': args.replace_rate,
    'age_threshold': args.age_threshold,
    'env_name': config.env_name,
    'pinball_random_goal_radius': args.pinball_random_goal_radius,
    'pinball_random_goal_increment': args.pinball_random_goal_increment,
    'HER_sample_state_constant': config.HER_sample_state_constant,
    'optimizer': config.optimizer,
    'head_activation': config.head_activation,
    'end_epsilon': config.end_epsilon

}

aux_scores = np.zeros((num_runs, max_episode, args.num_aux_tasks))
num_steps_over_runs = np.zeros((num_runs, max_episode))
avg_reward_over_episodes = np.zeros((num_runs, max_episode))
retained_auxiliary_tasks = []
high_score_auxiliary_tasks = []
num_high_score_auxiliary_tasks = 4
stable_rank_over_time = np.zeros((num_runs, max_episode))

for run in np.arange(args.starting_run, args.starting_run + num_runs):
    print(run)
    total_steps = 0
    terminate_flag = False
    learner = learner_dict[config.learner_name](learner_config)

    action_selector = ActionSelector(env.action_space.n, policy_name, learner, args.num_aux_tasks)

    for episode in np.arange(max_episode):
        if terminate_flag:
            break
        obs_t = env.reset()
        learner.reset_history(obs_t)
        num_steps = 0
        # if episode == 80:
        #     learner.gvf.remove_best_aux()
        while True:
            a_t = action_selector.select_action(obs_t)
            obs_tp1, r, terminal, _ = env.step(a_t)
            avg_reward_over_episodes[run - args.starting_run, episode] += r
            if config.env_name == 'minigrid_door_key':
                r = np.sign(r)
            if num_steps >= episode_cutoff:
                break
            learner.learn(obs_t, a_t, r, terminal, obs_tp1)
            if main_task_on and terminal:
                break
            obs_t = obs_tp1
            num_steps += 1
            total_steps += 1
            if total_steps >= max_steps:
                terminate_flag = True
                break
        num_steps_over_runs[run - args.starting_run, episode] = num_steps
        print('episode: ', episode)
        print('num_stpes: ', num_steps)
        # avg_reward_over_episodes[run - args.starting_run, episode] /= num_steps
        print('avg reward: ', avg_reward_over_episodes[run - args.starting_run, episode])
        # aux_scores[run - args.starting_run, episode] = learner.aux_score

        stable_rank_over_time[run - args.starting_run, episode] = learner.network.get_stable_rank()

    retained_auxiliary_tasks.append(learner.gvf.random_goals[learner.gvf.retained_auxiliary_tasks])
    high_score_auxiliary_tasks.append(learner.gvf.random_goals[learner.gvf.auxiliary_tasks_ordered[num_high_score_auxiliary_tasks:]])


if not os.path.exists("{}".format(config.env_name)):
    os.makedirs("{}".format(config.env_name))

if sweep:
    np.save('{}/num_steps_over_runs_aux_type_{}_policy_{}_aux_num_{}_rep_nonlinear_step_size_{}_feature_trace_{}_test_freq_{}_pinball_random_goal_radius_{}_pinball_random_goal_increment_{}_replace_rate_{}'.
            format(config.env_name,
            args.aux_type,
            policy_name,
            learner_config['num_aux_tasks'],
            args.alpha,
            args.feature_trace,
            args.test_freq,
            args.pinball_random_goal_radius,
            args.pinball_random_goal_increment,
            args.replace_rate),
            num_steps_over_runs)
else:
    if save_num_steps:
        np.save('{}/num_steps_over_runs_aux_type_{}_policy_{}_aux_num_{}_rep_nonlinear_run_{}'.
                format(config.env_name,
                       args.aux_type,
                       policy_name,
                       learner_config['num_aux_tasks'],
                       args.starting_run),
                num_steps_over_runs)

# print(time.time() - start)
# np.save('{}/aux_scores_{}'.format(config.env_name, args.aux_type), aux_scores)
# np.save('{}/stable_ranks_{}_run_{}'.format(config.env_name, args.aux_type, args.starting_run), stable_rank_over_time)
if retain:
    np.save('{}/retained_auxiliary_tasks_run_{}'.format(config.env_name, args.starting_run), retained_auxiliary_tasks)
if save_return:
    np.save('{}/avg_reward_aux_type_{}_policy_{}_aux_num_{}_rep_nonlinear_run_{}'.
                format(config.env_name,
                       args.aux_type,
                       policy_name,
                       learner_config['num_aux_tasks'],
                       args.starting_run),
                avg_reward_over_episodes)