# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import argparse
import os
import time
import sys

# To ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Dueling bandit environments
from environments.strategic import StrategicEnv, StrategicEnvGrad

# Dueling bandit learners
from learners.baselines import Random, Linear, OptGTM
from learners.cobra import COBRA, CobraLoo

# Plotting the results
from plotting_functions import cumulative_regret_error

# Getting current directory path
cwd = os.getcwd()
sys.path.append(cwd)


# Input arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Contextual Bandits")
    
    # Environment parameters
    parser.add_argument(
        "--environment",
        type=str,
        default="strategic", 
        metavar="strategic|strategic_grad",
        help="Name of environment to use"
    )
    parser.add_argument(
        "--reward_function",
        type=str,
        default="linear", 
        metavar="linear|square|cosine",
        help="Name of reward function to use"
    )
    parser.add_argument(
        "--contexts",
        type=int,
        default=1000,
        help="Nunmber of contexts of the bandit problem"
    )
    parser.add_argument(
        "--arms",
        type=int,
        default=5,
        help="Number of arms/agents"
    )    
    parser.add_argument(
        "--context_dim",
        type=int,
        default=2,
        help="Set the dimension of context feature vector"
    )
    parser.add_argument(
        "--arm_dim",
        type=int,
        default=2, 
        help="Set the dimension of arm feature vector"
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=0.1,
        help="Standard deviation of Gaussian noise in reward."
    )

    # Strategic environment parameters: Using noisy estimate
    parser.add_argument(
        "--delta_max",
        type=float,
        default=1.0,
        help="Added fraction in context to get higher reward"
    ) 
    parser.add_argument(
        "--delta_sigma",
        type=float,
        default=0.2,
        help="Standard deviation of Gaussian for perturbation"
    )
    parser.add_argument(
        "--strategic_nature",
        type=str,
        default="static",
        metavar="static|dynamic|best",
        help="Nature of the strategic behavior"
    )

    # Strategic environment parameters: Using GD
    parser.add_argument(
        "--eta",
        type=float,
        default=0.1,
        help="Cost parameter for strategic arms"
    )
    parser.add_argument(
        "--eta_noise",
        type=float,
        default=0.0,
        help="Noise in cost parameter for strategic arms"
    )

    # Other environment parameters
    parser.add_argument(
        "--poly",
        type=str,
        default="poly_false",
        metavar="poly_true|poly_false",
        help="Use polynomial features or not"
    )
    parser.add_argument(
        "--vary_arms",
        type=str,
        default="vary_arms",
        metavar="vary_arms|fix_arms",
        help="Vary the arms across rounds or not"
    )
    parser.add_argument(
        "--strategic_arm",
        type=str,
        default="strategic_arm",
        metavar="strategic_arm|no_strategic_arm",
        help="Use strategic arms or not"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Set the random seed for numpy/Torch"
    )

    # Learner parameters
    parser.add_argument(
        "--learner",
        type=str,
        default="cobra",
        metavar="linear|cobraloo|optgtm|uniform|cobra",
        help="Name of dueling bandit learner's type to use"    
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="ucb",
        metavar="ucb|ts",
        help="Set the strategy to use: TS or UCB."
    )
    parser.add_argument(
        "--mechanism",
        type=str,
        default="loom",
        metavar="self|loom",
        help="Name of dueling bandit learner's type to use"    
    )
    parser.add_argument(
        "--lamdba",
        type=float,
        default=0.01,
        help="Set the lamdba parameter."    
    ) 
    parser.add_argument(
        "--delta",
        type=float,
        default=0.05,
        help="Set the failure probability."    
    )
    parser.add_argument(
        "--max_sigma",
        type=float,
        default=0.1,
        help="Set the max noise level."    
    )

    # Experiment parameters
    parser.add_argument(
        "--runs",
        type=int,
        default=2,
        help="Set the number of runs"
    )
    parser.add_argument(
        "--plot",
        type=str,
        default="no_plot",
        metavar="plot|no_plot",
        help="Plot the results or not"
    )
    
    return parser.parse_args()


# Starting the main function
if __name__ == '__main__':
    # Parsing the input arguments
    args = parse_args()
    
    # Setting the random seed for numpy
    np.random.seed(args.seed)

    # Environment dictionary
    environments = {
        'strategic'         :   StrategicEnv,
        'strategic_grad'    :   StrategicEnvGrad
    }

    # Learners dictionary
    learners = {
        'random'    :   Random,
        'linear'    :   Linear,
        'optgtm'    :   OptGTM,
        'cobra'     :   COBRA,
        'cobarloo'  :   CobraLoo
    }

    # Strategic bandit environment
    if args.environment not in environments:
        raise RuntimeError('Environment not exist')

    # Strategic environment
    environment = environments[args.environment](
                    args.reward_function, 
                    args.contexts, 
                    args.arms, 
                    args.context_dim, 
                    args.arm_dim, 
                    args.sigma,
                    args.delta_max,
                    args.delta_sigma,
                    args.strategic_nature,
                    args.poly,
                    args.vary_arms,
                    args.strategic_arm,
                    args.seed
                )

    # ### Contextual bandit learners ###
    if args.learner not in learners:
        raise RuntimeError('Learner not exist')
    
    learner = learners[args.learner](
                    environment.arms,
                    environment.dim,
                    args.strategy,
                    args.lamdba, 
                    args.delta, 
                    args.max_sigma
                )
        
    learner_info = '{}_{}_{}_{}_{}_{}'.format(
                    args.learner,        
                    args.strategy,  
                    args.mechanism,      
                    args.lamdba, 
                    args.delta, 
                    args.max_sigma
                )
    
    
    cases = [args.learner + ' (UCB)'] if args.strategy == 'ucb' else [args.learner + ' (TS)']

    
    # ### Interaction between the learner and the environment ###
    # Starting the time
    start_time = time.time()
    
    # Running over multiple runs
    algo_regret = []
    for r in tqdm(range(args.runs)):
        # Regret for the run
        regret = [] 

        # Loop through the bandit problem
        for t in range(environment.contexts):
        # while not db_env.finish():
            # Get the context-arms pair
            context_arms = environment.get_context_arms()

            # print (f"Round: {t}, Context-Arms: {context_arms}")
            
            # Get the learner's selected arm
            arm, active_arms = learner.select(context_arms)
                        
            # Get the preference feedback
            reward = environment.get_noisy_reward(arm)
            
            # Update the learner's model
            learner.update(reward)
            
            # Get the instantaneous regret
            instantaneous_regret = environment.get_active_regret(active_arms, arm)
            # print(f"Time: {t}, Arm: {arm}, Regret: {instantaneous_regret}")

            # if instantaneous_regret != 0:
            #     print(f"Time: {t}, Arm: {arm}, Regret: {instantaneous_regret}")
                
            
            # Append the instantaneous regret
            regret.append(instantaneous_regret)

        # Append the regret for the run
        algo_regret.append(regret)

        # Reset the environment
        environment.reset()
        
        # Reset the learner
        learner.reset()

    # Save the regret data
    file_base_name = environment.problem_name + "_" + learner_info
    file_location = "data/plots/" + file_base_name
    file_to_save = file_location + "_{}.npz".format(args.runs)
    np.savez(file_to_save,
             algo_regret = algo_regret,
             time_taken = time.time() - start_time
            )

    # ### Plotting the regret ###
    if args.plot == "plot":
        import matplotlib.pyplot as plt

        # Colors and shapes
        colors = list("rgbcmkyrb")
        shape = ['--^', '--v', '--*', '--H', '--d', '--+', '--v', '--^']

        # Plotting the average regret with error bars
        horizon, batched_regret, error = cumulative_regret_error(np.array(algo_regret))
        plt.plot(horizon, batched_regret, colors[0] + shape[0], label=args.learner)
        plt.errorbar(horizon, batched_regret, error, color=colors[0])

        # Plot location
        plot_location = file_base_name + "_{}".format(args.runs)
        
        # Average regret plotting
        plot_to_save = plot_location + "_regret.png"


        # Plot details
        plt.rc('font', size=12)                     # controls default text sizes
        plt.legend(loc="upper right", numpoints=1)  # Location of the legend
        plt.xlabel("Rounds", fontsize=20)
        plt.ylabel("Cumulative Regret", fontsize=20)
        plt.savefig(plot_to_save, bbox_inches='tight', dpi=600)
        plt.close()
    
    # Delete the object instances for the environment and the learner
    del environment
    del learner
        
