import matplotlib
matplotlib.use('Agg')
import os
import sys
import argparse
sys.path.append("..")

import numpy as np
import pandas as pd 

from Shared.data_extractors import *
from Shared.data_preprocessors import *
from Shared.visualization_tools import *

def get_args():
    """
    This function will extract the arguments from the command line
    """
 
    parser = argparse.ArgumentParser(description='Generates all the required plots for Linear adaptation')

    parser.add_argument('--root_dir_q', default='results_adaptation/q-learning', type=str,
            nargs='?', help="The path of the results for q-learning")
    parser.add_argument('--root_dir_exp', default='results_adaptation/expected_sarsa', type=str,
            nargs='?', help="The path of the results for expected sarsa")

    parser.add_argument('--save_dir_q', default='plots/adaptation/q', type=str,
            nargs='?', help="The root path that should be used to save the plots for q-learning")
    parser.add_argument('--save_dir_exp', default='plots/adaptation/exp', type=str,
            nargs='?', help="The root path that should be used to save the plots for expected sarsa")

    return vars(parser.parse_args())

def create_path(directory):

    if not os.path.exists(directory):
        os.makedirs(directory)
        print("Save path didn't exist, so it has been created: {}".format(directory))

if __name__ == '__main__':
    
    args = get_args()
    
    # Loading & Preprocessing the data for q-learning

    root = args['root_dir_q']

    save_dir = args['save_dir_q']
    linear_bar_save_dir = os.path.join(save_dir, 'bar_plots')
    linear_sens_save_dir = os.path.join(save_dir, 'sensitivity_plots')
#     linear_lc_save_dir = os.path.join(save_dir, 'best_lr_plots')
    
    create_path(linear_bar_save_dir)
    create_path(linear_sens_save_dir)
#     create_path(linear_lc_save_dir)
        
    path = root
    STEP_SIZES = [0.1]
    
    ENV_IDS = "MountainCar-v0"
    env_step_sizes = {
            ENV_IDS: str(STEP_SIZES[0]),
    }
    exp_values = {
        'epsilon-greedy': [(0.05, 0.01), (0.1, 0.05), (0.2, 0.1)],
        'ResMax': [2**i for i in [0, 8, 14]],
        'softmax': [2**i for i in [0, 4, 8]],
        'mellowmax': [2**i for i in [0, 4, 8]],
        }     
    data = collect_experiments_data(path, extracted_type='episode_steps')
    bar_plot(data, env_step_sizes, exp_values, 'steps', save_dir=linear_bar_save_dir)
    
    exp_values = {
        'resmax': [2**n for n in range(0,19,2)],
        'softmax': [2**n for n in range(0,19,2)],
        'mellowmax': [2**n for n in range(0,19,2)],
        }

    data = collect_experiments_data(path, extracted_type='episode_steps')
    sensitivity_plot_2(data, env_step_sizes, exp_values, 'steps', save_dir=linear_sens_save_dir)
#     best_learning_curve_exp_algo(data, 'steps', is_steps=True, save_dir=linear_lc_save_dir)

    ENV_IDS = "CartPole-v0"
    env_step_sizes = {
            ENV_IDS: str(STEP_SIZES[0]),
    }
    exp_values = {
            'epsilon-greedy': [(0.05, 0.01), (0.1, 0.05), (0.2, 0.1)],
            'ResMax': [2**i for i in [-4, 4, 10]],
            'softmax': [2**i for i in [-4, 0, 4]],
            'mellowmax': [2**i for i in [0, 4, 8]],
    }

    data = collect_experiments_data(path, extracted_type='episode_returns')
    bar_plot(data, env_step_sizes, exp_values, 'returns',save_dir=linear_bar_save_dir)

    exp_values = {
            'ResMax': [2**n for n in range(-8,19,2)],
            'softmax': [2**n for n in range(-8,19,2)],
            'mellowmax': [2**n for n in range(0,19,2)],
    }
    sensitivity_plot_2(data, env_step_sizes, exp_values, 'returns', save_dir=linear_sens_save_dir)
#     best_learning_curve_exp_algo(data, 'returns', is_steps=False, save_dir=linear_lc_save_dir)

    #########################################################################
    # Loading & Preprocessing the data for expected sarsa
    root = args['root_dir_exp']
    save_dir = args['save_dir_exp']
    linear_bar_save_dir = os.path.join(save_dir, 'bar_plots')
    linear_sens_save_dir = os.path.join(save_dir, 'sensitivity_plots')
#     linear_lc_save_dir = os.path.join(save_dir, 'best_lr_plots')
    
    create_path(linear_bar_save_dir)
    create_path(linear_sens_save_dir)
#     create_path(linear_lc_save_dir)
    
    path = root
    STEP_SIZES = [0.1]

    ENV_IDS = "MountainCar-v0"
    
    env_step_sizes = {
            ENV_IDS: str(STEP_SIZES[0]),
    }

    exp_values = {
        'epsilon-greedy': [(0.05, 0.01), (0.1, 0.05), (0.2, 0.1)],
        'ResMax': [2**i for i in [0, 8, 14]],
        'softmax': [2**i for i in [0, 4, 8]],
        'mellowmax': [2**i for i in [0, 4, 8]],
        }     
    data = collect_experiments_data(path, extracted_type='episode_steps')
    bar_plot(data, env_step_sizes, exp_values, 'steps', save_dir=linear_bar_save_dir)
    
    exp_values = {
        'resmax': [2**n for n in range(0,19,2)],
        'softmax': [2**n for n in range(0,19,2)],
        'mellowmax': [2**n for n in range(0,19,2)],
        }

    data = collect_experiments_data(path, extracted_type='episode_steps')
    sensitivity_plot_2(data, env_step_sizes, exp_values, 'steps', save_dir=linear_sens_save_dir)


    ENV_IDS = "CartPole-v0"
    env_step_sizes = {
            ENV_IDS: str(STEP_SIZES[0]),
    }
    exp_values = {
            'epsilon-greedy': [(0.05, 0.01), (0.1, 0.05), (0.2, 0.1)],
            'ResMax': [2**i for i in [-4, 4, 10]],
            'softmax': [2**i for i in [0, 4, 8]],
            'mellowmax': [2**i for i in [0, 4, 8]],
    }

    data = collect_experiments_data(path, extracted_type='episode_returns')
    bar_plot(data, env_step_sizes, exp_values, 'returns',save_dir=linear_bar_save_dir)

    exp_values = {
            'resmax': [2**n for n in range(0,19,2)],
            'softmax': [2**n for n in range(0,19,2)],
            'mellowmax': [2**n for n in range(0,19,2)],
    }
    sensitivity_plot_2(data, env_step_sizes, exp_values, 'returns', save_dir=linear_sens_save_dir)
#     best_learning_curve_exp_algo(data, 'returns', is_steps=False, save_dir=linear_lc_save_dir)
