# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from tqdm import tqdm
import argparse
import os
import time
import sys

# To ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Dueling bandit environments
from env import CombDuelingEnv

# Dueling bandit learners
# from combinatorial_learners import LinearConfDB, NeuralInitDB, NeuralDB, RandomSearch
from learner import LinearConfDB, NeuralInitDB, NeuralDB, RandomSearch




# Plotting the results
from plotting_function import cumulative_regret_plotting

# Getting current directory path
cwd = os.getcwd()
sys.path.append(cwd)


# Input arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Neural Dueling Bandit")
    parser.add_argument(
        "--db_value_function",
        type=str,
        default="square",
        metavar="square|cosine",
        help="Name of dataset to use"
    )
    parser.add_argument(
        "--reward_function",
        type=str,
        default="add",
        metavar="add",
        help="Name of dataset to use"
    )
    parser.add_argument(
        "--dim",
        type=int,
        default=5,
        help="Set the dimension of context in the bandit problem"
    )
    parser.add_argument(
        "--total_arms",
        type=int,
        default=5,
        help="Set the number of arms"
    )
    parser.add_argument(
        "--super_arms",
        type=int,
        default=2,
        help="Set the number of arms"
    )
    parser.add_argument(
        "--size",
        type=int,
        default=100,
        help="Set the size of the bandit problem"
    )
    parser.add_argument(
        "--noise",
        type=float,
        default=1.0,
        help="Set the noise level for the arms"
    )
    parser.add_argument(
        "--suboptimality_gap",
        type=float,
        default=0.0,
        help="Set the optimality gap for the arms"
    )
    parser.add_argument(
        "--learner",
        type=str,
        default="linear",
        # default="neural",
        metavar="neural|linear",
        help="Name of learner's type to use"
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="ucb",
        metavar="ucb|ts",
        help="Set the strategy to use: TS or UCB."
    )
    parser.add_argument(
        "--diagonalize",
        type=bool,
        default=False,
        metavar="True|False",
        help="Use diagonalize for the inverse of gram matrix or not"
    )
    parser.add_argument(
        "--lamdba",
        type=float,
        default=1.0,
        help="Set the lamdba parameter."
    )
    parser.add_argument(
        "--nu",
        type=float,
        default=1.0,
        help="Set the parameter nu."
    )
    parser.add_argument(
        "--learner_update",
        type=int,
        default=20,
        metavar="10|20|50|100",
        help="Set the update frequency of the learner"
    )
    parser.add_argument(
        "--hidden",
        type=int,
        default=50,
        metavar="32|100",
        help="Set the network hidden size"
    )
    parser.add_argument(
        "--layers",
        type=int,
        default=1,
        metavar="2|1",
        help="Set the number of hidden layers"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Set the random seed for numpy/Torch"
    )
    parser.add_argument(
        "--plots",
        type=bool,
        default=True,
        metavar="True|False",
        help="Plot the results or not"
    )
    parser.add_argument(
        "--sample_superarm",
        type=int,
        default=50,
        help="Set the number of samples of superarm"
    )
    parser.add_argument(
        "--runs",
        type=int,
        default=2,
        help="Set the number of runs"
    )

    return parser.parse_args()


# Starting the main function
if __name__ == '__main__':

    # Parsing the input arguments
    args = parse_args()

    # Setting the random seed for numpy
    np.random.seed(args.seed)

    # Dueling bandit environment
    db_env = CombDuelingEnv(args.db_value_function,
                        args.reward_function,
                        args.dim,
                        args.total_arms,
                        args.super_arms,
                        args.size,
                        args.noise,
                        args.suboptimality_gap,
                        args.seed
                        )

    # Dueling bandit learners
    if args.learner == 'neural':
        learner = NeuralDB(input_dim=db_env.dim, each_round_arms=db_env.super_arms, sample_superarm = args.sample_superarm, lamdba=args.lamdba, nu=args.nu,
                           strategy=args.strategy, diagonalize=args.diagonalize, learner_update=args.learner_update, hidden_size=args.hidden)
        learner_info = '{}_{}_{}_{}_{}_{}_{}_{}'.format(
            args.learner,
            args.strategy,
            args.super_arms,
            args.diagonalize,
            args.lamdba,
            args.nu,
            args.learner_update,
            args.sample_superarm
        )
        cases = ['NeuralCDB-UCB'] if args.strategy == 'ucb' else ['NeuralCDB-TS']

    elif args.learner == 'linear':
        learner = LinearConfDB(db_env.dim, db_env.super_arms, args.lamdba, args.nu, args.strategy, learner_update=args.learner_update)
        learner_info = '{}_{}_{}_{}_{}_{}'.format(
            args.learner,
            args.strategy,
            args.super_arms,
            args.lamdba,
            args.nu,
            args.learner_update
        )
        cases = ['LinCDB-UCB'] if args.strategy == 'ucb' else ['LinCDB-TS']

    elif args.learner == "random":
        learner = RandomSearch(db_env.dim, db_env.super_arms)
        learner_info = '{}_{}'.format(
            args.learner,
            args.super_arms,
        )
        cases = ['Random-Search']

    elif args.learner == 'neuralinit':
        learner = NeuralInitDB(input_dim=db_env.dim, each_round_arms=db_env.super_arms, sample_superarm=args.sample_superarm,
                               lamdba=args.lamdba, nu=args.nu,strategy=args.strategy, diagonalize=args.diagonalize, learner_update=args.learner_update)
        learner_info = '{}_{}_{}_{}_{}_{}_{}_{}'.format(
            args.learner,
            args.strategy,
            args.super_arms,
            args.diagonalize,
            args.lamdba,
            args.nu,
            args.learner_update,
            args.sample_superarm
        )
        cases = ['NeuralInitDB-UCB'] if args.strategy == 'ucb' else ['NeuralInitDB-TS']

    else:
        raise RuntimeError('Learner not exist')


    # ### Interaction between the learner and the environment ###
    # Starting the time
    print('start')
    start_time = time.time()

    # Running over multiple runs
    algo_average_regret = []
    algo_weak_regret = []
    for r in tqdm(range(args.runs)):
        # Reset the environment
        db_env.reset()

        # Reset the learner
        learner.reset()

        # Loop through the bandit problem
        average_regret = []
        weak_regret = []

        # Loop through the bandit problem
        for t in range(db_env.size):
            # while not db_env.finish():
            # Get the context-arms pair
            context_arms = db_env.get_context_arms()

            # Get the learner's actions
            arm_set1, arm_set2 = learner.select(context_arms)

            # print(f"Time: {t}, Arm 1: {at_1}, Arm 2: {at_2}")

            # Get the preference feedback
            feedback_list = []
            for arm1,arm2 in zip(arm_set1, arm_set2):
                feedback = db_env.get_feedback(arm1, arm2)
                feedback_list.append(feedback)
                # print('env:', arm1,arm2,feedback)
            # Update the learner's model
            # print(arm_set1, arm_set2)
            if arm_set1 != arm_set2:
                learner.update(feedback_list)

            # Get the regret
            rt_avg, rt_weak = db_env.get_regret(arm_set1, arm_set2)

            # Append the regret
            average_regret.append(rt_avg)
            weak_regret.append(rt_weak)
            print('t',t, rt_avg, rt_weak)

        # Append the regret for the run
        algo_average_regret.append(average_regret)
        algo_weak_regret.append(weak_regret)


    # Save the regret data
    file_location = "data/plots/" + db_env.problem_name + "_" + learner_info
    file_to_save = file_location + "_{}.npz".format(args.runs)
    os.makedirs(os.path.dirname(file_to_save), exist_ok=True)

    np.savez(file_to_save,
             average_regret=algo_average_regret,
             weak_regret=algo_weak_regret,
             time_taken=time.time() - start_time
             )

    # ### Plotting the regret ###
    if args.plots:
        # Plot location
        plot_location = file_location + "_{}".format(args.runs)

        # Average regret plotting
        plot_to_save = plot_location + "_average.png"
        cumulative_regret_plotting(algo_average_regret, cases, plot_to_save, 'lower right')

        # Weak regret plotting
        plot_to_save = plot_location + "_weak.png"
        cumulative_regret_plotting(algo_weak_regret, cases, plot_to_save, 'lower right')

    # Delete the object instances for the environment and the learner
    del db_env
    del learner

