# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import argparse
import os
import time
import sys

# To ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Dueling bandit environments
from environment import DuelingEnv

# Dueling bandit learners
from learners import NeuralDB
from learners import LinearDB
from learners import CoLSTIM

# Plotting the results
from plotting_function import cumulative_regret_plotting

# Getting current directory path
cwd = os.getcwd()
sys.path.append(cwd)


# Input arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Neural Dueling Bandit")
    parser.add_argument(
        "--db_function",
        type=str,
        default="square", 
        metavar="square|cosine",
        help="Name of dataset to use"
    )   
    parser.add_argument(
        "--dim",
        type=int,
        default=5,
        help="Set the dimension of context in the bandit problem"
    )
    parser.add_argument(
        "--arms",
        type=int,
        default=5,
        help="Set the number of arms"
    ) 
    parser.add_argument(
        "--size",
        type=int,
        default=100,
        help="Set the size of the bandit problem"
    )
    parser.add_argument(
        "--noise",
        type=float,
        default=1.0,
        help="Set the noise level for the arms"
    )
    parser.add_argument(
        "--suboptimality_gap",
        type=float,
        default=0.0,
        help="Set the optimality gap for the arms"
    )    
    parser.add_argument(
        "--learner",
        type=str,
        default="neural",
        metavar="neural|linear",
        help="Name of learner's type to use"    
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="ucb",
        metavar="ucb|ts",
        help="Set the strategy to use: TS or UCB."
    )
    parser.add_argument(
        "--diagonalize",
        type=bool,
        default=False, 
        metavar="True|False", 
        help="Use diagonalize for the inverse of gram matrix or not"
    )
    parser.add_argument(
        "--lamdba",
        type=float,
        default=1.0,
        help="Set the lamdba parameter."    
    ) 
    parser.add_argument(
        "--nu",
        type=float,
        default=1.0,
        help="Set the parameter nu."    
    )
    parser.add_argument(
        "--learner_update",
        type=int,
        default=20,
        metavar="10|20|50|100",
        help="Set the update frequency of the learner"
    )
    parser.add_argument(
        "--hidden",
        type=int,
        default=100,
        metavar="32|100",
        help="Set the network hidden size"
    )
    parser.add_argument(
        "--layers",
        type=int,
        default=1,
        metavar="2|1",
        help="Set the number of hidden layers"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Set the random seed for numpy/Torch"
    )
    parser.add_argument(
        "--plots",
        type=bool,
        default=False,
        metavar="True|False",
        help="Plot the results or not"
    )
    parser.add_argument(
        "--runs",
        type=int,
        default=2,
        help="Set the number of runs"
    )
    
    return parser.parse_args()


# Starting the main function
if __name__ == '__main__':
    # Parsing the input arguments
    args = parse_args()
    
    # Setting the random seed for numpy
    np.random.seed(args.seed)
    
    # Dueling bandit environment
    db_env = DuelingEnv(args.db_function, 
                   args.dim, 
                   args.arms, 
                   args.size, 
                   args.noise, 
                   args.suboptimality_gap,
                   args.seed
    )

    # Dueling bandit learners
    if args.learner == 'neural':
        learner = NeuralDB(db_env.dim, args.lamdba, args.nu, args.strategy, args.diagonalize, args.learner_update)
        learner_info = '{}_{}_{}_{}_{}_{}'.format(
            args.learner,        
            args.strategy,        
            args.diagonalize,
            args.lamdba, 
            args.nu,  
            args.learner_update
        )
        cases = ['NeuralDB-UCB'] if args.strategy == 'ucb' else ['NeuralDB-TS']
     
    elif args.learner == 'linear': 
        learner = LinearDB(db_env.dim, args.lamdba, args.nu, args.strategy, args.learner_update)
        learner_info = '{}_{}_{}_{}_{}'.format(
            args.learner,        
            args.strategy,        
            args.lamdba, 
            args.nu,  
            args.learner_update
        )
        cases=['LinDB-UCB'] if args.strategy == 'ucb' else ['LinDB-TS']
        
    elif args.learner == 'colstim': 
        learner = CoLSTIM(db_env.dim, args.lamdba, args.nu, db_env.size, args.learner_update)
        learner_info = '{}_{}_{}_{}_{}'.format(
            args.learner,        
            args.strategy,        
            args.lamdba, 
            args.nu,  
            args.learner_update
        )
        cases=['CoLSTIM']
    
    else:
        raise RuntimeError('Learner not exist')
    
    
    
    # ### Interaction between the learner and the environment ###
    # Starting the time
    start_time = time.time()
    
    # Running over multiple runs
    algo_average_regret = []
    algo_weak_regret = []
    for r in tqdm(range(args.runs)):
        # Reset the environment
        db_env.reset()
        
        # Reset the learner
        learner.reset()
        
        # Loop through the bandit problem
        average_regret = []   
        weak_regret = []

        # Loop through the bandit problem
        for t in range(db_env.size):
        # while not db_env.finish():
            # Get the context-arms pair
            context_arms = db_env.get_context_arms()
            
            # Get the learner's actions
            at_1, at_2 = learner.select(context_arms)
            
            # print(f"Time: {t}, Arm 1: {at_1}, Arm 2: {at_2}")
            
            # Get the preference feedback
            feedback = db_env.get_feedback(at_1, at_2)
            
            # Update the learner's model
            if at_1 != at_2:
                learner.update(feedback)
            
            # Get the regret
            rt_avg, rt_weak = db_env.get_regret(at_1, at_2)
            
            # Append the regret
            average_regret.append(rt_avg)
            weak_regret.append(rt_weak)

        # Append the regret for the run
        algo_average_regret.append(average_regret)
        algo_weak_regret.append(weak_regret)
 
    # Save the regret data
    file_location = "data/plots/" + db_env.problem_name + "_" + learner_info
    file_to_save = file_location  + "_{}.npz".format(args.runs)
    np.savez(file_to_save, 
             average_regret = algo_average_regret, 
             weak_regret = algo_weak_regret, 
             time_taken = time.time() - start_time
             )

    # ### Plotting the regret ###
    if args.plots:
        # Plot location
        plot_location = file_location + "_{}".format(args.runs)
        
        # Average regret plotting
        plot_to_save = plot_location + "_average.png"
        cumulative_regret_plotting(algo_average_regret, cases, plot_to_save, 'lower right')

        # Weak regret plotting
        plot_to_save = plot_location + "_weak.png"
        cumulative_regret_plotting(algo_weak_regret, cases, plot_to_save, 'lower right')
    
    # Delete the object instances for the environment and the learner
    del db_env
    del learner
        
