import numpy as np
import pandas as pd
from utils import *
import argparse, os

budget_cap = False

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--seed', '-sd', type=int, default=2333)
    parser.add_argument('--T', '-t', type=int, default=int(5e5))
    parser.add_argument('--dataset', '-ds', default='MovieLens')
    # parser.add_argument('--num_logs', '-nl', type=int, default=2000)
    args = parser.parse_args()
    # print('args.seed =', args.seed)
    seed = args.seed
    # T = args.T
    # T = n * 100
    dataset = args.dataset
    # num_logs = args.num_logs
except:
    print('not parsing command line inputs. use given parameters.')
    T, seed, dataset = int(5e5), 23333, 'MovieLens'

################################################################
# # randomly generate some data
# n, m = 50, 100
# B = np.random.uniform(size=n) + 0.2
# # B = np.ones(n)
# B = B / np.sum(B) # normalize B
# # v = np.abs(np.random.normal(0, 1, (n, m)))
# v = np.random.exponential(size = (n, m)) # v = np.array([ [1, 2], [1.5, 1.5], [2, 1] ])
# # v = (v.T / np.sum(v, 1)).T # normalize v

################################################################
# load from a dataset or generate random data
# dataset = 'MovieLens' # random
varying_budgets = False

# np.random.seed(1) # this seed is only for subsampling rows and columns
if dataset == 'MovieLens':
    df = pd.read_csv("../../data/movielens_1500x1500_lr0.1_wd1e-05_dim20_rmse0.88640326.csv", header=None)
    v = df.to_numpy()
    # n, m = 300, 1500
if dataset == 'Jokes':
    df = pd.read_csv("../../data/jokes_7200.csv")
    v = df.to_numpy()
if dataset == 'Household':
    df = pd.read_csv("../../data/household_items_understood.csv")
    v = df.to_numpy()

n, m = v.shape
T = n * 110 if budget_cap == False else n * 100 # scale with #(buyers) 
print('seed = {}, dataset = {}, n = {}, m = {}, budget_cap = {}'.format(seed, dataset, n, m, budget_cap))
# set budgets
np.random.seed(seed)
# B = np.random.uniform(size=n) + 0.2 if varying_budgets else np.ones(n)
B = np.ones(n) / n

################################################################
# normalization: 
# 1. supply of each item is 1/m (so total supply of all items is 1)
# 2. buyer i gets utility 1 from all items
# 3. sum(B) == 1
v = m * (v.T / np.sum(v, 1)).T

print('load offline eq prices and allocations...')

fpath = os.path.join('results', dataset, 'offline-eq')
x_opt = np.loadtxt(os.path.join(fpath, 'x')) 
p_opt = np.loadtxt(os.path.join(fpath, 'p'))
u_opt = np.sum(v * x_opt, axis = 1) # eq. utilities
# print('PR u_opt = {}'.format(u_opt))
beta_opt = B / u_opt
# compute primal obj # np.sum(x_opt,0)
pobj, dobj = eg_primal_obj_val(x_opt, v, B), eg_dual_obj_val(beta_opt, v, B)
print('offline (x,p) has dobj = {:.4e}, EG duality gap = {:.4e}'.format(dobj, dobj - pobj))

################################################################
# experiment parameters
do_stochastic = True
prefix = 'budget-cap-' if budget_cap else ''

if np.sum(v[0]) == 1:
    v = m * v # rescale, now supply of each item is 1/m, not 1

delta0 = 0.05
beta = np.ones(n)
beta_ave = np.zeros(n)
g_ave = np.zeros(n)
# u_running = np.zeros(n) # to collect utilities along the way
# record stuff across time
items_all_t = np.zeros(T, dtype=np.int) # j(t) sampled uniformly at random from {0, 1, ..., m-1}
winners_all_t = np.zeros(T, dtype=np.int) # i(t) = min of argmax over i of beta[i] * v[i, j(t)]
# spending[i] := cumulative spending of buyer i
# it gets incremented by beta[t,i] * v[i,j] if j = j(t) is sampled at time t and i = i(t) is the winner

# some are logged in every t (error norms in variables)
# some are only periodically (dgap and envy gap)
spending = np.zeros(n)
inf_norm_to_u_eq, inf_norm_to_beta_eq, inf_norm_to_B = [], [], []
ave_one_norm_to_u_eq, ave_one_norm_to_beta_eq, ave_one_norm_to_B = [], [], []
# log_interval = int(T//num_logs)
# duality_gap, max_envy, ave_envy = [], [], []

x_cumulative = np.zeros((n, m))
x_proportional = (B * np.ones(shape=(n, m)).T).T / m

# x_cumulative = x_proportional
# np.sum(x_feas, 0)
first_budget_out = False
for t in range(1, T+1):
    if do_stochastic:
        # sample an item
        j = np.random.choice(m)
        items_all_t[t-1] = j
        # remove buyers that have depleted their budgets
        if budget_cap:
            has_budget = spending + beta * v[:, j] <= B * T
            if np.sum(has_budget) < n and first_budget_out == False:
                first_budget_out_time, first_budget_out = t, True
        else:
            has_budget = np.ones(n, dtype=np.bool)
        # find winners for this item (just pick the lex. smallest winner, if tie)
        winner = np.argmax(beta[has_budget] * v[has_budget, j])
        winners_all_t[t-1] = winner
        spending[winner] += beta[winner] * v[winner, j] # option 1: use beta(t) to compute prices
        # u_running[winner] += v[winner, j] # winner gets its v[winner, j]
        # update g_bar: only the winner's entry can potentially be incremented
        g_ave = (t-1) * g_ave / t if t > 1 else np.ones(n) / n
        # note the m: since it is non-averaged sum over j
        g_ave[winner] += v[winner, j] / t
    else: # find the full subgradient
        winners = np.argmax(beta * v.T, 1) # winners[j] wins item j
        g_ave = (t-1) * g_ave / t
        for j, winner in enumerate(winners):
            # u_running[winner] += v[winner, j]/m # winners collect their rewards: same as g_ave
            g_ave[winner] += (v[winner, j]/m) / t
    # update beta
    beta = np.maximum((1-delta0) * B, np.minimum(1 + delta0, B / g_ave)) # spending[winner] += beta[winner] * v[winner, j] # option 2: use beta(t+1) to compute prices
    beta_ave = (t-1) * beta_ave / t + beta / t
    # compute duality gap
    x_cumulative[winner, j] += 1
    # logging
    inf_norm_to_u_eq.append(np.max(np.abs(g_ave - u_opt)/u_opt)) # relative to each u_opt
    inf_norm_to_beta_eq.append(np.max(np.abs(beta - beta_opt)/beta_opt))
    inf_norm_to_B.append(np.max(np.abs(B - spending/t)/B))
    ave_one_norm_to_u_eq.append(np.mean(np.abs(g_ave - u_opt)/u_opt))
    ave_one_norm_to_beta_eq.append(np.mean(np.abs(beta - beta_opt)/beta_opt))
    ave_one_norm_to_B.append(np.mean(np.abs(B - spending/t)/B))
    if t % (int(T//20)) == 0:
        # print('t = {}, dobj = {}, dgap = {:.4f}'.format(t, dobj, dgap))
        print('t = {}, max_beta_error = {:.4f}, max_u_error = {:.4f},  max_b_error = {:.4f}'.format(t, inf_norm_to_beta_eq[-1], inf_norm_to_u_eq[-1], inf_norm_to_B[-1]))
    
res = g_ave / u_opt
print('max and min of g_ave/T divided by u_opt[i]: {:.4f}, {:.4f}'.format(np.min(res), np.max(res)))

# construct final empirical x
x = x_cumulative / T

# u_time_ave = np.sum(v * x, 1) # it is u_running / T
# b_eq_all = [p_opt.T @ x[i] for i in range(n)]

# np.linalg.norm(u_time_ave - u_opt, 1) / np.linalg.norm(u_opt)

# # final envy: <v[i], x[k]> / B[k]
# umat_budget_scaled = np.array([[v[i].T @ x[k] / B[k] for k in range(n)] for i in range(n)]) # umat[2,3] - v[2].T @ x[3] == 0
# envy_final = np.max(umat_budget_scaled, 1) - np.diag(umat_budget_scaled)
# imax = np.argmax(envy_final)
# print('max envy = {} through buyer i={}'.format(envy_final[imax], imax))
# res = np.array(spending)/T

# plot something
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_theme()

# # inf_norm_to_u_eq, inf_norm_to_beta_eq, duality_gap = np.array(inf_norm_to_u_eq), np.array(inf_norm_to_beta_eq), np.array(duality_gap)
# plt.plot(range(1, T+1),  inf_norm_to_u_eq, label = r'$||\bar{g}^t - u^*||_\infty$', linestyle='solid')
# plt.plot(range(1, T+1),  inf_norm_to_beta_eq, label = r'$||\beta^t - \beta^*||_\infty$', linestyle='dashed')
# plt.plot(range(1, T+1, log_interval), duality_gap, label = r'${\rm dgap}_t$', linestyle='dashed')
# # plt.plot(range(1, T+1, log_interval), max_envy, label = r'$\|\rho^t\|_\infty$', linestyle='dashdot')
# # plt.plot(range(1, T+1, log_interval), ave_envy, label = r'$\|\rho^t\|_1/n$', linestyle='dashed')
# plt.yscale('log'), plt.xscale('log')
# plt.xlabel('t')
# plt.title('{}, n = {}, m = {}'.format(dataset, n, m))
# plt.legend()
# plt.savefig('g-ave-and-beta-conv.pdf')

# save results
import pandas as pd
import json
fpath = os.path.join('results', dataset.lower(), prefix + 'sd-'+str(seed))
print('fpath = {}'.format(fpath))
os.makedirs(fpath, exist_ok=True)
# np.savetxt(os.path.join(fpath, 'duality_gap'), duality_gap, fmt='%.4e') 
np.savetxt(os.path.join(fpath, 'inf_norm_to_beta_eq.gz'), inf_norm_to_beta_eq, fmt='%.4e') 
np.savetxt(os.path.join(fpath, 'ave_one_norm_to_beta_eq.gz'), ave_one_norm_to_beta_eq, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'inf_norm_to_u_eq.gz'), inf_norm_to_u_eq, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'ave_one_norm_to_u_eq.gz'), ave_one_norm_to_u_eq, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'inf_norm_to_B.gz'), inf_norm_to_B, fmt='%.4e')
np.savetxt(os.path.join(fpath, 'ave_one_norm_to_B.gz'), ave_one_norm_to_B, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'ave_envy'), ave_envy, fmt='%.4e')
# np.savetxt(os.path.join(fpath, 'max_envy'), max_envy, fmt='%.4e')
meta_data = {'T': T, 'dataset': dataset, 'n': n, 'm': m,  # 'number of duality gap and envy computations (num_logs)': num_logs, 
            'seed': seed, 'delta0': delta0,
            'varying_budgets': varying_budgets, 'first_budget_out_time': first_budget_out_time}
with open(os.path.join(fpath, 'meta_data'), 'w') as mdff:
    mdff.write(json.dumps(meta_data, indent=4))