#!/usr/bin/env python
# coding: utf-8

# ### README
# Main `.py` to be run for Experiment 1 of the main paper, after running other `.py` files. Check `README` file for more details on this.
# 
# _Side note_: Each of the other 5 `.py` files correspond to a setting in the representative set; each setting is run for 60 independent runs. This file combines them and plots. This ensures that we have 300 runs overall, while ensuring that each setting gets exactly 60 runs. This closely simulates 300 independent runs where, in each run, a setting is chosen with 20% probability.

# In[1]:


import numpy as np
import matplotlib.pyplot as plt


# In[2]:


import pickle


# In[3]:


def get_rewards_for_algo(rewards_dict_for_plot, algo):
    rewards = []
    for k, v in rewards_dict_for_plot.items():
        if k[2] == algo:
            rewards.append(v)
    
    return np.array(rewards)


# In[4]:


def get_regrets_for_algo(regrets_dict_for_plot, algo):
    regrets = []
    for k, v in regrets_dict_for_plot.items():
        if k[2] == algo:
            regrets.append(v)
    
    return np.array(regrets)


# In[5]:


## Plotting settings


# In[6]:


conf_width = 1.96


# In[7]:


algos = ['A', 'B', 'C', 'D', 'E', 'F', 'v2']
colors = ['black', 'steelblue', 'forestgreen', 'orange', 'olive', 'slateblue', 'red']
labels = ['NonCausal_Std_TS', 'Std_TS', 'Std_UniExp', 'TargInt_TS', 'TargInt_UniExp', 'TargInt_TS_UniExp', 'Unc_CCB (Ours)']


# In[8]:


import pathlib
from pathlib import Path

cwd = Path.cwd()
dirs = [d for d in cwd.iterdir() if (d.is_dir() and d.stem[0]!='.' and d.stem[0]!='_') and d.stem[0]=='V']
print(*dirs, sep='\n')


# In[9]:


variations = [d.stem for d in dirs]

variations


# ## Variation settings

# In[10]:


var_settings = {}

for d in dirs:

    files = [e for e in d.iterdir() if e.is_file() and e.suffix=='.pickle']
    
    with open(files[0], 'rb') as handle:
        pickled = pickle.load(handle)
        var_settings[d.stem] = pickled['params']['var_ranges']



# ## Plotting - Across T

# In[11]:


from pathlib import Path
import pathlib


# In[12]:


def plot_across_T(alpha=1/2):
    fig, ax = plt.subplots(figsize=(16,12), nrows=1, ncols=1)
    
    regrets_means, regrets_mean_of_means, regrets_stderror = {}, {}, {}
    bucketed_regrets_means, bucketed_regrets_mean_of_means, bucketed_regrets_stderror = {}, {}, {}
    files, files_sorted = {}, {}
    Ts, train_runss = {}, {}
    regrets_dict_for_plot = {}
    max_Ts, min_Ts = {}, {}
    train_runs_for_buckets, Ts_buckets_for_var, Ts_bucket_cutoffs, Ts_buckets_cutoffs_for_expt = {}, {}, {}, {}
    
    Ts_bucket_cutoffs = [0, 1/8, 1/4, 3/8, 1/2, 5/8]

    max_frac_xctar = 1 
    
    for variation in variations:

        cwd = Path.cwd()
        p = cwd / variation
        files[variation] = [e for e in p.iterdir() if e.is_file()]

        Ts[variation], train_runss[variation] = [], []
        regrets_means[variation] = {}
        regrets_mean_of_means[variation] = {}
        regrets_stderror[variation] = {}

        for f in files[variation]:  # For each file f denoting a single experiment
            with open(f, 'rb') as handle:
                pickled = pickle.load(handle)

            if pickled['config']['alpha'] != alpha:
                continue

            Ts[variation].append(pickled['config']['train_T'])
            train_runss[variation].append(pickled['config']['train_runs'])

            regrets_dict_for_plot[variation] = pickled['regrets']

            # Append outcomes for every algo in this experiment to the corresponding algo lists
            for algo in algos:
                temp = get_regrets_for_algo(regrets_dict_for_plot[variation], algo).mean(axis=1)
                if algo not in list(regrets_mean_of_means[variation].keys()):                
                    regrets_means[variation][algo] = [temp]
                    regrets_mean_of_means[variation][algo] = [temp.mean()]
                    regrets_stderror[variation][algo] = [conf_width*np.std(temp)/np.sqrt(pickled['config']['train_runs'])]
                else:
                    regrets_means[variation][algo] += [temp]
                    regrets_mean_of_means[variation][algo] += [temp.mean()]
                    regrets_stderror[variation][algo] += [conf_width*np.std(temp)/np.sqrt(pickled['config']['train_runs'])]

        if len(Ts[variation]) == 0:
            return

        # Sort everything according to Ts
        for algo in algos:
            regrets_means[variation][algo] = [x for _, x in sorted(zip(Ts[variation], regrets_means[variation][algo]))]
            regrets_mean_of_means[variation][algo] = [x for _, x in sorted(zip(Ts[variation], regrets_mean_of_means[variation][algo]))]
            regrets_stderror[variation][algo] = [x for _, x in sorted(zip(Ts[variation], regrets_stderror[variation][algo]))]
        train_runss[variation] = [x for _, x in sorted(zip(Ts[variation], train_runss[variation]))]
        files_sorted[variation] = [x for _, x in sorted(zip(Ts[variation], files[variation]))]
        Ts[variation] = sorted(Ts[variation])

        # Bucketing
        max_Ts[variation] = var_settings[variation]['range_C1'] * var_settings[variation]['range_x'] * max_frac_xctar
        min_Ts[variation] = 0

        Ts_buckets_cutoffs_for_expt[variation] = [t*max_Ts[variation]  for t in Ts_bucket_cutoffs]

        Ts_buckets_for_var[variation] = []
        for l in range(len(Ts_buckets_cutoffs_for_expt[variation])-1):
            r = l+1
            Ts_buckets_for_var[variation] += [[t for t in Ts[variation] if (t>Ts_buckets_cutoffs_for_expt[variation][l] and t<=Ts_buckets_cutoffs_for_expt[variation][r])]]

        bucketed_regrets_means[variation], bucketed_regrets_mean_of_means[variation], bucketed_regrets_stderror[variation]  = {}, {}, {}

        for algo in algos:     
            train_runs_for_buckets[variation] = []
            bucketed_regrets_means[variation][algo], bucketed_regrets_mean_of_means[variation][algo], bucketed_regrets_stderror[variation][algo] = [], [], []

            i = 0
            for bucket in Ts_buckets_for_var[variation]:
                temp = np.array([])
                train_runs_for_bucket = 0
                for _ in bucket:
                    temp = np.append(temp, regrets_means[variation][algo][i])
                    train_runs_for_bucket += train_runss[variation][i]
                    i += 1
                bucketed_regrets_means[variation][algo] += [temp]
                bucketed_regrets_mean_of_means[variation][algo] += [temp.mean()]
                bucketed_regrets_stderror[variation][algo] += [conf_width*np.std(temp)/np.sqrt(train_runs_for_bucket)]
                train_runs_for_buckets[variation] += [train_runs_for_bucket]


        # Normalization
        min_regret = 0
        max_regret = 0
        for algo in algos:
            if max(bucketed_regrets_mean_of_means[variation][algo]) > max_regret:
                max_regret = max(bucketed_regrets_mean_of_means[variation][algo])

        for algo in algos:
            for i in range(len(bucketed_regrets_means[variation][algo])):
                temp = bucketed_regrets_means[variation][algo][i]
                temp = (temp - min_regret) / (max_regret - min_regret)
                bucketed_regrets_means[variation][algo][i] = temp
                bucketed_regrets_mean_of_means[variation][algo][i] = temp.mean()
                bucketed_regrets_stderror[variation][algo][i] = conf_width*np.std(temp)/np.sqrt(train_runs_for_buckets[variation][i])

    # Combine across variations
    combined_regrets_mean_of_means = {}
    combined_regrets_stderror = {}
    for algo in algos:
        combined_regrets_mean_of_means[algo] = []
        combined_regrets_stderror[algo] = []
        for t in range(len(Ts_bucket_cutoffs)-1):
            temp = np.array([])
            train_runs_combined = 0
            for variation in variations:              
                temp = np.append(temp, bucketed_regrets_means[variation][algo][t])
                train_runs_combined += train_runs_for_buckets[variation][t]
            
            combined_regrets_mean_of_means[algo] += [temp.mean()]
            combined_regrets_stderror[algo] += [conf_width*np.std(temp)/np.sqrt(train_runs_combined)]
    
                
    # Plot
    for algo, color, label in zip(algos, colors, labels):        
        ax.errorbar(
            Ts_bucket_cutoffs[1:], 
            combined_regrets_mean_of_means[algo], 
            yerr = combined_regrets_stderror[algo], 
            fmt='o', 
            color=color, 
            label=label,
            elinewidth = 0.5,
            capsize=3,
            linestyle='--',
            marker='^' if algo=='v2' else 'o',
            markersize=12 if algo=='v2' else 10,
            linewidth=2.5 if algo=='v2' else 1.75  
        )

    ax.set_xlabel('Number of training rounds $T$ (as fraction of $N_XN_{C^{tar}}$)', fontsize=24)
    ax.set_ylabel('Regret (normalized to [0, 1])', fontsize=24)
    ax.tick_params(axis="x", labelsize=16)
    ax.tick_params(axis="y", labelsize=16)
    
    ax.legend(fontsize=24, loc='upper right')
    
    plt.savefig('experiment1_plot', bbox_inches='tight')
    plt.show()


# In[13]:


plot_across_T()

