"""
DQN algorithm with different exploration techniques
"""
import torch
import argparse 
import os
import time
import gym
# import logging
import pandas as pd
import numpy as np
import torch.nn.functional as F

from utils.snippets import *
from utils.agents import DQNAgent
from utils.networks import MLP, LFANet, MinAtarQNetwork, create_atari_q_network
from torch import nn
from gym_recording_modified.wrappers import TraceRecordingWrapper
from utils.preprocessing_tools import *
from utils.atari_wrappers import wrap_deepmind
from datetime import datetime
from sys import platform
from utils.logger import Logger
from utils.logging import NumpyLogger 
from utils.wrappers import SparseMountainCarWrapper

DEFAULT_REPLAY_BUFFER_SIZE = 50000
DEFAULT_TARGET_UPDATE_FREQUENCY = 1000

def get_args(): 
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Exploration Strategies for DQN')

    parser.add_argument('--exploration_strategy',  default='epsilon-greedy', type=str, choices=("epsilon-greedy", "softmax", "resmax", "mellowmax"), help="Exploration Strategy to be used")

    parser.add_argument('--algorithm', default='q-learning', type=str, choices=['q-learning', 'expected_sarsa'], help="Algorithm that we with to run: 1. q-learning 2. expected_sarsa")
    
    parser.add_argument('--epsilon', default=0.1, type=float,
            nargs='?', help="Value of epsilon for epsilon-greedy strategy")

    parser.add_argument('--omega', default=0.1, type=float,
            nargs='?', help="Value of omega for mellowmax strategy")

    parser.add_argument('--reps', default=1,
                      type=int, nargs='?',
                      help='Number of executions')

    parser.add_argument('--use_gpu', '-gpu', action='store_true')

    parser.add_argument('--use_normalization_scheme', '-uns', action='store_true')
    
    parser.add_argument('--do_offline_evaluation', '-doe', action='store_true')
    
    parser.add_argument('--gpu_id', type=int, default=0)

    parser.add_argument('--agent_type', type=str, default="non-linear", choices=("non-linear", "linear"))

    parser.add_argument('--seed', default=1, type=int, nargs='?')

    parser.add_argument('--batch_size', type=int, default=32)
    
    parser.add_argument('--num_timesteps', type=int, default=100000)
    
    parser.add_argument('--num_agent_train_steps_per_iter', type=int, default=1)

    parser.add_argument('--target_update_freq', type=int, default=DEFAULT_TARGET_UPDATE_FREQUENCY)
 
    parser.add_argument('--learning_starts', type=int, default=32)
    
    parser.add_argument('--nn_size', type=int, default=64)
    
    parser.add_argument('--n_layers', type=int, default=1)
    
    parser.add_argument('--verbose', type=int, choices=[0, 1], default=1)
    
    parser.add_argument('--gamma', type=float, default=1.0, help="Discounting factor parameter")
    
    parser.add_argument('--eta', type=float, default=12., help="Exploitation pressure used for resmax exploration technique")
    
    parser.add_argument('--temp', type=float, default=1., help="Temperature value for softmax exploration technique")
    
    parser.add_argument('--step_size', type=float, default=1e-3, help="Step-size that will be used to update the value function")
    
    parser.add_argument('--only_store_rewards', action='store_true')
    
    parser.add_argument('--double_q', action='store_true')
    
    parser.add_argument('--render', action='store_true')
    
    parser.add_argument('--exploration_schedule', default=0, type=int, choices=(0, 1, 2), 
            help="Type of exploration schedule used for epsilon-greedy exploration. 0: constant schedule 1: stepwise schedule")

    parser.add_argument('--outside_value', type=float, default=0.1, help="The final value of the epsilon used with exploration scheduling")

    parser.add_argument('--portion_decay', type=float, default=0.1, help="portion of timesteps number to be used to decay!")
    
    parser.add_argument('--env_name', default='deep_sea/0', type=str,
            nargs='?', help="Name of the enviornment: 1. deep_sea/0")

    parser.add_argument('--save_path', default='results', type=str,
            nargs='?', help="The root path that should be used to save the results of our experiments")
    
    parser.add_argument('--save_type', default='episodic_return', type=str, choices=['episodic_return', 'reward_per_step', 'episodic_steps'], nargs='?', help="The type of data that should be saved")
    
    parser.add_argument('--log_interval', type=int, default=10000, help="Interval that has been used to save the episodic returns or steps")
    
    parser.add_argument('--eval_episodes_num', type=int, default=10, help="The number of episodes that should be evaluated after each log_interval steps")
    
    parser.add_argument('--td_error_mg', type=float, default=1., help="The initial value of moving average of td errors.")
 
    parser.add_argument('--td_error_mg_lr', type=float, default=.99, help="The moving average parameter.")
    
    parser.add_argument('--td_error_mg_epsilon', type=float, default=0.0001, help="The epsilon value used in the denominator to keep the level of exploration higher than a certain value.")
    
    parser.add_argument('--td_error_scheduling', '-tes', action='store_true')
    parser.add_argument('--replay_buffer_size', type=int, default=DEFAULT_REPLAY_BUFFER_SIZE, help="Replay buffer size")

    return vars(parser.parse_args())


def main(args: dict):
    """
    Main Process
    """

    # Assigning GPU or CPU device
    device: str = gpu_assigner(args['use_gpu'], args['gpu_id'])

    # Setting the number of threads to 1 for running on Cedar cluster
    torch.set_num_threads(1)

    verbose = args['verbose']

    # Extracting the exploratio hyperparameter
    if args['exploration_strategy'] == 'epsilon-greedy':
        exp_value = args['epsilon']
    elif args['exploration_strategy'] == 'softmax':
        exp_value = args['temp']
    elif args['exploration_strategy'] == 'resmax':
        exp_value = args['eta']
    elif args['exploration_strategy'] == 'mellowmax':
        exp_value = args['omega']
    else:
        raise ValueError("Value of exploration_strategy is wrong: {}".format(args['exploration_strategy']))

    env_name = args['env_name']
    if platform == "win32":
        env_name = args['env_name'].replace('/', '\\') # make saving work on windows

    # Making directory naming compatible with windows
    if ':' in env_name:
        env_name = env_name[env_name.find(':')+1:]

    # Specifying the directory that will be used to save the results of these experiments: {save_path}/{env_name}/{exploration_strategy}/{step_size}/{Exploration Value}/{Seed}
    if args['target_update_freq'] != DEFAULT_TARGET_UPDATE_FREQUENCY:
        save_dir: str = os.path.join(args['save_path'], env_name, args['exploration_strategy'], str(args['step_size']), str(exp_value), str(args['target_update_freq']), str(args['seed']))
    elif args['target_update_freq'] != DEFAULT_REPLAY_BUFFER_SIZE:
        save_dir: str = os.path.join(args['save_path'], env_name, args['exploration_strategy'], str(args['step_size']), str(exp_value), str(args['seed']))
    elif args['exploration_schedule'] in [0, 2]:
        save_dir: str = os.path.join(args['save_path'], env_name, args['exploration_strategy'], str(args['step_size']), str(exp_value), str(args['seed']))
    elif args['exploration_schedule'] == 1:
        save_dir: str = os.path.join(args['save_path'], env_name, args['exploration_strategy'], str(args['portion_decay']), str(args['outside_value']), str(args['step_size']), str(exp_value), str(args['seed']))
    else:
        raise ValueError('')

    logger = NumpyLogger(save_dir)
    print("!Save path didn't exist, so it has been created: {}".format(save_dir))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print("Save path didn't exist, so it has been created: {}".format(save_dir))

    logging = Logger(save_dir)
    
    # logging.basicConfig(level=logging.INFO, filename=os.path.join(save_dir, 'info_logs.log'), filemode='w', format='%(name)s - %(levelname)s - %(message)s')
    
    # Preprocessing Saved Observations
    preprocess_saved_obs = None
    if 'deep_sea' in args['env_name']:
        preprocess_saved_obs = lambda obs: deep_sea_obs_preprocess(obs, True)
    
    # Saving the arguments in the save_dir
    pd.Series(args).to_csv(os.path.join(save_dir, 'args.csv'))

    # Make the gym environment # TODO: find a way to combine load_and_record and TraceRecordingWrapper file structure (maybe create function for loading the results of both of these methods)
    is_minatar = args['env_name'] in ['asterix', 'breakout', 'freeway', 'seaquest', 'space_invaders']
    
    raw_env_class = None
    if is_minatar:
        from utils.min_atari_utils import MinAtariDMEnv
        from minatar import Environment
        raw_env_class = lambda: MinAtariDMEnv(Environment(args['env_name']))
        raw_env = raw_env_class()
        env = TraceRecordingWrapper(raw_env, save_dir, only_reward=args['only_store_rewards'], preprocess_obs=preprocess_saved_obs, save_type=args['save_type'], log_interval=args['log_interval'], logger=logging)
    elif '/' not in args['env_name'] and '\\' not in args['env_name']: # This checks whether the enviornment is part of bsuite or gym
        raw_env = gym.make(args['env_name'])
        if len(raw_env.observation_space.shape) > 2: # This means that it is one of the atari games
            
            raw_env_class = lambda: wrap_deepmind(gym.make(args['env_name']))
            raw_env = wrap_deepmind(raw_env)
        
        else:
            
            def raw_env_class_gen():
                tmp_env = gym.make(args['env_name'])

                if 'MountainCar' in args['env_name']:
                    tmp_env._max_episode_steps = 5000
                    tmp_env = SparseMountainCarWrapper(tmp_env, tmp_env._max_episode_steps)
                    logging.info('MountainCar with')
                elif 'CartPole' in args['env_name']:
                    tmp_env._max_episode_steps = 200
                
                return tmp_env

            raw_env_class = raw_env_class_gen
        
        del raw_env
        raw_env = raw_env_class()
        env = raw_env #TODO
        env = TraceRecordingWrapper(raw_env, save_dir, only_reward=args['only_store_rewards'], preprocess_obs=preprocess_saved_obs, save_type=args['save_type'], log_interval=args['log_interval'], logger=logging)
    else:
        # raw_env = bsuite.load_and_record(args['env_name'], save_path=save_dir, overwrite=True)
        
        import bsuite
        from bsuite import sweep
        from bsuite.utils import gym_wrapper
        raw_env_class = lambda: gym_wrapper.GymFromDMEnv(bsuite.load_from_id(args['env_name']))
        raw_env = raw_env_class()
        env = TraceRecordingWrapper(raw_env, save_dir, only_reward=args['only_store_rewards'], preprocess_obs=preprocess_saved_obs, save_type=args['save_type'], log_interval=args['log_interval'], logger=logging)
        
        # Logging the settings of this environment
        bsuite_id = args['env_name'] # save_dir[(save_dir.rfind('/'))+1:]
        print('bsuite_id={}, settings={}, num_episodes={}'
            .format(bsuite_id, sweep.SETTINGS[bsuite_id], env.bsuite_num_episodes))

    # Set random seeds
    seed: int = args['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    env.seed(seed)
 
    # In the initial stage of our research project, we are going to only benchmark our algorithms on enviornments with discrete action spaces
    discrete: bool = isinstance(env.action_space, gym.spaces.Discrete)
    assert discrete, 'The action space shuold be discrete'
    print('Is action-space discrete? ', discrete)

    # Are the observations images? # TODO: functionality to work with images should be added for playing Atari
    img: bool = len(env.observation_space.shape) > 2
    

    # Check if it is an Atari Environment
    is_atari = img and not is_minatar
    print('Is it atari environment? ', is_atari)

    # Observation and action sizes
    if len(env.observation_space.shape) == 0:
        ob_dim = env.observation_space.n
    else:
        ob_dim = env.observation_space.shape if img else int(np.prod(env.observation_space.shape))

    print('Dimensions of Observations: ', ob_dim)

    in_channels = None
    if is_minatar:
        in_channels = ob_dim[2]
    
    ac_dim = env.action_space.n if discrete else env.action_space.n
    
    # Setting Exploration Schedule 
    if args['exploration_schedule'] == 0:
        exploration_schedule = constant_exploration_schedule(exp_value)
    elif args['exploration_schedule'] == 1:
        exploration_schedule = stepwise_exploration_schedule(args['num_timesteps'], args['outside_value'], args['portion_decay'], args['exploration_strategy'])
    elif args['exploration_schedule'] == 2:
        exploration_schedule = episode_based_linear_exploration_schedule(exp_value)
    else:
        raise ValueError("Value of exploration_schedule is wrong: {}".format(args['exploration_schedule']))

    # Setting Replay Buffer Size
   #  if is_atari:
        # replay_buffer_size = int(1e6)
    # elif is_minatar: 
        # replay_buffer_size = int(1e5)
    # else:
        # replay_buffer_size = int(5e4)

    g_bound = None

#    if args['use_normalization_scheme']:
#        if 'MountainCar-v0' == args['env_name']:
#            g_bound = (0, 1)
#        elif 'CartPole-v0' == args['env_name']:
#            g_bound = (1, 99.34)
#        elif 'Acrobot-v1' == args['env_name']:
#            g_bound = (-99.34, 0)
#        else:
#            raise ValueError('There is no normalization scheme for this environment: {}'.format(args['env_name']))

    agent_params: dict = {
            'algorithm': args['algorithm'],
            'batch_size': args['batch_size'],
            'ac_dim': ac_dim,
            'ob_dim': ob_dim,
            'in_channels': in_channels,
            'input_shape': (10, 10) if is_minatar else (84, 84), 
            'gamma': args['gamma'],
            'learning_starts': args['learning_starts'],
            'learning_freq': 4 if is_atari else 1,
            'target_update_freq': args['target_update_freq'],
            'exploration_strategy': args['exploration_strategy'],
            'exploration_schedule': exploration_schedule,
            'episode_based_exploration': args['exploration_schedule'] == 2, 
            'optimizer_spec': default_optimizer_spec(args['step_size']),
            'device': device,
            'grad_norm_clipping': 10,
            'replay_buffer_size': args['replay_buffer_size'],
            'frame_history_len': 4 if is_atari else 1, # TODO this should be make dynamic in the future
            'double_q': args['double_q'],
            'use_normalization_scheme': args['use_normalization_scheme'],
            'g_bound': g_bound,
            'td_error_mg': args['td_error_mg'],
            'td_error_mg_lr': args['td_error_mg_lr'],
            'td_error_mg_epsilon': args['td_error_mg_epsilon'],
            'td_error_scheduling': args['td_error_scheduling']
            }

    logging.info('Agent Params: {}'.format(agent_params))

    if args['agent_type'] == "non-linear": # TODO: it is better to change the name of --agent_type parameter to fa_type or function_approximation_type
        if is_atari:
            agent_params['q_func'] = create_atari_q_network
        elif is_minatar:
            agent_params['q_func'] = MinAtarQNetwork
        else:
            agent_params['q_func'] = lambda ob_shape, ac_shape: MLP(ob_shape, ac_shape, args['n_layers'], args['nn_size'], F.relu)
    elif args['agent_type'] == "linear":
        agent_params['q_func'] = lambda ob_shape, ac_shape: LFANet(ob_shape, ac_shape, 0, args['nn_size'])

    preprocess_obs = None
    if 'deep_sea' in args['env_name']:
        preprocess_obs = deep_sea_obs_preprocess #TODO: this should be done for other enviornments
    elif 'cartpole' in args['env_name']:
        preprocess_obs = cartpole_obs_preprocess
    elif 'mountain_car' in args['env_name']:
        preprocess_obs = mountain_car_obs_preprocess
    elif 'riverswim' in args['env_name']:
        preprocess_obs = riverswim_obs_preprocess(env.observation_space.n)
    elif ob_dim == 1:
        preprocess_obs = scalar_obs_preprocess

    agent = DQNAgent(env, agent_params, preprocess_obs=preprocess_obs, render=args['render'])

    # Offline Evaluation data
    offline_eval_returns = []
    offline_eval_steps = []

    # init vars at beginning of training
    total_envsteps: int = 0
    learning_started: bool = False
    start_time: float = time.time()
    loss: float = 0.0
    #losses = np.empty(args['num_timesteps']*args['num_agent_train_steps_per_iter'])
    #td_errors = np.empty(args['num_timesteps']*args['num_agent_train_steps_per_iter'])
    #weights_changes = np.empty(args['num_timesteps']*args['num_agent_train_steps_per_iter'])

    logging.info('## Started ##')

    # Variables needed for logging evaluation time
    evaluation_time_spent = 0
    evaluation_nums = 0

    for itr in range(args['num_timesteps']):

        # offline evaluation
#        if itr % args['log_interval'] == 0:
#            eval_time = time.time()
#            evaluation_nums += 1
#            off_returns, off_steps = agent.offline_eval_episodes(args['eval_episodes_num'], raw_env_class, seed)
#            logging.info("Offline Evaluation Returns -> mean: {}, median: {}, max: {}, min: {}".format(np.mean(off_returns), np.median(off_returns), np.max(off_returns), np.min(off_returns)))
#            logging.info("Offline Evaluation Steps -> mean: {}, median: {}, max: {}, min: {}".format(np.mean(off_steps), np.median(off_steps), np.max(off_steps), np.min(off_steps)))
#            logging.info('Evaluation time: ' + str(evaluation_time_spent/evaluation_nums))
#            logging.info('======================= Timestep: {} =========================='.format(itr))
#            
#            offline_eval_returns.append(np.mean(off_returns))
#            offline_eval_steps.append(np.mean(off_steps))
#            evaluation_time_spent += time.time() - eval_time

        # collect trajectories, to be used for training
        agent.step_env()
        envsteps_this_batch = 1

        total_envsteps += envsteps_this_batch

        # train agent (using sampled data from replay buffer)
        if verbose: print('\nTraining agent using sampled data from replay buffer...')

        for _ in range(args['num_agent_train_steps_per_iter']):

            # sample some data from the data buffer
            ob_batch, ac_batch, re_batch, next_ob_batch, terminal_batch = agent.sample(args['batch_size'])

            # use the sampled data for training
            loss, td_error, weights_change = agent.train(ob_batch, ac_batch, re_batch, next_ob_batch, terminal_batch)

            # Logging Loss in a proper way
            update_calls = agent.t -1
            is_learning = update_calls > agent_params['learning_starts'] and update_calls % agent_params['learning_freq'] == 0
            
            if update_calls < agent_params['learning_starts']:
                pass

            # elif is_learning:
                # losses[update_calls] = loss/args['num_agent_train_steps_per_iter']
                # td_errors[update_calls] = td_error/args['num_agent_train_steps_per_iter']
                # weights_changes[update_calls] = weights_change/args['num_agent_train_steps_per_iter']
                # if not learning_started:
                    # for l_itr in range(update_calls+1):
                        # losses[l_itr] = losses[l_itr-1]
                        # td_errors[l_itr] = td_errors[l_itr-1]
                        # weights_changes[l_itr] = weights_changes[l_itr-1]
                    # learning_started = True

            # elif update_calls % agent_params['learning_freq'] != 0 and learning_started:
                # losses[update_calls] = losses[update_calls-1]
                # td_errors[update_calls] = td_errors[update_calls-1]
                # weights_changes[update_calls] = weights_changes[update_calls-1]

    # logger.log_arr(losses, 'losses.npy')
    # logger.log_arr(td_errors, 'td_errors.npy')
    # logger.log_arr(weights_changes, 'weight_differences.npy')
#    logger.log_arr(offline_eval_returns, 'offline_eval_returns.npy')
#    logger.log_arr(offline_eval_steps, 'offline_eval_steps.npy')

    logging.info("Execution Time (s): {}".format(time.time() - start_time))
    logging.info("Execution Time (m): {}".format((time.time() - start_time)/60))


    #logging.info('Evaluation Time: ' + str(evaluation_time_spent/evaluation_nums))

    env.close()
    logger.log_arr(np.array(agent.action_values), 'action_values.npy')
    logging.info("Root Finder Failures: {}".format(agent.root_finder.failed_count))

if __name__ == '__main__':
    ARGS = get_args()
    main(ARGS)
