 # !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss


# ### Plotting ###
# Average plotting
def average_plotting(data, cases, file_name, plot_location, y_label):
    colors = list("gbcmrykb")
    shape = ['--^', '--v', '--H', '--d', '--+', '--*', '--v', '--^']
    total_cases = len(cases)
    
    # Setting up x-axis
    x_axis = [i for i in range(1, len(data[0])+1)]
    for c in range(total_cases):
        # Computing mean and standard error
        runs = 10
        mean = np.mean(np.array(data), axis=0)
        std_err = 1.96*(np.std(np.array(data), axis=0) / (np.sqrt(runs)))
        
        # Plotting
        # plt.plot(x_axis, mean, colors[c] + shape[c], label=cases[c])
        plt.plot(x_axis, mean, colors[c], label=cases[c])
        plt.fill_between(x_axis, (mean-std_err), (mean+std_err), color=colors[c], alpha=.1)
        # plt.errorbar(x_axis, mean, std_err, color=colors[c])

    # Plot details
    plt.rc('font', size=10)                     # controls default text sizes
    plt.legend(loc=plot_location, numpoints=1)  # Location of the legend
    plt.xlabel("Rounds", fontsize=15)
    plt.ylabel(y_label, fontsize=15)

    # plt.title("Title")
    # plt.axis([0, samples, -20, samples])
    # plt.xscale('log')
     
    # Saving plot
    plt.savefig(file_name, bbox_inches='tight', dpi=600)
    plt.close()


# Getting Average regret and Confidence interval
def cumulative_regret_error(regret):
    time_horizon = [0]
    samples = len(regret[0])
    runs = len(regret)
    batch = samples / 20
    # batch = 40

    # Time horizon
    t = 0
    while True:
        t += 1
        if time_horizon[-1] + batch > samples:
            if time_horizon[-1] != samples:
                time_horizon.append(time_horizon[-1] + samples % batch)
            break
        time_horizon.append(time_horizon[-1] + batch)

    # Mean batch regret of R runs
    avg_batched_regret = []
    for r in range(runs):
        count = 0
        accumulative_regret = 0
        batch_regret = [0]
        for s in range(samples):
            count += 1
            accumulative_regret += regret[r][s]
            if count == batch:
                batch_regret.append(accumulative_regret)
                count = 0

        if samples % batch != 0:
            batch_regret.append(accumulative_regret)
        avg_batched_regret.append(batch_regret)

    regret = np.mean(avg_batched_regret, axis=0)

    # Confidence interval
    conf_regret = []
    freedom_degree = runs - 1
    for r in range(len(avg_batched_regret[0])):
        conf_regret.append(ss.t.ppf(0.95, freedom_degree) *
                           ss.sem(np.array(avg_batched_regret)[:, r]))
    return time_horizon, regret, conf_regret


# # ### Plotting Regret ###
# Plot and problem details
plot_types = ['rmse', 'subg', 'mae', 'average_regret', 'weak_regret']
file_location = "data/plots/"
# problem = "square"      # "linear", levy", "cosine", "square", "ackley"
problems = ["cosine", "square"]
problems = ["square20", "cosine10"]
contexts = 1000
arms = 10
dim = 5
seed = 1


# Learner details
lamdba = 1.0
nu = 1.0
rounds = 200
learner_udpate = 10
runs = 20
learner_info = "{}_{}_{}_{}_{}".format(lamdba, nu, rounds, learner_udpate, runs)

# Algorithms to compare
algos = [
    'Random', 
    'AE-Borda', 
    'APO', 
    'AE-Borda (NN)',
    'Neural-ADB (APO)', 
    'Neural-ADB (UCB)', 
    'Neural-ADB (TS)', 
    # 'Neural-ADBIG (UCB)',
    # 'Neural-ADBIG (TS)',
    # 'Neural-ADB (NG + APO)',
    # 'Neural-ADB (NG + UCB)',
    # 'Neural-ADBIG (NG + UCB)',
    # 'Neural-ADB (NG + TS)',
    # 'Neural-ADBIG (NG + TS)'
]
y_axis_labels = {
    'rmse': 'RMSE',
    'subg': 'Worst Sub-optimality Gap',
    'mae': 'Average Sub-optimality Gap',
    'average_regret': 'Average Regret',
    'weak_regret': 'Weak Regret'
}

# Plotting the average regret
colors = list("rgbcmkyrb")
shape = ['--^', '--v', '--*', '--H', '--d', '--+', '--v', '--^']
x_axis = [i for i in range(1, rounds+1)]

# Fetching data from the files
for problem in problems:
    problem_instance = problem + "_{}_{}_{}_{}".format(contexts, arms, dim, seed)
    for plot_type in plot_types:
        for a in range(len(algos)):
            print ("Plotting: ", algos[a])
            if algos[a] == 'Random':
                data_file = problem_instance + "_Random_ucb_" + learner_info
            
            elif algos[a] == 'AE-Borda':
                data_file = problem_instance + "_AEBorda_ucb_" + learner_info
                
            elif algos[a] == 'APO':
                data_file = problem_instance + "_APO_ucb_" + learner_info
            
            elif algos[a] == 'AE-Borda (NN)':
                data_file = problem_instance + "_AEBordaNNGrad_ucb_" + learner_info

            elif algos[a] == 'Neural-ADB (APO)':
                data_file = problem_instance + "_NeuralAPOGrad_ucb_" + learner_info

            elif algos[a] == 'Neural-ADB (UCB)':
                data_file = problem_instance + "_NeuralADBGrad_ucb_" + learner_info

            elif algos[a] == 'Neural-ADBIG (UCB)':
                data_file = problem_instance + "_NeuralADBIGGrad_ucb_" + learner_info

            elif algos[a] == 'Neural-ADB (TS)':
                data_file = problem_instance + "_NeuralADBGrad_ts_" + learner_info

            elif algos[a] == 'Neural-ADBIG (TS)':
                data_file = problem_instance + "_NeuralADBIGGrad_ts_" + learner_info

            elif algos[a] == 'Neural-ADB (NG + APO)':
                data_file = problem_instance + "_NeuralAPO_ucb_" + learner_info

            elif algos[a] == 'Neural-ADB (NG + UCB)':
                data_file = problem_instance + "_NeuralADB_ucb_" + learner_info

            elif algos[a] == 'Neural-ADBIG (NG + UCB)':
                data_file = problem_instance + "_NeuralADBIG_ucb_" + learner_info

            elif algos[a] == 'Neural-ADB (NG + TS)':
                data_file = problem_instance + "_NeuralADB_ts_" + learner_info

            elif algos[a] == 'Neural-ADBIG (NG + TS)':
                data_file = problem_instance + "_NeuralADBIG_ts_" + learner_info
            
            else:
                raise RuntimeError('Learner not exist')

            # Load data
            all_data = np.load(file_location + data_file + ".npz")
            plot_data = all_data['{}'.format(plot_type)]

            if plot_type == 'average_regret' or plot_type == 'weak_regret':
                horizon, batched_regret, error = cumulative_regret_error(np.array(plot_data))
                plt.errorbar(horizon, batched_regret, error, color=colors[a])
                plt.plot(horizon, batched_regret, colors[a] + shape[a], label=algos[a])

            else:
                # Computing mean and standard error
                mean = np.mean(np.array(plot_data), axis=0)
                std_err = 1.96*(np.std(np.array(plot_data), axis=0) / (np.sqrt(runs)))
                
                # Plotting
                plt.plot(x_axis, mean, colors[a], label=algos[a])
                plt.fill_between(x_axis, (mean-std_err), (mean+std_err), color=colors[a], alpha=.1)

        # Average regret plotting
        plot_file = problem_instance + "_compare{}_".format(len(algos)) + learner_info
        file_to_save = "plots/" + plot_file + "_{}.png".format(plot_type)
        file_to_save = "results/" + plot_file + "_{}.png".format(plot_type)

        # Plot details
        plt.rc('font', size=12)                     # controls default text sizes
        plt.legend(loc="upper right", numpoints=1)  # Location of the legend
        plt.xlabel("Rounds", fontsize=20)
        y_label = y_axis_labels[plot_type]
        plt.ylabel(y_label, fontsize=20)
        plt.title("Comparison of Algorithms")
        # plt.axis([0, samples, -20, samples])

        # if plot_type == 'subg':
        #     plt.yscale('log')
        # # plt.yscale('log')
            
        # Saving plot
        plt.savefig(file_to_save, bbox_inches='tight', dpi=600)
        plt.close()

