import numpy as np
from matplotlib import pyplot as plt
from numpy.lib.function_base import average
import pandas as pd
import seaborn as sns
import os, argparse

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--n', '-n', type=int, default=100)
    parser.add_argument('--m', '-m', type=int, default=1000)
    parser.add_argument('--rank', '-rk', type=int, default=3)
    args = parser.parse_args()
    n, m, d = args.n, args.m, args.rank
except Exception as ee:
    print(ee)
    print('not parsing command line inputs. use given parameters.')
    n, m, d = 100, 1000, 3

print('plotting Random low-rank n = {}, m = {}, d (rank) = {}'.format(n, m, d))

# reconstruct the instance here
np.random.seed(1)
alpha, theta = np.random.exponential(size=(n, d)) + 0.05, np.abs(np.random.normal(size=(m, d))) + 1
alpha = (alpha.T / (alpha @ np.sum(theta, 0))).T
v = alpha @ theta.T
n, m = v.shape
v = m * (v.T / np.sum(v, 1)).T
B = np.ones(n) / n

# load offline eq
x_opt = np.loadtxt(os.path.join('results', 'random-' + 'n-{}-m-{}-d-{}'.format(n, m, d), 'offline-eq', 'x'))
u_opt = np.sum(v * x_opt, 1)
u_proportional = np.ones(n)/n
inf_norm_to_u_eq_baseline = np.max(np.abs(u_proportional-u_opt)/u_opt)
ave_norm_to_u_eq_baseline = np.mean(np.abs(u_proportional-u_opt)/u_opt)

sns.set_theme()

os.makedirs('plots', exist_ok=True)

# average across seeds
from matplotlib import pyplot as plt
import seaborn as sns
# sns.set_theme()
import os, json

inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = [], [], []
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = [], [], []

for seed in range(1, 11):
    fpath = os.path.join('results', 'random-n-{}-m-{}-d-{}'.format(n, m, d), 'sd-{}'.format(seed))
    with open(os.path.join(fpath, 'meta_data'), 'r') as ff:
        meta_data = json.load(ff)
    T = meta_data['T']
    inf_norm_to_beta_eq = np.loadtxt(os.path.join(fpath, 'inf_norm_to_beta_eq.gz')) 
    ave_one_norm_to_beta_eq = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_beta_eq.gz'))
    inf_norm_to_u_eq = np.loadtxt(os.path.join(fpath, 'inf_norm_to_u_eq.gz'))
    ave_one_norm_to_u_eq = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_u_eq.gz'))
    inf_norm_to_B = np.loadtxt(os.path.join(fpath, 'inf_norm_to_B.gz'))
    ave_one_norm_to_B = np.loadtxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'))

    inf_norm_to_beta_eq_all_seeds.append(inf_norm_to_beta_eq), inf_norm_to_u_eq_all_seeds.append(inf_norm_to_u_eq), inf_norm_to_B_all_seeds.append(inf_norm_to_B)
    ave_one_norm_to_beta_eq_all_seeds.append(ave_one_norm_to_beta_eq), ave_one_norm_to_u_eq_all_seeds.append(ave_one_norm_to_u_eq), ave_one_norm_to_B_all_seeds.append(ave_one_norm_to_B)

# convert them into numpy arrays
inf_norm_to_beta_eq_all_seeds, inf_norm_to_u_eq_all_seeds, inf_norm_to_B_all_seeds = np.array(inf_norm_to_beta_eq_all_seeds), np.array(inf_norm_to_u_eq_all_seeds), np.array(inf_norm_to_B_all_seeds)
ave_one_norm_to_beta_eq_all_seeds, ave_one_norm_to_u_eq_all_seeds, ave_one_norm_to_B_all_seeds = np.array(ave_one_norm_to_beta_eq_all_seeds), np.array(ave_one_norm_to_u_eq_all_seeds), np.array(ave_one_norm_to_B_all_seeds)

# np.std(inf_norm_to_u_eq_all_seed, axis=0)
t0 = 5*n
T = n * 100
skip_size = 1
num_dp = (T - t0) // skip_size

dataset = 'Random'
###### max relative errors ######
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_\infty$', linestyle='solid', errorevery=num_dp//10)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{g}^t - u^*)/u^*||_\infty$', linestyle='dotted', errorevery=num_dp//8)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(inf_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_\infty$', linestyle='dashdot', errorevery=num_dp//6)
plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * inf_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_\infty$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=0.5, linestyles='dotted', linewidth=1.0) #, label=r'multiplies of $n$')
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# plt.errorbar(range(1, T+1, log_interval), np.mean(duality_gap_all_seeds, axis=0), np.std(duality_gap_all_seeds, axis=0), label = r'${\rm dgap}_t$', linestyle='dashed', errorevery=num_logs//4)
# plt.yscale('log') #, plt.xscale('log')
plt.xticks(range(0, T+1, T//5))
plt.xlabel('t')
plt.title('{}, n = {}, m = {}, d = {} (Max Relative Errors)'.format(dataset, n, m, d))
# if dataset == 'MovieLens':
plt.legend()
# plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
plt.savefig(os.path.join('plots', 'max-relative-error-{}-n-{}-m-{}-d-{}-mean-and-se.pdf'.format(dataset, n, m, d)))
plt.clf()

###### ave relative errors ######
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_beta_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\beta^t - \beta^*)/\beta^*||_1/n$', linestyle='solid', errorevery=num_dp//10)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_u_eq_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{g}^t - u^*)/u^*||_1/n$', linestyle='dotted', errorevery=num_dp//8)
plt.errorbar(range(t0+1, T+1, skip_size), np.mean(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), (1/np.sqrt(10)) * np.std(ave_one_norm_to_B_all_seeds[:, range(t0, T, skip_size)], axis=0), label = r'$||(\bar{b}^t - B)/B||_1/n$', linestyle='dashdot', errorevery=num_dp//6)
plt.plot(range(t0+1, T+1, skip_size), np.ones(num_dp) * ave_norm_to_u_eq_baseline, label = r'$||(u^{\rm PS} - u^*)/u^*||_1/n$', linestyle = (0, (3, 5, 1, 5, 1, 5)))
# plt.vlines([pt for pt in range(t0, T+1) if pt % (n*10) == 0], ymin=0, ymax=1, linestyles='dotted', linewidth=1.0)
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*10) == 0]
# plt.yscale('log') #, plt.xscale('log')
plt.xticks(range(0, T+1, T//5))
plt.xlabel('t')
plt.title('{}, n = {}, m = {}, d = {} (Ave. Relative Errors)'.format(dataset, n, m, d))
plt.legend()
# plt.savefig(os.path.join('plots', '{}-n-{}-m-{}-seed-{}'.format(dataset, n, m, seed)))
plt.savefig(os.path.join('plots', 'ave-relative-error-{}-n-{}-m-{}-d-{}-mean-and-se.pdf'.format(dataset, n, m, d)))
plt.clf()