 # !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss


# ### Plotting ###
# Getting Average regret and Confidence interval
def cumulative_regret_error(regret):
    time_horizon = [0]
    samples = len(regret[0])
    runs = len(regret)
    batch = samples / 20
    # batch = 40

    # Time horizon
    t = 0
    while True:
        t += 1
        if time_horizon[-1] + batch > samples:
            if time_horizon[-1] != samples:
                time_horizon.append(time_horizon[-1] + samples % batch)
            break
        time_horizon.append(time_horizon[-1] + batch)

    # Mean batch regret of R runs
    avg_batched_regret = []
    for r in range(runs):
        count = 0
        accumulative_regret = 0
        batch_regret = [0]
        for s in range(samples):
            count += 1
            accumulative_regret += regret[r][s]
            if count == batch:
                batch_regret.append(accumulative_regret)
                count = 0

        if samples % batch != 0:
            batch_regret.append(accumulative_regret)
        avg_batched_regret.append(batch_regret)

    regret = np.mean(avg_batched_regret, axis=0)

    # Confidence interval
    conf_regret = []
    freedom_degree = runs - 1
    for r in range(len(avg_batched_regret[0])):
        conf_regret.append(ss.t.ppf(0.95, freedom_degree) *
                           ss.sem(np.array(avg_batched_regret)[:, r]))
    return time_horizon, regret, conf_regret


# # ### Plotting Regret ###
# Plot and problem details
regret_types = ['average', 'weak']
file_location = "data/plots/"
problem = "cosine10"      # "linear", levy", "cosine", "square", "ackley"
problem_instance = problem + "_5_5_1000_0.1_0.0_1" 
lamdba = 1.0
nu = 1.0
learner_udpate = 1
runs = 20
learner_info = "{}_{}_{}_{}".format(lamdba, nu, learner_udpate, runs)

# Lambda and nu for NeuralDB
neural_params = False
neural_ucb_lambda = 1.0
neural_ucb_nu = 1.0
neural_ts_lambda = 1.0
neural_ts_nu = 1.0

# Algorithms to compare
algos = [
    'LinDB-UCB',
    'LinDB-TS',
    'CoLSTIM',
    'NeuralDB-UCB',
    'NeuralDB-TS'
]
algos_in_plots = [
    'LinDB-UCB',
    'LinDB-TS',
    'CoLSTIM',
    'NDB-UCB',
    'NDB-TS'
]

# Plotting the average regret
colors = list("gbcmrykb")
shape = ['--^', '--v', '--+', '--H', '--d', '--*', '--v', '--^']

# Fetching data from the files
for regret_type in regret_types:
    for a in range(len(algos)):
        if algos[a] == 'LinDB-UCB':
            data_file = problem_instance + "_linear_ucb_" + learner_info
        
        elif algos[a] == 'CoLSTIM':
            data_file = problem_instance + "_colstim_ucb_" + learner_info
        
        elif algos[a] == 'LinDB-TS':
            data_file = problem_instance + "_linear_ts_" + learner_info
            
        elif algos[a] == 'NeuralDB-UCB':
            learner_info_local = "{}_{}_{}_{}".format(neural_ucb_lambda, neural_ucb_nu, learner_udpate, runs)
            learner_info = learner_info_local if neural_params else learner_info        
            data_file = problem_instance + "_neural_ucb_False_" + learner_info
            
        elif algos[a] == 'NeuralDB-TS':
            learner_info_local = "{}_{}_{}_{}".format(neural_ts_lambda, neural_ts_nu, learner_udpate, runs)
            learner_info = learner_info_local if neural_params else learner_info           
            data_file = problem_instance + "_neural_ts_False_" + learner_info
        
        else:
            raise RuntimeError('Learner not exist')

        # Load data
        all_data = np.load(file_location + data_file + ".npz")
        algo_regret = all_data['{}_regret'.format(regret_type)]

        # Scatter Error bar with scatter plot
        horizon, batched_regret, error = cumulative_regret_error(np.array(algo_regret))
        plt.errorbar(horizon, batched_regret, error, color=colors[a])
        plt.plot(horizon, batched_regret, colors[a] + shape[a], label=algos_in_plots[a])
        
    # Average regret plotting
    file_to_save = "plots/" + problem_instance + "_compare_{}.png".format(regret_type)
    file_to_save = "" + problem_instance + "_compare_{}.png".format(regret_type)

    # Plot details
    plt.rc('font', size=12)                     # controls default text sizes
    plt.legend(loc="upper left", numpoints=1)  # Location of the legend
    # plt.xlabel("Rounds", fontsize=15)
    plt.ylabel("Cumulative Regret ({})".format(regret_type), fontsize=18)
    # plt.ylabel("Cumulative Regret", fontsize=15)


    # plt.title("Comparison of Algorithms")
    # plt.axis([0, samples, -20, samples])
    # plt.xscale('log')
        
    # Saving plot
    plt.savefig(file_to_save, bbox_inches='tight', dpi=600)
    plt.close()
    
    
    