from utils import compute_me_fin_dim
import numpy as np

seed = 2022
np.random.seed(seed)
n, m = 100, 300
T = 200 * n
repeat = 10

v, = np.load('instances/v_mat_movielens.npz').values()
s = np.random.uniform(size=m) # fixed distribution
s = s / np.sum(s)
v = (v.T/(v@s)).T
B = np.ones(n)/n

# compute underlying equilibrium quantities
x_uleq, p_uleq = compute_me_fin_dim( v, B, s, max_iter=5000 )
u_uleq = np.sum( v*x_uleq, 1 )
beta_uleq = B / u_uleq

def sample_all_arrivals_iid_finite_item_set(s, T=1000, repeat=10):
    ''' given distribution s over items and horizon T
        return a sample path of item arrivals '''
    m = len(s)
    if s is None: 
        return np.random.choice(m, size=(repeat,T))
    return np.random.choice(m, size=(repeat,T), p=s)

def generate_perturb_instance(s, chunksize=m//100):
    ''' return a perturbed instance, i.e., a sequence of perturbed distributions 
        larger chunksize ==> more perturb '''
    m = len(s)
    all_distributions = np.array( [s for _ in range(T)] )
    start = 0
    for t in range(T):
        items = [ idx % m for idx in range( start, start+chunksize ) ]
        ss = all_distributions[t][items].sum()
        all_distributions[t][items[0]] = ss
        all_distributions[t][items[1:]] = 0
        start += chunksize
    return all_distributions

def sample_all_arrivals(instance, repeat=10):
    ''' sample item arrivals given a non-iid instance
        here, instance[t] is the item distribution at time t '''
    T, m = instance.shape
    sampled_all = [
        np.random.choice(
            m, size=repeat, p=instance[t] 
        ) for t in range(T) 
    ]
    return np.array(sampled_all).T

def pace(v, B, items_sample_path, u_uleq, beta_uleq):
    ''' give a sample path
        run pace and return beta and utility errors '''
    
    n, m = v.shape
    T = len(items_sample_path)

    # run pace
    delta0 = 0.05
    beta = np.ones(n) # initial beta: minimizer of the regularizer
    beta_ave = np.zeros(n)
    g_ave = np.zeros(n)
    winners_all_t, inf_norm_to_u_uleq, inf_norm_to_beta_uleq = [], [], []

    for t in range(1, T+1): 
        j = items_sample_path[t-1]
        # find winners for this item (just pick the lex. smallest winner, if tie)
        winner = np.argmax(beta * v[:, j])
        winners_all_t.append(winner)
        g_ave = (t-1) * g_ave / t if t > 1 else np.ones(n) / n
        g_ave[winner] += v[winner, j] / t

        # update beta
        beta = np.maximum( (1-delta0) * B, np.minimum(1 + delta0, B / g_ave) ) # spending[winner] += beta[winner] * v[winner, j] # option 2: use beta(t+1) to compute prices
        beta_ave = (t-1) * beta_ave / t + beta / t

        # logging
        inf_norm_to_u_uleq.append(np.max(np.abs(g_ave - u_uleq)/u_uleq))
        inf_norm_to_beta_uleq.append(np.max(np.abs(beta - beta_uleq)/beta_uleq))

        if t % (int(T//10)) == 0: 
            print(
                't={}, max_beta_error={:.4f}, max_u_error={:.4f}'.format(
                    t, inf_norm_to_beta_uleq[-1], inf_norm_to_u_uleq[-1]
                )
            )

    return inf_norm_to_u_uleq, inf_norm_to_beta_uleq

small_perturb_instance = generate_perturb_instance( s, chunksize=m//10 )
large_perturb_instance = generate_perturb_instance( s, chunksize=m//5 )

# total TV distance
delta_small = 0.5 * sum( np.linalg.norm(s-d, 1) for d in small_perturb_instance ) / T
delta_large = 0.5 * sum( np.linalg.norm(s-d, 1) for d in large_perturb_instance ) / T

iid_sample_paths = sample_all_arrivals_iid_finite_item_set(s, T=T, repeat=10)
small_perturb_sample_paths = sample_all_arrivals(small_perturb_instance, repeat=10)
large_perturb_sample_paths = sample_all_arrivals(large_perturb_instance, repeat=10)

# run pace on these sample paths
results = []
for sd in range(10):
    print('================== sd = {} =================='.format(sd))
    results.append(
        {
            'iid': pace( v, B, iid_sample_paths[sd], u_uleq, beta_uleq ), 
            'small_perturb': pace( v, B, small_perturb_sample_paths[sd], u_uleq, beta_uleq ), 
            'large_perturb': pace( v, B, large_perturb_sample_paths[sd], u_uleq, beta_uleq )
        }
    )

# aggregate & plot
from matplotlib import pyplot as plt
import seaborn as sns
plt.clf()
sns.set_theme()
fig = plt.figure(figsize=(6, 4))
t0 = int(T//50)
skip_size = max(int(T//2000), 5)
num_dp = (T - t0) // skip_size

# plot beta convergence
all_data_arrays = (
    np.array([results[sd]['iid'][1] for sd in range(10)]), 
    np.array([results[sd]['small_perturb'][1] for sd in range(10)]), 
    np.array([results[sd]['large_perturb'][1] for sd in range(10)]), 
)
all_labels = (
     'i.i.d.', 
     r'small perturb ($\delta={:.2f}$)'.format(delta_small), 
     r'large perturb ($\delta={:.2f}$)'.format(delta_large), 
)

for data_array, label in zip(all_data_arrays, all_labels):
    plt.errorbar(
        np.arange(t0+1, T+1, skip_size), 
        np.mean(data_array[:, np.arange(t0, T, skip_size)], axis=0), 
        (1/np.sqrt(10)) * np.std(data_array[:, np.arange(t0, T, skip_size)], axis=0), 
        errorevery=num_dp//10,
        linestyle='solid', 
        label=label, 
    )
# horizontal bars
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*20) == 0]
plt.xticks(range(0, T+1, T//5))
plt.xlabel('t')
plt.title(r'max$_i$ $|\beta_i^t - \beta^*_i| / \beta^*_i $')
plt.legend( prop={'size': 12}, loc='center right' )

plt.savefig(f'../plots/iid_vs_small_vs_large_perturb.pdf')
