import pickle
import yaml

import matplotlib.pyplot as plt
import numpy as np

import argparse

parser = argparse.ArgumentParser(
                prog = 'Plotting for MEVDRONet',
                description = 'Computes adversarial risk for some data.')
parser.add_argument('filename')
args = parser.parse_args()

with open(args.filename, 'r') as f:
    params = yaml.safe_load(f)

d = params['d']

n_epochs = params['n_epochs']
n_lam    = params['n_lam'] 

width  = params['width'] 
n_data = params['n_data']
block_size = params['block_size']

if block_size > 0:
    n_data = n_data // block_size
    if block_size == 100:
        n_data = 20

n_max = params['n_max'] 
rate = params['rate'] 

use_softmax = params['use_softmax'] 
experiment  = params['experiment'] 

risk = params['risk'] 

n_eps  = params['n_eps'] 
n_runs = params['n_runs'] 

eps_max = params['eps_max'] 

data_file = params['data_file'] 
cost_norm = params['cost_norm']

try:
    eps_coef = params['eps_coef']
except:
    eps_coef = None

synthetic_rate = params['synthetic_rate'] #0

import os

save_path_pp = data_file + '_pp_{}_gen_eps{}_data{}_blocksize{}/'.format(risk, eps_max, n_max, block_size)

true_key = 'true_risk'
pop_key  = 'p0_risk'
adv_key  = 'adv_risk'

try:
    #with open(save_path_pp + 'stats0_{}.p'.format(n_data), 'rb') as f:
    with open(save_path_pp + 'stats{}_100.p'.format(eps_max), 'rb') as f:
        data = pickle.load(f)
    true = data[true_key]
    pop  = data[pop_key]
    adv  = data[adv_key]
    adv = adv * (adv < 1e2)
    x = np.stack(data['losses'])[:,0]
except :
    adv = None
    print('EVD stats not found, skipping')

fig, ax = plt.subplots()
error_pop = np.abs(pop.mean(-1) - true.mean(-1))
# x_domain = (x / pop.mean(-1))
x_domain = (x / true.mean(-1))


# error_pop_lower = np.log(error_pop) - np.log(np.nanstd(pop.numpy(), -1)) / np.log(error_pop) 
# error_pop_upper = np.log(error_pop) + np.log(np.nanstd(pop.numpy(), -1)) / np.log(error_pop)
# ax.plot(x_domain, np.log(error_pop), marker ='o', label=r'$P_0$')
# ax.fill_between(x_domain, error_pop_lower, error_pop_upper, alpha=0.3)
error_pop_lower = np.log(error_pop) - (np.nanstd(error_pop, -1) / error_pop) 
error_pop_upper = np.log(error_pop) + (np.nanstd(error_pop, -1) / error_pop)
ax.plot(x_domain, np.log(error_pop), marker ='o', label=r'$P_0$: Non-DRO EVD Risk')
ax.fill_between(x_domain, error_pop_lower, error_pop_upper, alpha=0.3)

# ax.plot(x_domain, true.mean(-1)-true.mean(-1), marker ='x', label=r'$P_\mathrm{real}$')
# ax.fill_between(x_domain, true.mean(-1) - true.std(-1), true.mean(-1) + true.std(-1), alpha=0.3)

# error_adv = np.abs(np.nanmean(adv.numpy(), -1) - true.numpy().mean(-1))
# error_adv_lower = np.log(error_adv) - np.log(np.abs(np.nanstd(adv.numpy(), -1) - true.numpy().mean(-1))) / np.log(error_adv)
# error_adv_upper = np.log(error_adv) + np.log(np.abs(np.nanstd(adv.numpy(), -1) - true.numpy().mean(-1))) / np.log(error_adv)
# ax.plot(x_domain, np.log(error_adv), marker ='*', label=r'$P_\star$')
# ax.fill_between(x_domain, error_adv_lower, error_adv_upper, alpha=0.3)
error_adv = np.abs(np.nanmean(adv.numpy(), -1) - true.numpy().mean(-1))
error_adv_lower = np.log(error_adv) - (np.nanstd(error_adv, -1) / error_adv)
error_adv_upper = np.log(error_adv) + (np.nanstd(error_adv, -1) / error_adv)
ax.plot(x_domain, np.log(error_adv), marker ='*', label=r'$P_\star$: DRO EVD Risk')
ax.fill_between(x_domain, error_adv_lower, error_adv_upper, alpha=0.3)




# plt.gca().spines['top'].set_visible(False)
# plt.gca().spines['right'].set_visible(False)

y_domain_max = np.around(np.max(error_adv))
y_domain_min = np.around(np.max(error_adv))
y_ticks = np.arange(y_domain_min, y_domain_max + 1, dtype=int)
ax.spines['top'].set_color('none')
ax.spines['right'].set_position(('axes', 1.0))
ax.spines['left'].set_color('none')
ax.spines['bottom'].set_position(('axes', 0.0))
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xlabel(r'$\delta$ Normalized by True Risk', fontsize = 15)
# plt.ylabel(r'Log Expected Error = $Log \vert \mathbb{E}[\ell(X_{P_{true}})] - \mathbb{E}[\ell(X_{P_{model}})] \vert $', fontsize = 15)
plt.ylabel(r'$Log \vert \mathbb{E}[\ell(X_{P_{true}})] - \mathbb{E}[\ell(X_{P_{model}})] \vert $', fontsize = 15)
# plt.title("Expected Risk Evaluted Over\nIncreasing Uncertainty ($\delta$)", fontsize = 15)
plt.title("Expected Risk Evaluated Over\nIncreasing Uncertainty ($\delta$)", fontsize = 15)
plt.legend(loc="lower right", fontsize = 15)
plt.tight_layout()
plt.savefig('raw_comparison_{}_N_{}_{}.pdf'.format(data_file, n_data, risk))