"""
DQN algorithm with different exploration techniques
"""
# import torch
import argparse
import os
import time
import gym
import bsuite
import pandas as pd
import numpy as np

from bsuite import sweep
from bsuite.utils import gym_wrapper
from utils.snippets import constant_exploration_schedule, stepwise_exploration_schedule
from utils.agents import *
from gym_recording_modified.wrappers import TraceRecordingWrapper
from utils.preprocessing_tools import *
from datetime import datetime
from sys import platform
from MountainCarRecording import SparseMountainCarWrapper
from utils.logger import Logger
from utils.logging import NumpyLogger 

def get_args(): #TODO Choices argument should be set for all of the ArgumentParsers below
    """
    This function will extract the arguments from the command line
    """
    parser = argparse.ArgumentParser(description='Exploration Strategies for DQN')
    parser.add_argument('--exploration_strategy',  default='ResMax', type=str, choices=("epsilon-greedy", "softmax", "ResMax", "mellowmax"), help="Exploration Strategy to be used")
    parser.add_argument('--algorithm', default='q-learning', type=str, choices=['q-learning', 'expected_sarsa'], help="Algorithm that we with to run: 1. q-learning 2. expected_sarsa")
    parser.add_argument('--epsilon', default=0.1, type=float,
            nargs='?', help="Value of epsilon for epsilon-greedy strategy")
    parser.add_argument('--reps', default=1,
                      type=int, nargs='?',
                      help='Number of executions')
    parser.add_argument('--agent_type', type=str, default="linear", choices=("linear", "LSVI"))
    parser.add_argument('--seed', default=1, type=int, nargs='?')
    parser.add_argument('--num_timesteps', type=int, default=10000)
    parser.add_argument('--num_episodes', type=int, default=50)
    parser.add_argument('--episode_based', type=int, default=0)
    parser.add_argument('--num_agent_train_steps_per_iter', type=int, default=1)
    parser.add_argument('--gamma', type=float, default=1.0, help="Discounting factor parameter")
    parser.add_argument('--eta', type=float, default=12., help="Exploitation pressure used for ResMax exploration technique")
    parser.add_argument('--temp', type=float, default=1., help="Temperature value for softmax exploration technique")
    parser.add_argument('--omega', default=0.1, type=float, help="Value of omega for mellowmax strategy")
    parser.add_argument('--step_size', default=0, type=float, help="Step-size that will be used to update the value function")
    parser.add_argument('--tile_coding', action='store_true')
    parser.add_argument('--num_tilings', type=int, default=8)
    parser.add_argument('--init', type=float, default=0.0)
    parser.add_argument('--rand_init', type=int, default=0, choices=(0, 1))
    parser.add_argument('--max_iter', type=int, default=200)
    parser.add_argument('--num_tiles', type=int, default=8)
    parser.add_argument('--iht_size', type=int, default=4096)
    parser.add_argument('--only_store_rewards', action='store_true')
    parser.add_argument('--verbose', type=int, choices=[0, 1], default=1)
    parser.add_argument('--exploration_schedule', default=0, type=int, choices=(0, 1), 
            help="Type of exploration schedule used for epsilon-greedy exploration. 0: constant schedule 1: stepwise schedule")
    parser.add_argument('--env_name', default='deep_sea/0', type=str,
            nargs='?', help="Name of the enviornment: 1. deep_sea/0")
    parser.add_argument('--save_path', default='results', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments")
    parser.add_argument('--save_type', default='episodic_return', type=str, choices=['episodic_return', 'reward_per_step', 'episodic_steps'], nargs='?', help="The type of data that should be saved")
    parser.add_argument('--log_interval', type=int, default=10000, help="Interval that has been used to save the episodic returns or steps")
    parser.add_argument('--g_min', type=float, default=0., help="Minimum return for the environment.")
    parser.add_argument('--g_max', type=float, default=1., help="Maximum return for the environment.")
    parser.add_argument('--normalization_scheme', '-ns', type=str, default="none", choices=("none", "fixed", "td_squared", "td_absolute"))
    parser.add_argument('--eval_episodes_num', type=int, default=10, help="The number of episodes that should be evaluated after each log_interval steps")
    parser.add_argument('--td_step_size', type=float, default=0.9, help="Step size parameter for running average of squared td error")
    parser.add_argument('--td_epsilon', type=float, default=10e-9, help="lower bound of td-based normalization")
    parser.add_argument('--zeta', type=float, default=0, help="initial value of zeta (running average of squared TD error)")
    parser.add_argument('--outside_value', type=float, default=0.1, help="The final value of the epsilon used with exploration scheduling")
    parser.add_argument('--portion_decay', type=float, default=0.1, help="portion of timesteps number to be used to decay!")
    return vars(parser.parse_args())


def main(args: dict):
    """
    Main Process
    """
    # Extracting the exploration hyperparameter
    if args['exploration_strategy'] == 'epsilon-greedy':
        exp_value = args['epsilon']
    elif args['exploration_strategy'] == 'softmax':
        exp_value = args['temp']
    elif args['exploration_strategy'] == 'ResMax':
        exp_value = args['eta']
    elif args['exploration_strategy'] == 'mellowmax':
        exp_value = args['omega']
    else:
        raise ValueError("Value of exploration_strategy is wrong: {}".format(args['exploration_strategy']))
    env_name = args['env_name']
    if platform == "win32":
        env_name = args['env_name'].replace('/', '\\') # make saving work on windows
        env_name = args['env_name'].replace(':', '_') # make saving work on windows

    verbose = args['verbose']
    
    # Specifying the directory that will be used to save the results of these experiments: {save_path}/{env_name}/{exploration_strategy}/{step_size}/{Exploration Value}/{Seed}
    # save_dir: str = os.path.join(args['save_path'], env_name, args['exploration_strategy'], str(args['step_size']), str(exp_value), str(args['seed']))
    
    # save_dir: str = os.path.join(args['save_path'], env_name, args['algorithm'], args['exploration_strategy'], str(args['num_tiles']), str(args['num_tilings']), str(args['step_size']), str(exp_value), str(args['seed']))
    if args['exploration_schedule'] == 0:
        save_dir: str = os.path.join(args['save_path'], args['algorithm'], env_name, args['exploration_strategy'], str(args['step_size']), str(exp_value), str(args['seed']))
    elif args['exploration_schedule'] == 1:
        save_dir: str = os.path.join(args['save_path'], args['algorithm'], env_name, args['exploration_strategy'], str(args['portion_decay']), str(args['outside_value']), str(args['step_size']), str(exp_value), str(args['seed']))
    else:
        raise ValueError('')
        
    logger = NumpyLogger(save_dir)
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print("Save path didn't exist, so it has been created: {}".format(save_dir))

    logging = Logger(save_dir)

    # Preprocessing Saved Observations
    preprocess_saved_obs = None
    if 'deep_sea' in args['env_name']:
        preprocess_saved_obs = lambda obs: deep_sea_obs_preprocess(obs, True) #???

    # Saving the arguments in the save_dir
    pd.Series(args).to_csv(os.path.join(save_dir, 'args.csv'))

    # Make the gym environment # TODO: find a way to combine load_and_record and TraceRecordingWrapper file structure (maybe create function for loading the results of both of these methods)
    if '/' not in args['env_name'] and '\\' not in args['env_name'] : # This checks whether the enviornment is part of bsuite or gym
        def raw_env_class_gen():
            tmp_env = gym.make(args['env_name'])

            tmp_env._max_episode_steps = args['max_iter']
            if 'MountainCar' in args['env_name']:
                tmp_env = SparseMountainCarWrapper(tmp_env, steps_limit=args['max_iter'])
            logging.info('MountainCar with')

            return tmp_env
        raw_env_class = raw_env_class_gen
        raw_env = raw_env_class()

        env = TraceRecordingWrapper(raw_env, save_dir, only_reward=args['only_store_rewards'], preprocess_obs=preprocess_saved_obs, save_type=args['save_type'], log_interval=args['log_interval'], logger=logging)#, batch_size=1)
        ep_len = env.spec.max_episode_steps
    else:
        # raw_env = bsuite.load_and_record(args['env_name'], save_path=save_dir, overwrite=True)
        raw_env = gym_wrapper.GymFromDMEnv(bsuite.load_from_id(args['env_name']))
        env = TraceRecordingWrapper(raw_env, save_dir, only_reward=args['only_store_rewards'], preprocess_obs=preprocess_saved_obs, save_type=args['save_type'], log_interval=args['log_interval'], logger=logging)#, batch_size=1)
        if 'deep_sea' in args['env_name']:
            ep_len = env.observation_space.shape[0]
        else:
            raise NotImplementedError
        
        # Logging the settings of this environment
        bsuite_id = args['env_name'] # save_dir[(save_dir.rfind('/'))+1:]
        print('bsuite_id={}, settings={}, num_episodes={}'
            .format(bsuite_id, sweep.SETTINGS[bsuite_id], env.bsuite_num_episodes))

    # Set random seeds
    seed: int = args['seed']
    # torch.manual_seed(seed)
    np.random.seed(seed)
    env.seed(seed)
 
    # In the initial stage of our research project, we are going to only benchmark our algorithms on enviornments with discrete action spaces
    discrete: bool = isinstance(env.action_space, gym.spaces.Discrete)
    assert discrete, 'The action space shuold be discrete'

    # Are the observations images? # TODO: functionality to work with images should be added for playing Atari
    img: bool = len(env.observation_space.shape) > 2
    assert not img
    
    # Observation and action sizes
    if len(env.observation_space.shape) == 0:
        ob_dim = env.observation_space.n
    else:
        ob_dim = env.observation_space.shape if img else int(np.prod(env.observation_space.shape))
    
    # Setting Exploration Schedule 
    if args['exploration_schedule'] == 0:
        exploration_schedule = constant_exploration_schedule(exp_value)
    elif args['exploration_schedule'] == 1:
        exploration_schedule = stepwise_exploration_schedule(args['num_timesteps'], args['outside_value'], args['portion_decay'], args['exploration_strategy'])
    else:
        raise ValueError("Value of exploration_schedule is wrong: {}".format(args['exploration_schedule']))

    agent_params: dict = {
            'algorithm': args['algorithm'],
            'ac_dim': env.action_space.n if discrete else env.action_space.n,
            'ob_dim': ob_dim,
            'ob_low': env.observation_space.low if args['tile_coding'] else None,
            'ob_high': env.observation_space.high if args['tile_coding'] else None,
            'tile_coding': args['tile_coding'],
            'num_tilings': args['num_tilings'],
            'num_tiles': args['num_tiles'],
            'iht_size': args['iht_size'],
            'gamma': args['gamma'],
            'exploration_strategy': args['exploration_strategy'],
            'exploration_schedule': exploration_schedule,
            'env_name': args['env_name'],
            'ep_len': ep_len,
            'step_size': args['step_size'],
            'init': args['init'],
            'rand_init': args['rand_init'],
            'normalization_scheme': args['normalization_scheme'],
            'g_min' : args['g_min'],
            'g_max' : args['g_max'],
            'td_step_size' : args['td_step_size'],
            'td_epsilon' : args['td_epsilon'],
            'zeta' : args['zeta']
            }
    
    logging.info('Run Params: {}'.format(args))
    logging.info('Agent Params: {}'.format(agent_params))

    if args['agent_type'] == 'linear':
        if 'deep_sea' in args['env_name']:
            agent : LinearAgent = DeepSeaLinearAgent(env, agent_params)
        elif 'riverswim' in args['env_name']:
            agent : LinearAgent = RiverSwimLinearAgent(env, agent_params)
        elif ('CartPole' in args['env_name'] or 
                'MountainCar' in args['env_name'] or 
                'Acrobot-v1' in args['env_name'] or 
                'VarianceWorld-v0' in args['env_name'] or
                'ContinuousRiverswim-v0' in args['env_name'] or
                'Antishaping-v0' in args['env_name'] or
                'Hypercube-v0' in args['env_name']) and args['tile_coding']:
            agent : LinearAgent = LinearAgent(env, agent_params)
        else:
            # NOTE add more environments
            raise NotImplementedError

    # init vars at beginning of training
    start_time: float = time.time()
    agent.start()

    online_ep_lens = []
    last_itr = 0
    num_eps = 0

    td_errors = []#np.empty(args['num_timesteps'])
    weights_changes = []#np.empty(args['num_timesteps'])
    
    offline_eval_returns = []
    offline_eval_steps = []

    logging.info('## Started ##')

    # Variables needed for logging evaluation time
    evaluation_time_spent = 0
    evaluation_nums = 0

    for itr in range(args['num_timesteps']):

        if itr % args['log_interval'] == 0:
            print('offline')
            eval_time = time.time()
            evaluation_nums += 1
            off_returns, off_steps = agent.offline_eval(args['eval_episodes_num'], raw_env_class, seed)
            logging.info("Offline Evaluation Returns -> mean: {}, median: {}, max: {}, min: {}".format(np.mean(off_returns), np.median(off_returns), np.max(off_returns), np.min(off_returns)))
            logging.info("Offline Evaluation Steps -> mean: {}, median: {}, max: {}, min: {}".format(np.mean(off_steps), np.median(off_steps), np.max(off_steps), np.min(off_steps)))
            logging.info('Evaluation time: ' + str(evaluation_time_spent/evaluation_nums))
            logging.info('======================= Timestep: {} =========================='.format(itr))

            offline_eval_returns.append(np.mean(off_returns))
            offline_eval_steps.append(np.mean(off_steps))
            evaluation_time_spent += time.time() - eval_time

        if verbose: 
            print("\n\n********** Iteration %i ************"%itr)
            print("\n\n********** Episode   %i ************"%num_eps)
        done, td_error, weight_change = agent.step_env()
        td_errors.append(td_error)
        weights_changes.append(weight_change)

        if done:
            online_ep_lens.append(itr - last_itr)
            last_itr = itr
            num_eps += 1
            if args['episode_based'] == 1:
                if num_eps == args['num_episodes']:
                    break    

    # if verbose: 
    #     print("Execution Time (s): ", time.time() - start_time)
    #     print("Execution Time (m): ", (time.time() - start_time)/60)

    logger.log_arr(np.array(td_errors), 'td_errors.npy')
    logger.log_arr(np.array(weights_changes), 'weight_differences.npy')
    logger.log_arr(np.array(offline_eval_returns), 'offline_eval_returns.npy')
    logger.log_arr(np.array(offline_eval_steps), 'offline_eval_steps.npy')

    logging.info("Execution Time (s): {}".format(time.time() - start_time))
    logging.info("Execution Time (m): {}".format((time.time() - start_time)/60))
    logging.info('Evaluation Time: ' + str(evaluation_time_spent/evaluation_nums))

    env.close()
    if verbose: print("Average num steps in last 100 episodes:", np.average(online_ep_lens[num_eps - 100 : num_eps]))

if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)
