import numpy as np
# import matplotlib
from matplotlib import pyplot as plt
from numpy.core.fromnumeric import size
from numpy.lib.function_base import average
import pandas as pd
import seaborn as sns
import os, argparse

# font = {'size'   : 40}
# matplotlib.rc('font', **font)
plt.rcParams.update({'font.size': 100})

print('plot inf-dim instance 10 seeds...')

# load offline eq
n = 100
# load offline equilibrium
u_opt = np.loadtxt(os.path.join('results', 'inf-dim', 'offline-eq', 'u'))

sns.set_theme()

os.makedirs('plots', exist_ok=True)

# average across seeds
from matplotlib import pyplot as plt
import seaborn as sns
# sns.set_theme()
import os, json

inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = [], [], []
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = [], [], []

for seed in range(1, 11):
    fpath = os.path.join('results', 'inf-dim', 'sd-{}'.format(seed))
    with open(os.path.join(fpath, 'meta_data'), 'r') as ff:
        meta_data = json.load(ff)
    T = meta_data['T']
    inf_norm_to_beta_eq = np.loadtxt(os.path.join(fpath, 'inf_norm_to_beta_eq.gz')) 
    ave_one_norm_to_beta_eq = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_beta_eq.gz'))
    inf_norm_to_u_eq = np.loadtxt(os.path.join(fpath, 'inf_norm_to_u_eq.gz'))
    ave_one_norm_to_u_eq = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_u_eq.gz'))
    inf_norm_to_B = np.loadtxt(os.path.join(fpath, 'inf_norm_to_B.gz'))
    ave_one_norm_to_B = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'))

    inf_norm_to_beta_eq_all_seeds.append(inf_norm_to_beta_eq), inf_norm_to_u_eq_all_seeds.append(inf_norm_to_u_eq), inf_norm_to_B_all_seeds.append(inf_norm_to_B)
    ave_one_norm_to_beta_eq_all_seeds.append(ave_one_norm_to_beta_eq), ave_one_norm_to_u_eq_all_seeds.append(ave_one_norm_to_u_eq), ave_one_norm_to_B_all_seeds.append(ave_one_norm_to_B)

# convert them into numpy arrays
inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = np.array(inf_norm_to_beta_eq_all_seeds), np.array(inf_norm_to_u_eq_all_seeds), np.array(inf_norm_to_B_all_seeds)
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = np.array(ave_one_norm_to_beta_eq_all_seeds), np.array(ave_one_norm_to_u_eq_all_seeds), np.array(ave_one_norm_to_B_all_seeds)

# np.std(inf_norm_to_u_eq_all_seed, axis=0)
t0 = 5*n
T = n * 100
skip_size = 1
num_dp = (T - t0) // skip_size

u_proportional = 1/n
inf_norm_to_u_eq_baseline, ave_one_norm_to_u_eq_baseline = np.max(np.abs(u_proportional-u_opt)/u_opt), np.average(np.abs(u_proportional-u_opt)/u_opt)

###### max relative errors ######
fig = plt.figure(figsize=(6, 4))
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_\infty$', linestyle='solid', errorevery=num_dp//10)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{u}^t - u^*)/u^*||_\infty$', linestyle='dashed', errorevery=num_dp//8)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_\infty$', linestyle='dashdot', errorevery=num_dp//6)
plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * inf_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_\infty$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=0.5, linestyles='dotted', linewidth=1.0) #, label=r'multiplies of $n$')
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# plt.errorbar(range(1, T+1, log_interval), np.mean(duality_gap_all_seeds, axis=0), np.std(duality_gap_all_seeds, axis=0), label = r'${\rm dgap}_t$', linestyle='dashed', errorevery=num_logs//4)
# plt.yscale('log') #, plt.xscale('log')
plt.xticks(range(0, T+1, T//5))
# plt.xlabel('t')
plt.title(r'Inf-Dim $n={}$, $\Theta = [0,1]$ (Max Relative Errors)'.format(n), fontsize=15)
# if dataset == 'MovieLens':
plt.legend(prop={'size': 15}, loc='center right')
# plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
plt.savefig(os.path.join('plots', 'max-relative-error-inf-dim-n-{}-mean-and-se.pdf'.format(n)))
plt.clf()

###### ave relative errors ######
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_1/n$', linestyle='solid', errorevery=num_dp//10)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{u}^t - u^*)/u^*||_1/n$', linestyle='dashed', errorevery=num_dp//8)
# plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_1/n$', linestyle='dashdot', errorevery=num_dp//6)
plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * ave_one_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_1/n$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=1, linestyles='dotted', linewidth=1.0)
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# plt.yscale('log') #, plt.xscale('log')
plt.xticks(range(0, T+1, T//5))
# plt.xlabel('t')
plt.title(r'Inf-Dim $n={}$, $\Theta = [0,1]$ (Average Relative Errors)'.format(n), fontsize=15)
plt.legend(prop={'size': 15}, loc='center right')
# plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
plt.savefig(os.path.join('plots', 'ave-relative-error-inf-dim-n-{}-mean-and-se.pdf'.format(n)))
plt.clf()