 # !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
import argparse
import os
import sys


# Getting current directory path
cwd = os.getcwd()
sys.path.append(cwd)


# ### Plotting ###
# Getting Average regret and Confidence interval
def cumulative_regret_error(regret):
    time_horizon = [0]
    samples = len(regret[0])
    runs = len(regret)
    batch = samples / 20
    # batch = 40

    # Time horizon
    t = 0
    while True:
        t += 1
        if time_horizon[-1] + batch > samples:
            if time_horizon[-1] != samples:
                time_horizon.append(time_horizon[-1] + samples % batch)
            break
        time_horizon.append(time_horizon[-1] + batch)

    # Mean batch regret of R runs
    avg_batched_regret = []
    for r in range(runs):
        count = 0
        accumulative_regret = 0
        batch_regret = [0]
        for s in range(samples):
            count += 1
            accumulative_regret += regret[r][s]
            if count == batch:
                batch_regret.append(accumulative_regret)
                count = 0

        if samples % batch != 0:
            batch_regret.append(accumulative_regret)
        avg_batched_regret.append(batch_regret)

    regret = np.mean(avg_batched_regret, axis=0)

    # Confidence interval
    conf_regret = []
    freedom_degree = runs - 1
    for r in range(len(avg_batched_regret[0])):
        conf_regret.append(ss.t.ppf(0.95, freedom_degree) *
                           ss.sem(np.array(avg_batched_regret)[:, r]))
    return time_horizon, regret, conf_regret


# Input arguments
def parse_args():
    parser = argparse.ArgumentParser(description="Plotting the results")
    
    # Plotting parameters
    parser.add_argument(
        "--filename_prefix",
        type=str,
        default="strategic_square_1000_5_2_2_0.1_1.0_0.2_static_poly_true_vary_arms_strategic_arm_101"
    )
    parser.add_argument(
        "--filename_postfix",
        type=str,
        default="loom_0.01_0.05_0.1_50"
    )
    parser.add_argument(
        "--y_limit",
        type=int,
        default=4000,
        help="Y-axis limit for the plots"
    )

    return parser.parse_args()


# Starting the main function
if __name__ == '__main__':
    # Parsing the input arguments
    args = parse_args()

    # Y-axis limit
    y_limit = args.y_limit
    

    print ("File Prefix: {}".format(args.filename_prefix))
    print ("File Postfix: {}".format(args.filename_postfix))

    # # ### Plotting Regret ###
    # Plot and problem details
    plot_types = ['algo_regret']
    file_location = "data/plots/"

    # Algorithms to compare
    algos = [
        # 'Random',
        'LinUCB',
        'LinTS',
        'OptGTM',
        'COBRA (UCB)',
        'COBRA (TS)'
    ]

    # Dictionary for stored algorithms
    algos_dict = {
        'Random'        :   '_random_ucb_',
        'LinUCB'        :   '_linear_ucb_',
        'LinTS'         :   '_linear_ts_',
        'OptGTM'        :   '_optgtm_ucb_',
        'COBRA (UCB)'   :   '_cobra_ucb_',
        'COBRA (TS)'    :   '_cobra_ts_'
    }

    # Plotting the average regret
    colors = list("rgbcmkyrb")
    shape = ['--^', '--v', '--*', '--H', '--d', '--+', '--v', '--^']

    # Fetching data from the files
    for plot_type in plot_types:
        for a in range(len(algos)):
            data_file = file_location + args.filename_prefix + algos_dict[algos[a]] + args.filename_postfix

            # Load data
            all_data = np.load(data_file + ".npz")
            plot_data = all_data['{}'.format(plot_type)]

            # Scatter Error bar with scatter plot
            horizon, batched_regret, error = cumulative_regret_error(np.array(plot_data))
            plt.errorbar(horizon, batched_regret, error, color=colors[a])
            plt.plot(horizon, batched_regret, colors[a] + shape[a], label=algos[a])

            # Computing mean and standard error
            # mean = np.mean(np.array(plot_data), axis=0)
            # std_err = 1.96*(np.std(np.array(plot_data), axis=0) / (np.sqrt(runs)))
            
            # Plotting
            # plt.plot(x_axis, mean, colors[a], label=algos[a])
            # plt.fill_between(x_axis, (mean-std_err), (mean+std_err), color=colors[a], alpha=.1)

        # Average regret plotting
        plot_file = args.filename_prefix + "_compare{}_".format(len(algos)) + args.filename_postfix
        file_to_save = "results/" + plot_file + "_{}.png".format(plot_type)

        # Plot details
        plt.rc('font', size=12)                     # controls default text sizes
        plt.legend(loc="upper left", numpoints=1)  # Location of the legend
        plt.xlabel("Rounds", fontsize=20)
        y_label = "Cumulative Regret" if plot_type == 'algo_regret' else "Regret"
        plt.ylabel(y_label, fontsize=20)

        # plt.ylim(0, y_limit)
        # plt.title("Comparison of Algorithms")
        # plt.axis([0, samples, -20, samples])
        # plt.yscale('log')
            
        # Saving plot
        plt.savefig(file_to_save, bbox_inches='tight', dpi=600)
        plt.close()