"""
Main program for tabular algorithms with different exploration techniques
"""

import argparse
import os
import time
import gym
import bsuite
import pandas as pd
import numpy as np

from pathlib import Path
from tqdm import tqdm

from bsuite.utils import gym_wrapper
from gym_recording_modified.wrappers import TraceRecordingWrapper

from agent.agents import DeepSeaTabularAgent, RiverSwimTabularAgent, HardSquareTabularAgent, TwoStateTabularAgent
from agent.preprocessing_tools import *
from utils.logger import Logger


def get_args(): #TODO Choices argument should be set for all of the ArgumentParsers below
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Exploration Strategies for Tabular Agent')

    parser.add_argument('--exploration_strategy',  default='resmax', type=str, 
        choices=("epsilon-greedy", "softmax", "resmax", "resmax-normalized", "softmax-normalized", "resmax-normalized-td", "softmax-normalized-td", 'mellowmax', 'log-sum-exp'), help="Exploration Strategy to be used")
# =======
#         choices=("epsilon-greedy", "softmax", "resmax", 'mellowmax'), help="Exploration Strategy to be used")
# >>>>>>> 8b00340c06bc41f8c07b2cc5faca23129253f748

    parser.add_argument('--algorithm', default='expected-sarsa', type=str, choices=['q-learning', 'expected-sarsa'], help="Algorithm that we with to run: 1. q-learning 2. expected-sarsa")
    
    parser.add_argument('--epsilon', default=0.1, type=float,
            nargs='?', help="Value of epsilon for epsilon-greedy strategy")

    parser.add_argument('--seed', default=1, type=int, nargs='?')

    parser.add_argument('--num_timesteps', type=int, default=1000)
    
    parser.add_argument('--step_size', type=float, default=0.1, help = "Step size for updates" )

    parser.add_argument('--eta', type=float, default=2**-5, help="Exploitaon pressure used for resmax exploration technique")

    parser.add_argument('--temp', type=float, default=2**-2, help="Temperature value for softmax exploration technique")

    parser.add_argument('--omega', type=float, default=1., help="Omega parameter for mellowmax")

    parser.add_argument('--only_store_rewards', default=False, action='store_true', help = "Whether to only store rewards for the traces" )

    parser.add_argument('--save_policy', default=False, action='store_true', help = "Whether to save the policy" )

    parser.add_argument('--g_min', type=float, default=0., help="Minimum return for the environment.")

    parser.add_argument('--g_max', type=float, default=1., help="Maximum return for the environment.")
    
    parser.add_argument('--gamma', type=float, default=1, help="Discount factor gamma for the environment.")

    parser.add_argument('--td_step_size', type=float, default=0.9, help="Step size parameter for running average of squared td error")

    parser.add_argument('--td_epsilon', type=float, default=10e-9, help="lower bound of td-based normalization")

    parser.add_argument('--initial_optimism', type=float, default=0, help="initial q-values")

    parser.add_argument('--zeta', type=float, default=0, help="initial value of zeta (running average of squared TD error)")

    parser.add_argument('--env_name', default='gym_riverswim:riverswim-v0', type=str,
            nargs='?', help="Name of the enviornment", choices =(
                'deep_sea/0', 
                'gym_riverswim:riverswim-v0',
                'gym_exploration:HardSquare-v0',  
                'gym_exploration:TwoState-v0',  
                'riverswim_variants:stochastic-riverswim-v0', 
                'riverswim_variants:skewed-stochastic-riverswim-v0', 
                'riverswim_variants:scaled-riverswim-v0'))

    parser.add_argument('--save_path', default='results_temp', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments")

    parser.add_argument('--save_type', default='episodic_return', type=str, choices=['episodic_return', 'reward_per_step', 'episodic_steps'], nargs='?', help="The type of data that should be saved")
    
    parser.add_argument('--log_interval', type=int, default=1000, help="Interval that has been used to save the episodic returns or steps")

    parser.add_argument('--eval_episodes_num', type=int, default=30, help="The number of episodes that should be evaluated after each log_interval steps")

    return vars(parser.parse_args())


def main(args: dict):
    """
    Main Process
    """
    # Extracting the exploration hyperparameter
    if args['exploration_strategy'] == 'epsilon-greedy':
        exp_value = args['epsilon']
    elif 'softmax' in args['exploration_strategy']:
        exp_value = args['temp']
    elif 'log-sum-exp' in args['exploration_strategy']:
        exp_value = args['temp']
    elif 'resmax' in args['exploration_strategy']:
        exp_value = args['eta']
    elif 'mellowmax' in args['exploration_strategy']:
        exp_value = args['omega']
    else:
        raise ValueError("Value of exploration_strategy is wrong: {}".format(args['exploration_strategy']))  
 
    save_dir = args['save_path']
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Preprocessing Saved Observations
    preprocess_saved_obs = None
    if 'deep_sea' in args['env_name']:
        preprocess_saved_obs = lambda obs: deep_sea_obs_preprocess(obs)

    # Saving the arguments in the save_dir
    pd.Series(args).to_csv(os.path.join(save_dir, 'args.csv'))

    # Make environment
    if 'deep_sea' not in args['env_name']: # This checks whether the enviornment is part of bsuite or gym
        print(args['env_name'])
        make_raw_env = lambda: gym.make(args['env_name'])
    else:
        if 'deep_sea_deterministic' in args['env_name']:
            size = int(args['env_name'].split("/")[1]) 
            make_raw_env = lambda: gym_wrapper.GymFromDMEnv(bsuite.environments.deep_sea.DeepSea(size, randomize_actions=False))
        else:
            make_raw_env = lambda: gym_wrapper.GymFromDMEnv(bsuite.load_from_id(args['env_name']))

    raw_env = make_raw_env()
    logging = Logger(save_dir)
    env = TraceRecordingWrapper(
        raw_env, 
        save_dir, 
        only_reward=args['only_store_rewards'], 
        preprocess_obs=preprocess_saved_obs, 
        save_type=args['save_type'], 
        log_interval=args['log_interval'],
        logger=logging)

    # Set random seeds
    seed: int = args['seed']
    np.random.seed(seed)
    env.seed(seed)

    # For the first step, we are going to only benchmark our algorithms on enviornments with discrete action spaces
    discrete: bool = isinstance(env.action_space, gym.spaces.Discrete)
    assert discrete, 'The action space shuold be discrete'

    # Are the observations images?
    img: bool = len(env.observation_space.shape) > 2
    assert not img

    # Observation and action sizes
    if len(env.observation_space.shape) == 0:
        ob_dim = env.observation_space.n
    else:
        ob_dim = env.observation_space.shape if img else int(np.prod(env.observation_space.shape))

    ac_dim = env.action_space.n if discrete else env.action_space.n
    
    if  args.get('save_policy', False):
        assert 'riverswim' in args['env_name'],  'only use this functionality with riverswim'

    agent_params: dict = {
            'algorithm': args['algorithm'],
            'ac_dim': ac_dim,
            'ob_dim': ob_dim,
            'step_size':  args['step_size'],
            'input_shape': None, 
            'gamma': args['gamma'],
            'exploration_strategy': args['exploration_strategy'],
            'exp_value': exp_value,
            'num_timesteps': args['num_timesteps'],
            'save_policy': args.get('save_policy', False),
            'g_min' : args['g_min'],
            'g_max' : args['g_max'],
            'td_step_size' : args['td_step_size'],
            'td_epsilon' : args['td_step_size'],
            'zeta' : args['zeta'],
            'horizon': False,
            'initial_optimism' : args['initial_optimism']
            }

    logging.info('Run Params: {}'.format(args))
    logging.info('Agent Params: {}'.format(agent_params))

    if 'deep_sea' in args['env_name']:
        agent = DeepSeaTabularAgent(env, agent_params)
    elif 'riverswim' in args['env_name']:
        agent_params['horizon'] = 20 # hard coded horizon for riverswim
        agent  = RiverSwimTabularAgent(env, agent_params)
    elif 'HardSquare' in args['env_name']:
        agent = HardSquareTabularAgent(env, agent_params)   
    elif 'TwoState' in args['env_name']:
        agent = TwoStateTabularAgent(env, agent_params)     
    else:
        # NOTE add more environments
        raise NotImplementedError

    # init vars at beginning of training
    start_time = time.time()
    online_episode_len = []

    logging.info('## Started ##')
    agent.start()

    step = 0
    sum_reward = 0
    
    for itr in tqdm(range(args['num_timesteps'])):
        done, reward = agent.step_env()
        step += 1        
        if done:
            online_episode_len.append(step)
            step = 0 

    
        sum_reward += reward
        
    if args.get('save_policy', False) :# for plotting riverswim policy over time
        np.save(Path(save_dir)/"policy.npy", agent.policy_log)

    logging.info("Execution Time (s): {}".format(time.time() - start_time))
    logging.info("Execution Time (m): {}".format((time.time() - start_time)/60))
    env.close()

if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)
