# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import argparse
import os
import time
import sys

# To ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Synthetic Dueling bandit environments for active learning
from environments.synthetic import SyntheticEnv

# Dueling bandit learners
from learners.baselines import Random, AEBorda, APO
from learners.neural_adb import NeuralAPO, NeuralADB, NeuralADBIG
from learners.neural_adb import AEBordaNNGrad, NeuralAPOGrad, NeuralADBGrad, NeuralADBIGGrad
from learners.neural_adb import NeuralADBGradAbl

# Plotting the results
from utils.plotting_functions import average_plotting, cumulative_regret_plotting

# Getting current directory path
cwd = os.getcwd()
sys.path.append(cwd)


# Input arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Active Dueling Bandit")
    parser.add_argument(
        "--db_function",
        type=str,
        default="square", 
        metavar="linear|square|cosine",
        help="Name of dataset to use"
    ) 
    parser.add_argument(
        "--contexts",
        type=int,
        default=1000,
        help="Set the number of contexts in the dueling bandit problem"
    )
    parser.add_argument(
        "--arms",
        type=int,
        default=5,
        help="Set the number of arms"
    )  
    parser.add_argument(
        "--dim",
        type=int,
        default=5,
        help="Set the dimension of context-arm feature vector in the dueling bandit problem"
    )  
    parser.add_argument(
        "--learner",
        type=str,
        default="NeuralADBGrad",
        metavar="Random|AEBordaNN|APO|NeuralADB|NeuralADBIG",
        help="Name of dueling bandit learner's type to use"    
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="ucb",
        metavar="ucb|ts",
        help="Set the strategy to use: TS or UCB."
    )
    parser.add_argument(
        "--diagonalize",
        type=bool,
        default=False, 
        metavar="True|False", 
        help="Use diagonalize for the inverse of gram matrix or not"
    )
    parser.add_argument(
        "--lamdba",
        type=float,
        default=1.0,
        help="Set the lamdba parameter."    
    ) 
    parser.add_argument(
        "--nu",
        type=float,
        default=1.0,
        help="Set the parameter nu."    
    )
    parser.add_argument(
        "--learner_update",
        type=int,
        default=10,
        metavar="10|20|50|100",
        help="Set the update frequency of the learner"
    )
    parser.add_argument(
        "--hidden",
        type=int,
        default=32,
        metavar="32|100",
        help="Set the network hidden size"
    )
    parser.add_argument(
        "--layers",
        type=int,
        default=2,
        metavar="2|1",
        help="Set the number of hidden layers"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Set the random seed for numpy/Torch"
    )
    parser.add_argument(
        "--rounds",
        type=int,
        default=200,
        help="Set the number of rounds"
    )
    parser.add_argument(
        "--runs",
        type=int,
        default=2,
        help="Set the number of runs"
    )
    parser.add_argument(
        "--plots",
        type=bool,
        default=False,
        metavar="True|False",
        help="Plot the results or not"
    )
    
    return parser.parse_args()


# Starting the main function
if __name__ == '__main__':
    # Parsing the input arguments
    args = parse_args()
    
    # Setting the random seed for numpy
    np.random.seed(args.seed)

    # Learners disctionary
    learners = {
        'Random'            : Random,
        'AEBorda'           : AEBorda,
        'APO'               : APO,
        'NeuralAPO'         : NeuralAPO,
        'NeuralADB'         : NeuralADB, 
        'NeuralADBIG'       : NeuralADBIG,
        'AEBordaNNGrad'     : AEBordaNNGrad,
        'NeuralAPOGrad'     : NeuralAPOGrad,
        'NeuralADBGrad'     : NeuralADBGrad,
        'NeuralADBIGGrad'   : NeuralADBIGGrad,
        'NeuralADBGradAbl'  : NeuralADBGradAbl

    }
    
    # ### Dueling bandit environment ###
    syn_env = SyntheticEnv(
                args.db_function, 
                args.contexts, 
                args.arms,  
                args.dim,
                args.seed
            )

    # ### Dueling bandit learners ###
    if args.learner not in learners:
        raise RuntimeError('Learner not exist')
    
    learner = learners[args.learner](
                    syn_env.all_context_arms, 
                    args.dim,
                    args.lamdba, 
                    args.nu, 
                    args.strategy, 
                    args.learner_update
                )
        
    learner_info = '{}_{}_{}_{}_{}_{}_{}'.format(
                    args.learner,        
                    args.strategy,        
                    args.lamdba, 
                    args.nu,  
                    args.rounds,
                    args.learner_update,
                    args.runs
                )
    if args.learner == 'NeuralADBGradAbl':
        learner = learners[args.learner](
                    syn_env.all_context_arms, 
                    args.dim,
                    args.lamdba, 
                    args.nu, 
                    args.strategy, 
                    args.learner_update,
                    args.layers,
                    args.hidden
                )
        
        learner_info = '{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(
                    args.learner,        
                    args.strategy,        
                    args.lamdba, 
                    args.nu,  
                    args.rounds,
                    args.learner_update,
                    args.layers,
                    args.hidden,
                    args.runs
                )


    cases = [args.learner]

    # ### Interaction between the learner and the environment ###
    # Ensure the number of rounds is not more than contexts
    assert args.rounds <= args.contexts, "Number of rounds should not be more than number of contexts"

    # Starting the time
    start_time = time.time()
    
    # Running over multiple runs
    algo_suboptimality_gap = []
    algo_rmse = []
    algo_mae = []
    algo_average_regret = []
    algo_weak_regret = []
    for r in tqdm(range(args.runs)):
        
        # Reset the learner
        learner.reset()
        
        # Storing suboptimality gap after each round
        iter_suboptimality_gap = []   
        iter_rmse = []
        iter_mae = []
        iter_average_regret = []
        iter_weak_regret = []

        # Loop through the bandit problem
        for t in range(args.rounds):
        # while not syn_env.finish():
            # Get the next context-arm pair
            c_t, xt_1, xt_2 = learner.select()

            # Ensure the selected arms are not the same
            assert xt_1 is not xt_2, "Selected arms should not be the same"

            # Get the preference feedback
            feedback = syn_env.get_preference(xt_1, xt_2)
            
            # Update the learner's model
            learner.update(feedback)

            # Get the regret
            rt_avg, rt_weak = syn_env.get_regret(c_t, xt_1, xt_2)
            
            # Get the estimated best arm for all contexts
            best_arms = learner.get_policy()

            # Get and append the sub-optimality gap
            suboptimality_gap, rmse, mae = syn_env.get_suboptimality_gap(best_arms)
            iter_suboptimality_gap.append(suboptimality_gap) 
            iter_rmse.append(rmse)
            iter_mae.append(mae)
            iter_average_regret.append(rt_avg)
            iter_weak_regret.append(rt_weak)

            # print('Round: {}, Sub-optimality gap: {}'.format(t, suboptimality_gap))

        # Append the regret for the run
        algo_suboptimality_gap.append(iter_suboptimality_gap)
        algo_rmse.append(iter_rmse)
        algo_mae.append(iter_mae)
        algo_average_regret.append(iter_average_regret)
        algo_weak_regret.append(iter_weak_regret)
 
    # ### Save the regret data for plotting results ###
    file_location = "data/plots/" + syn_env.problem_name + "_" + learner_info
    file_to_save_subg = file_location  + ".npz"
    np.savez(
        file_to_save_subg, 
        subg = algo_suboptimality_gap, 
        rmse = algo_rmse, 
        mae = algo_mae,
        average_regret = algo_average_regret,
        weak_regret = algo_weak_regret,
        time_taken = time.time() - start_time
    )

    # ### Plotting the regret ###
    if args.plots:
        # Plot location
        plot_location = file_location + "_{}".format(args.runs)
        
        # Average sub-optimizality gap plotting
        plot_to_save = plot_location + "_subg.png"
        average_plotting(algo_suboptimality_gap, cases, plot_to_save, 'upper right', 'Sub-optimality Gap')

        # Average RMSE plotting
        plot_to_save = plot_location + "_rmse.png"
        average_plotting(algo_rmse, cases, plot_to_save, 'upper right', 'RMSE')

        # Average MAE plotting
        plot_to_save = plot_location + "_mae.png"
        average_plotting(algo_mae, cases, plot_to_save, 'upper right', 'MAE')

        # Average average regret plotting
        plot_to_save = plot_location + "_average_regret.png"
        cumulative_regret_plotting(algo_average_regret, cases, plot_to_save, 'upper right', 'Average Regret')

        # Average weak regret plotting
        plot_to_save = plot_location + "_weak_regret.png"
        cumulative_regret_plotting(algo_weak_regret, cases, plot_to_save, 'upper right', 'Weak Regret')


    # Delete the object instances for the environment and the learner
    del syn_env
    del learner
        
