from random import sample
import numpy as np
import pandas as pd
from utils import *
import argparse, os

n, m = 100, 300

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--samplepathindex', '-spi', type=int, default=0)
    parser.add_argument('--datainputmodel', '-dim', type=str, default='markov')
    args = parser.parse_args()
    # print('args.seed =', args.seed)
    sample_path_idx = args.samplepathindex
    data_input_model = args.datainputmodel
except:
    print('not parsing command line inputs. use given parameters.')
    sample_path_idx = 0
    data_input_model = 'periodic'

# load sample path
import numpy as np
np.random.seed(2022)

(
    all_sample_paths_iid, 
    all_sample_paths_mild_corrupt, 
    all_sample_paths_markov, 
    all_sample_paths_periodic
) = np.load('instances/all_sample_paths.npz').values()

print(
    'sample path idx = {}, data input model = {}'.format(
        sample_path_idx, data_input_model
    )
)

if data_input_model == 'iid':
    all_sample_paths = all_sample_paths_iid
elif data_input_model == 'mild':
    all_sample_paths = all_sample_paths_mild_corrupt
elif data_input_model == 'markov':
    all_sample_paths = all_sample_paths_markov
elif data_input_model == 'periodic':
    all_sample_paths = all_sample_paths_periodic

items_sample_path = all_sample_paths[sample_path_idx]

df = pd.read_csv("data/movielens_1500x1500.csv", header=None)
v = df.to_numpy().T
v = v[np.random.choice(v.shape[0], size=n)]
v = v[:, np.random.choice(v.shape[1], size=m)]

T = n * 200 # scale with #(buyers) 

B = np.ones(n) / n

# true distribution
s = np.random.uniform(size=m) + 0.1
s = s / np.sum(s)

v = (v.T/(v@s)).T 

################################################################
# experiment parameters
do_stochastic = True

delta0 = 0.05
beta = np.ones(n) # initial beta: minimizer of the regularizer
beta_ave = np.zeros(n)
g_ave = np.zeros(n)
u_ave_baseline = np.zeros(n) # proportional allocation
# u_running = np.zeros(n) # to collect utilities along the way
# record stuff across time
winners_all_t = [] # i(t) = min of argmax over i of beta[i] * v[i, j(t)]
# spending[i] := cumulative spending of buyer i
# it gets incremented by beta[t,i] * v[i,j] if j = j(t) is sampled at time t and i = i(t) is the winner

# some are logged in every t (error norms in variables)
# some are only periodically (dgap and envy gap)
spending = np.zeros(n)
inf_norm_to_u_hseq, inf_norm_to_beta_hseq, inf_norm_to_B = [], [], []
ave_one_norm_to_u_hseq, ave_one_norm_to_beta_hseq, ave_one_norm_to_B = [], [], []
inf_norm_u_ave_baseline, ave_norm_u_ave_baseline = [], []
realized_prices = []

x_cumulative = np.zeros((n, m))
x_proportional = (B * np.ones(shape=(n, m)).T).T / m

# elif data_input_model == 'Mildly Perturbed':
#     items_sample_path = sample_all_arrivals_mild_corrupt_finite_item_set(s, T)
# elif data_input_model == 'Markov':
#     items_sample_path = sample_all_arrivals_markov_finite_item_set(m, T)

# for simplicity (w.l.o.g.), just make all items "appear" so we do not need to remove anything
s_hindsight = np.ones(m, dtype=int) / (T//2)
for t in range(T):
    s_hindsight[ items_sample_path[t] ] += 1
s_hindsight = s_hindsight / np.sum(s_hindsight)

# compute hindsight equilibrium
x_hseq, p_hseq = compute_me_fin_dim(v, B=np.ones(n)/n, s=s_hindsight, T=int(5e3))
u_hseq = np.sum(v * x_hseq, axis=1)
beta_hseq = B / u_hseq
pobj_hseq = np.sum( B * np.log(u_hseq) )
duality_const = np.sum(B) - np.sum( B * np.log(B) )
dobj_hseq = p_hseq @ s_hindsight - np.sum( B * np.log(beta_hseq) )
dgap_eg_hseq = dobj_hseq - duality_const - pobj_hseq
# all_buyer_reg_hseq = compute_buyer_regret(x_hseq, p_hseq, v, B)
# rel_buyer_reg_hseq = all_buyer_reg_hseq / u_hseq
# print('hindsight solution ave & max regret: {:.4f}, {:.4f}, EG dgap: {}'.format(ave_buyer_reg_hseq, max_buyer_reg_hseq, dgap_eg_hseq))
print('============== hindsight equilibrium ==============')
print('EG duality gap: {:.5f}'.format(dgap_eg_hseq))
# print('max rel. buyer reg: {:.5f}'.format(rel_buyer_reg_hseq.max()))
# print('ave rel. buyer reg: {:.5f}'.format(rel_buyer_reg_hseq.mean()))
print('===================================================')

# what to plot/record:
# convergence of beta to beta_hseq
# convergence of u to u_hseq
# convergence of b to B
# max_envy and total_envy

# at current t, 
# this is max of v[i,j(tau)] / p(j(tau)) over all tau <= t
bpb = np.ones(n)
max_rel_buyer_regret = []
ave_rel_buyer_regret = []

for t in range(1, T+1):
    j = items_sample_path[t-1]

    # find winners for this item (just pick the lex. smallest winner, if tie)
    winner = np.argmax(beta * v[:, j])
    winners_all_t.append(winner)
    curr_price = beta[winner] * v[winner, j]
    spending[winner] += curr_price # option 1: use beta(t) to compute prices
    realized_prices.append(curr_price)
    # u_running[winner] += v[winner, j] # winner gets its v[winner, j]
    # update g_bar: only the winner's entry can potentially be incremented
    g_ave = (t-1) * g_ave / t if t > 1 else np.ones(n) / n
    # note the m: since it is non-averaged sum over j
    g_ave[winner] += v[winner, j] / t
    
    u_ave_baseline = u_ave_baseline * (t-1)/t + v[:, j]*B / t

    # update beta
    beta = np.maximum( (1-delta0) * B, np.minimum(1 + delta0, B / g_ave) ) # spending[winner] += beta[winner] * v[winner, j] # option 2: use beta(t+1) to compute prices
    beta_ave = (t-1) * beta_ave / t + beta / t
    
    x_cumulative[winner, j] += 1

    # update running max bang-per-buck
    bpb = np.maximum(bpb, v[:, j] / curr_price)

    # compute regret
    rel_buyer_reg = np.maximum(bpb * B - g_ave, 0) / B
    # ave_rel_buyer_regret.append(rel_buyer_reg.mean())
    max_rel_buyer_regret.append(rel_buyer_reg.max())

    # logging
    inf_norm_u_ave_baseline.append(np.max(np.abs((u_ave_baseline-u_hseq)/u_hseq)))
    inf_norm_to_u_hseq.append(np.max(np.abs(g_ave - u_hseq)/u_hseq)) # relative to each u_opt
    inf_norm_to_beta_hseq.append(np.max(np.abs(beta - beta_hseq)/beta_hseq))
    inf_norm_to_B.append(np.max(np.abs(B - spending/t)/B))
    # ave_one_norm_to_u_hseq.append(np.mean(np.abs(g_ave - u_hseq)/u_hseq))
    # ave_one_norm_to_beta_hseq.append(np.mean(np.abs(beta - beta_hseq)/beta_hseq))
    # ave_one_norm_to_B.append(np.mean(np.abs(B - spending/t)/B))

    # compute envy
    # x_cumulative/t T
    # max_envy.append()

    if t % (int(T//20)) == 0:
        print('t={}, max_beta_error={:.4f}, max_u_error={:.4f},  max_b_error={:.4f}'.format(t, inf_norm_to_beta_hseq[-1], inf_norm_to_u_hseq[-1], inf_norm_to_B[-1]))

# check regret

# (B[30]*bpb[30] - u_hseq[30]) / u_hseq[30]

# save result
np.savez_compressed(
    'results/movielens_{}_{}_{}_all_logs_{}.npz'.format(n, m, data_input_model.lower().replace(' ', '_'), sample_path_idx), 
    u_hseq, B, 
    inf_norm_u_ave_baseline,# ave_norm_u_ave_baseline,
    inf_norm_to_beta_hseq, #ave_one_norm_to_beta_hseq, 
    inf_norm_to_u_hseq, #ave_one_norm_to_u_hseq, 
    inf_norm_to_B, #ave_one_norm_to_B, 
    max_rel_buyer_regret, #ave_rel_buyer_regret,
)

# res = g_ave / u_opt
# print('max and min of g_ave/T divided by u_opt[i]: {:.4f}, {:.4f}'.format(np.min(res), np.max(res)))

# # construct final empirical x
# x = x_cumulative / T # np.sum(x, axis=0) == s_hindsight

# from matplotlib import pyplot as plt

# burn_in = int(T//50)
# plt.plot(np.arange(burn_in, T), ave_rel_buyer_regret[burn_in:T], label='ave relative buyer regret')
# plt.plot(np.arange(burn_in, T), max_rel_buyer_regret[burn_in:T], label='max relative buyer regret')
# plt.plot(np.arange(burn_in, T), ave_one_norm_to_u_hseq[burn_in:T], label=r'ave $\|\|u^t_i - u^{\rm HS}_i\|\|/u^{\rm HS}_i$')
# plt.plot(np.arange(burn_in, T), inf_norm_to_u_hseq[burn_in:T], label=r'max $\|\|u^t_i - u^{\rm HS}_i\|\|/u^{\rm HS}_i$')
# plt.xlabel('t')
# plt.ylabel('value')
# plt.legend()
# # plt.title('Dataset: {}\n ({} Arrivals)'.format(dataset, data_input_model))
# # fname = '-'.join(
# #     [ dataset.lower(), data_input_model.replace(' ', '-').lower(), str(seed) ]
# # )
# # plt.savefig(f'plots/{fname}.pdf')
# plt.show()

# u_time_ave = np.sum(v * x, 1) # it is u_running / T
# b_eq_all = [p_opt.T @ x[i] for i in range(n)]

# np.linalg.norm(u_time_ave - u_opt, 1) / np.linalg.norm(u_opt)

# # final envy: <v[i], x[k]> / B[k]
# umat_budget_scaled = np.array([[v[i].T @ x[k] / B[k] for k in range(n)] for i in range(n)]) # umat[2,3] - v[2].T @ x[3] == 0
# envy_final = np.max(umat_budget_scaled, 1) - np.diag(umat_budget_scaled)
# imax = np.argmax(envy_final)
# print('max envy = {} through buyer i={}'.format(envy_final[imax], imax))
# res = np.array(spending)/T

# # inf_norm_to_u_hseq, inf_norm_to_beta_hseq, duality_gap = np.array(inf_norm_to_u_hseq), np.array(inf_norm_to_beta_hseq), np.array(duality_gap)
# plt.plot(range(1, T+1),  inf_norm_to_u_hseq, label = r'$||\bar{g}^t - u^*||_\infty$', linestyle='solid')
# plt.plot(range(1, T+1),  inf_norm_to_beta_hseq, label = r'$||\beta^t - \beta^*||_\infty$', linestyle='dashed')
# plt.plot(range(1, T+1, log_interval), duality_gap, label = r'${\rm dgap}_t$', linestyle='dashed')
# # plt.plot(range(1, T+1, log_interval), max_envy, label = r'$\|\rho^t\|_\infty$', linestyle='dashdot')
# # plt.plot(range(1, T+1, log_interval), ave_envy, label = r'$\|\rho^t\|_1/n$', linestyle='dashed')
# plt.yscale('log'), plt.xscale('log')
# plt.xlabel('t')
# plt.title('{}, n = {}, m = {}'.format(dataset, n, m))
# plt.legend()
# plt.savefig('g-ave-and-beta-conv.pdf')


########################################################################
# # save results
# import pandas as pd
# import json
# fpath = os.path.join('results', dataset.lower(), prefix + 'sd-'+str(seed))
# print('fpath = {}'.format(fpath))
# os.makedirs(fpath, exist_ok=True)
# # np.savetxt(os.path.join(fpath, 'duality_gap'), duality_gap, fmt='%.4e') 
# np.savetxt(os.path.join(fpath, 'inf_norm_to_beta_hseq.gz'), inf_norm_to_beta_hseq, fmt='%.4e') 
# np.savetxt(os.path.join(fpath, 'ave_one_norm_to_beta_hseq.gz'), ave_one_norm_to_beta_hseq, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'inf_norm_to_u_hseq.gz'), inf_norm_to_u_hseq, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'ave_one_norm_to_u_hseq.gz'), ave_one_norm_to_u_hseq, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'inf_norm_to_B.gz'), inf_norm_to_B, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'), ave_one_norm_to_B, fmt='%.4e')
# # np.savetxt(os.path.join(fpath, 'ave_envy'), ave_envy, fmt='%.4e')
# # np.savetxt(os.path.join(fpath, 'max_envy'), max_envy, fmt='%.4e')
# meta_data = {'T': T, 'dataset': dataset, 'n': n, 'm': m,  # 'number of duality gap and envy computations (num_logs)': num_logs, 
#             'seed': seed, 'delta0': delta0,
#             'varying_budgets': varying_budgets, 'first_budget_out_time': first_budget_out_time}
# with open(os.path.join(fpath, 'meta_data'), 'w') as mdff:
#     mdff.write(json.dumps(meta_data, indent=4))