#!/usr/bin/env python
# coding: utf-8

# ### README
# This code plots graphs for Experiment 2 after `experiment2.py` has been run.
# 
# Make sure that current working directory has a folder called `output_for_plot` with the `.pickle` files generated by `experiment2.py`

# ## Imports and utility functions

# In[1]:


import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib

import pickle


# In[2]:


def get_rewards_for_algo(rewards_dict_for_plot, algo):
    ''' Extract rewards for a _given_ algorithm from a consolidated list'''
    rewards = []
    for k, v in rewards_dict_for_plot.items():
        if k[2] == algo:
            rewards.append(v)
    
    return np.array(rewards)


# In[3]:


def get_regrets_for_algo(regrets_dict_for_plot, algo):
    ''' Extract regrets for a _given_ algorithm from a consolidated list'''
    regrets = []
    for k, v in regrets_dict_for_plot.items():
        if k[2] == algo:
            regrets.append(v)
    
    return np.array(regrets)


# ## Plotting - Across T

# In[4]:


conf_width = 1.96  # 95% confidence interval -- for plotting error bars


# In[5]:


algos = ['A', 'B', 'C', 'D', 'E', 'F', 'v2']
colors = ['black', 'steelblue', 'forestgreen', 'orange', 'olive', 'slateblue', 'red']
labels = ['NonCausal_Std_TS', 'Std_TS', 'Std_UniExp', 'TargInt_TS', 'TargInt_UniExp', 'TargInt_TS_UniExp', 'Unc_CCB (Ours)']


# In[ ]:





# In[6]:


def plot_across_T(alpha=1/2):    
    ''' Main function to plot regrets of all algorithms as a function of T'''
    
    fig, ax = plt.subplots(figsize=(16,12), nrows=1, ncols=1)

    cwd = Path.cwd()
    p = cwd / 'output_for_plot'
    files = [e for e in p.iterdir() if e.is_file()]
    print('Number of files = ', len(files))
    print('Sample filename = ', files[0])


    Ts = []
    regrets_mean_of_means = {}
    regrets_stderror = {}
    for f in files:


        with open(f, 'rb') as handle:
            pickled = pickle.load(handle)
            
        if pickled['config']['alpha'] != alpha:
            continue

        Ts.append(pickled['config']['train_T'])

        regrets_dict_for_plot = pickled['regrets']

        for algo in algos:
            if algo not in list(regrets_mean_of_means.keys()):
                regrets_mean_of_means[algo] = [get_regrets_for_algo(regrets_dict_for_plot, algo).mean()]
                regrets_stderror[algo] = [conf_width*np.std(get_regrets_for_algo(regrets_dict_for_plot, algo).mean(axis=1))/np.sqrt(pickled['config']['train_runs'])]
            else:
                regrets_mean_of_means[algo] += [get_regrets_for_algo(regrets_dict_for_plot, algo).mean()]
                regrets_stderror[algo] += [conf_width*np.std(get_regrets_for_algo(regrets_dict_for_plot, algo).mean(axis=1))/np.sqrt(pickled['config']['train_runs'])]

    if len(Ts) == 0:
        print('Error: No data with specified parameters')
        return
    
    for algo, color, label in zip(algos, colors, labels):        
        ax.errorbar(
            sorted(Ts), 
            [x for _, x in sorted(zip(Ts, regrets_mean_of_means[algo]))], 
            yerr = [x for _, x in sorted(zip(Ts, regrets_stderror[algo]))], 
            fmt='o', 
            color=color, 
            label=label,
            elinewidth = 0.5,
            capsize=3,
            linestyle='--',
            marker='^' if algo=='v2' else 'o',
            markersize=12 if algo=='v2' else 10,
            linewidth=2 if algo=='v2' else 1.5
        )

    Ts = sorted(Ts)

    print('Ts=', Ts)

    ax.set_xlabel('Number of training rounds $T$', fontsize=24)
    ax.set_ylabel('Regret', fontsize=24)
    ax.tick_params(axis="x", labelsize=16)
    ax.tick_params(axis="y", labelsize=16)
    
    ax.set_ylim(bottom=0)

    ax.legend(fontsize=24)
    
    plt.savefig('1d-random-params', bbox_inches = 'tight')


# In[7]:


plot_across_T()


# In[ ]:




