# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss

import warnings
warnings.filterwarnings("ignore")


# ### Plotting ###
# Getting Average regret and Confidence interval
# Cumulative regret error
def cumulative_regret_error(regret):
    time_horizon = [0]
    samples = len(regret[0]) 
    runs = len(regret)
    batch = samples / 25
    # batch = 40

    # Time horizon
    t = 0
    while True:
        t += 1
        if time_horizon[-1] + batch > samples:
            if time_horizon[-1] != samples:
                time_horizon.append(time_horizon[-1] + samples % batch)
            break
        time_horizon.append(time_horizon[-1] + batch)

    # Mean batch regret of R runs
    avg_batched_regret = []
    for r in range(runs):
        count = 0
        accumulative_regret = 0
        batch_regret = [0]
        for s in range(samples):
            count += 1
            accumulative_regret += regret[r][s]
            if count == batch:
                batch_regret.append(accumulative_regret)
                count = 0

        if samples % batch != 0:
            batch_regret.append(accumulative_regret)
        avg_batched_regret.append(batch_regret)

    regret = np.mean(avg_batched_regret, axis=0)

    # Confidence interval
    conf_regret = []
    freedom_degree = runs - 1
    for r in range(len(avg_batched_regret[0])):
        conf_regret.append(ss.t.ppf(0.95, freedom_degree) *
                           ss.sem(np.array(avg_batched_regret)[:, r]))
    return time_horizon, regret, conf_regret


# Regret Plotting
def cumulative_regret_plotting(regret, cases, file_name, plot_location, y_lim=100):
    colors = list("gbcmrykb")
    shape = ['--^', '--v', '--H', '--d', '--+', '--*', '--v', '--^']
    total_cases = len(cases)

    # Scatter Error bar with scatter plot
    # print(np.array(regret).shape)
    for c in range(total_cases):
        horizon, batched_regret, error = cumulative_regret_error(np.array(regret)[:1000, c])
        print(cases[c], batched_regret[-1], error[-1])
        plt.errorbar(horizon, batched_regret, error, color=colors[c])
        plt.plot(horizon, batched_regret, colors[c] + shape[c], label=cases[c])

    # Plot details
    plt.rc('font', size=12)                     # controls default text sizes
    plt.legend(loc=plot_location, numpoints=1)  # Location of the legend
    plt.xlabel("Rounds", fontsize=20)
    plt.ylabel("Cumulative Regret", fontsize=20)

    # plt.title("Comparison of Algorithms")
    # plt.axis([0, samples, -20, samples])
    # plt.xscale('log')
    plt.ylim(0, y_lim)
     
    # Saving plot
    plt.savefig(file_name, bbox_inches='tight', dpi=600)
    plt.close()
    

# Regret Plotting
def cumulative_regret_plotting_no_ylimit(regret, cases, file_name, plot_location):
    colors = list("gbcmrykb")
    shape = ['--^', '--v', '--H', '--d', '--+', '--*', '--v', '--^']
    total_cases = len(cases)

    # print(np.array(regret).shape)
    # Scatter Error bar with scatter plot
    for c in range(total_cases):
        horizon, batched_regret, error = cumulative_regret_error(np.array(regret)[:1000, c])
        plt.errorbar(horizon, batched_regret, error, color=colors[c])
        plt.plot(horizon, batched_regret, colors[c] + shape[c], label=cases[c])

    # Plot details
    plt.rc('font', size=12)                     # controls default text sizes
    plt.legend(loc=plot_location, numpoints=1)  # Location of the legend
    plt.xlabel("Rounds", fontsize=20)
    plt.ylabel("Cumulative Regret", fontsize=20)

    # plt.title("Comparison of Algorithms")
    # plt.axis([0, samples, -20, samples])
    # plt.xscale('log')
         
    # Saving plot
    plt.savefig(file_name, bbox_inches='tight', dpi=600)
    plt.close()
   

# Regret Plotting with reverse storage
def cumulative_regret_plotting_reverse(regret, cases, file_name, plot_location, y_lim=100):
    colors = list("gbcmrykb")
    shape = ['--^', '--v', '--H', '--d', '--+', '--*', '--v', '--^']
    total_cases = len(cases)
    
    # Scatter Error bar with scatter plot
    for c in range(total_cases):
        horizon, batched_regret, error = cumulative_regret_error(np.array(regret)[c, :])
        plt.errorbar(horizon, batched_regret, error, color=colors[c])
        plt.plot(horizon, batched_regret, colors[c] + shape[c], label=cases[c])

    # Plot details
    plt.rc('font', size=12)                     # controls default text sizes
    plt.legend(loc=plot_location, numpoints=1)  # Location of the legend
    plt.xlabel("Rounds", fontsize=20)
    plt.ylabel("Cumulative Regret", fontsize=20)

    # plt.title("Comparison of Algorithms")
    # plt.axis([0, samples, -20, samples])
    # plt.xscale('log')
    # plt.ylim(0, y_lim)
     
    # Saving plot
    plt.savefig(file_name, bbox_inches='tight', dpi=600)
    plt.close()

   
# Average plotting
def average_plotting(data, cases, file_name, plot_location, runs, y_label ="Cumulative Regret"):
    colors = list("gbcmrykb")
    shape = ['--^', '--v', '--H', '--d', '--+', '--*', '--v', '--^']
    shape = ['--', '--', '--', '--', '--', '--', '--', '--']
    total_cases = len(cases)
    
    # Setting up x-axis
    samples = (len(data[0][0]))
    batch_size = 1 #int(samples / 20)
    x_axis = [i for i in range(1, len(data[0][0])+1, batch_size)]
    for c in range(total_cases):     
        # Computing mean and standard error
        mean = np.mean(np.array(data)[:, c], axis=0)
        std_err = 1.96*(np.std(np.array(data)[:, c], axis=0) / (np.sqrt(runs)))
        
        # Plotting
        plt.plot(x_axis, mean, colors[c] + shape[c], label=cases[c])
        # plt.fill_between(x_axis, (mean-std_err), (mean+std_err), color=colors[c], alpha=.1)
        # plt.errorbar(x_axis, mean, std_err, color=colors[c])

    # Plot details
    plt.rc('font', size=10)                     # controls default text sizes
    plt.legend(loc=plot_location, numpoints=1)  # Location of the legend
    plt.xlabel("Rounds", fontsize=15)
    plt.ylabel(y_label, fontsize=15)

    # plt.title("Title")
    # plt.axis([0, samples, -20, samples])
    # plt.xscale('log')
     
    # Saving plot
    plt.savefig(file_name, bbox_inches='tight', dpi=600)
    plt.close()
 
 
# Average plotting with x-axis values
def average_plotting_x_axis(data, x_axis, cases, file_name, plot_location, runs, x_label="Rounds", y_label ="Cumulative Regret"):
    colors = list("gbcmrykb")
    shape = ['--^', '--v', '--H', '--d', '--+', '--*', '--v', '--^']
    shape = ['--', '--', '--', '--', '--', '--', '--', '--']
    shape = ['--+', '--+', '--+', '--+', '--+', '--+', '--+', '--+']

    total_cases = len(cases)
    
    # Setting up x-axis
    for c in range(total_cases):     
        # Computing mean and standard error
        mean = np.mean(np.array(data)[c], axis=0)
        std_err = 1.96*(np.std(np.array(data)[c], axis=0) / (np.sqrt(runs)))
        
        # Plotting
        plt.plot(x_axis, mean, colors[c] + shape[c], label=cases[c])
        # plt.fill_between(x_axis, (mean-std_err), (mean+std_err), color=colors[c], alpha=.1)
        plt.errorbar(x_axis, mean, std_err, color=colors[c])

    # Plot details
    plt.rc('font', size=10)                     # controls default text sizes
    plt.legend(loc=plot_location, numpoints=1)  # Location of the legend
    plt.xlabel(x_label, fontsize=15)
    plt.ylabel(y_label, fontsize=15)

    # Saving plot
    plt.savefig(file_name, bbox_inches='tight', dpi=600)
    plt.close()
 