# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import argparse
import os
import time
import sys

# To ignore warnings
import warnings
warnings.filterwarnings("ignore")

# GLM bandit environments
from environment import GLMEnv

# GLM bandit learners
from learners import NeuralGLM
from learners import LinearGLM

# Plotting the results
from plotting_function import cumulative_regret_plotting

# Getting current directory path
cwd = os.getcwd()
sys.path.append(cwd)


# Input arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Neural GLM Bandit")
    parser.add_argument(
        "--glm_function",
        type=str,
        default="square", 
        metavar="square|cosine",
        help="Name of dataset to use"
    )   
    parser.add_argument(
        "--dim",
        type=int,
        default=5,
        help="Set the dimension of context in the bandit problem"
    )
    parser.add_argument(
        "--arms",
        type=int,
        default=5,
        help="Set the number of arms"
    ) 
    parser.add_argument(
        "--size",
        type=int,
        default=100,
        help="Set the size of the bandit problem"
    )
    parser.add_argument(
        "--noise",
        type=float,
        default=1.0,
        help="Set the noise level for the arms"
    )
    parser.add_argument(
        "--suboptimality_gap",
        type=float,
        default=0.0,
        help="Set the optimality gap for the arms"
    )    
    parser.add_argument(
        "--learner",
        type=str,
        default="linear",
        metavar="neural|linear",
        help="Name of learner's type to use"    
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="ucb",
        metavar="ucb|ts",
        help="Set the strategy to use: TS or UCB."
    )
    parser.add_argument(
        "--diagonalize",
        type=bool,
        default=False, 
        metavar="True|False", 
        help="Use diagonalize for the inverse of gram matrix or not"
    )
    parser.add_argument(
        "--lamdba",
        type=float,
        default=1.0,
        help="Set the lamdba parameter."    
    ) 
    parser.add_argument(
        "--nu",
        type=float,
        default=1.0,
        help="Set the parameter nu."    
    )
    parser.add_argument(
        "--learner_update",
        type=int,
        default=1,
        metavar="1|10|20|50|100",
        help="Set the update frequency of the learner"
    )
    parser.add_argument(
        "--hidden",
        type=int,
        default=50,
        metavar="20|32|50|100",
        help="Set the network hidden size"
    )
    parser.add_argument(
        "--layers",
        type=int,
        default=2,
        metavar="1|2",
        help="Set the number of hidden layers"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Set the random seed for numpy/Torch"
    )
    parser.add_argument(
        "--plots",
        type=bool,
        default=False,
        metavar="True|False",
        help="Plot the results or not"
    )
    parser.add_argument(
        "--runs",
        type=int,
        default=2,
        help="Set the number of runs"
    )
    
    return parser.parse_args()


# Starting the main function
if __name__ == '__main__':
    # Parsing the input arguments
    args = parse_args()
    
    # Setting the random seed for numpy
    np.random.seed(args.seed)
    
    # GLM bandit environment
    glm_env = GLMEnv(args.glm_function, 
                   args.dim, 
                   args.arms, 
                   args.size, 
                   args.noise, 
                   args.suboptimality_gap,
                   args.seed
    )

    # GLM bandit learners
    if args.learner == 'neural':
        learner = NeuralGLM(glm_env.dim, args.lamdba, args.nu, args.strategy, args.diagonalize, args.learner_update)
        learner_info = '{}_{}_{}_{}_{}_{}'.format(
            args.learner,        
            args.strategy,        
            args.diagonalize,
            args.lamdba, 
            args.nu,  
            args.learner_update
        )
        cases = ['NeuralGLM-UCB'] if args.strategy == 'ucb' else ['NeuralGLM-TS']
     
    elif args.learner == 'linear': 
        learner = LinearGLM(glm_env.dim, args.lamdba, args.nu, args.strategy, args.learner_update)
        learner_info = '{}_{}_{}_{}_{}'.format(
            args.learner,        
            args.strategy,        
            args.lamdba, 
            args.nu,  
            args.learner_update
        )
        cases=['LinGLM-UCB'] if args.strategy == 'ucb' else ['LinGLM-TS']
    
    else:
        raise RuntimeError('Learner not exist')
    

    # ### Interaction between the learner and the environment ###
    # Starting the time
    start_time = time.time()
    
    # Running over multiple runs
    algo_regret = []
    for r in tqdm(range(args.runs)):
        # Reset the environment
        glm_env.reset()
        
        # Reset the learner
        learner.reset()
        
        # Loop through the bandit problem
        regret = []   

        # Loop through the bandit problem
        for t in range(glm_env.size):
        # while not glm_env.finish():
            # Get the context-arms pair
            context_arms = glm_env.get_context_arms()
            
            # Get the learner's actions
            a_t = learner.select(context_arms)
            
            # print(f"Time: {t}, Arm: {a_t}")
            
            # Get the preference feedback
            feedback = glm_env.get_feedback(a_t)
            
            # Update the learner's model
            learner.update(feedback)
            
            # Get the regret
            rt = glm_env.get_regret(a_t)
            
            # Append the regret
            regret.append(rt)

        # Append the regret for the run
        algo_regret.append(regret)
 
    # Save the regret data
    file_location = "data/plots/glm_" + glm_env.problem_name + "_" + learner_info
    file_to_save = file_location  + "_{}.npz".format(args.runs)
    np.savez(file_to_save, 
             regret = algo_regret, 
             time_taken = time.time() - start_time
             )

    # ### Plotting the regret ###
    if args.plots:
        # Plot location
        plot_location = file_location + "_{}".format(args.runs)
        
        # Average regret plotting
        plot_to_save = plot_location + ".png"
        cumulative_regret_plotting(algo_regret, cases, plot_to_save, 'lower right')
    
    # Delete the object instances for the environment and the learner
    del glm_env
    del learner
        
