import numpy as np
import time
import sys

from helper_FORM import *

delta_item = 0.2
delta_user = 0.2
idx_trial = 0

#################################################################
###################### DEFINE THE INSTANCE ######################
#################################################################

N = 30
M = 5
K = 3

popularity_weight = np.load("amazon_popularity_weight.npy")
prob_arrival = np.load("amazon_prob_arrival.npy")
revenue = np.load("amazon_revenue.npy")

item_objective = "visibility"

#################################################################
###################### START EXPERIMENTS ########################
#################################################################

time_start = time.time()

T = 2000

expected_revenues_lst = []
item_outcomes_lst = []
user_outcomes_lst = []

print ("Start trial {}".format(idx_trial))

np.random.seed(200 + idx_trial)

user_seq = np.random.choice(M, size = T, p = prob_arrival)
popularity_weight_est = popularity_weight + (np.random.random((N,M)) * 0.2 - 0.1)
num_of_epochs = np.ones((N,M)) * 200
purchase_count_total = popularity_weight_est * 200
purchase_count_total = purchase_count_total.astype(int)
w_hat = purchase_count_total/num_of_epochs
p_hat = np.ones(M) * 1/M

expected_revenues = []

ts = np.arange(1,T+1)
m_t = np.maximum(1,ts/np.log(T) - np.sqrt(ts * np.log(T)/2))

# confidence bound established for y-estimates and p-estimates
ell_yt = 2*np.log(T)/np.sqrt(m_t)
ell_pt = 5*np.sqrt(np.log(T)/ts)

# confidence bound we maintain for our constrained optimization problem
eta_t = np.log(T) * ell_pt /(10**7)
epsilon_t = np.minimum(1/N/M, np.power(N*M,-2/3) * np.power(ts,-1/20))

#################################################

t_current = 0

x_t = random_assortments(popularity_weight, K)

offered_assortment = None

current_assortments = [None] * M

flag = [True] * M

arrival_count_total = np.zeros(M)

while t_current < T:

    expected_rev_t = calculate_expected_revenue(x_t, popularity_weight, prob_arrival, revenue)
    expected_revenues_lst.append(expected_rev_t)

    item_outcomes_t = compute_item_outcomes(x_t, popularity_weight, prob_arrival, revenue, obj=item_objective)
    item_outcomes_lst.append(item_outcomes_t)

    user_outcomes_t = compute_user_outcomes(x_t, popularity_weight)
    user_outcomes_lst.append(user_outcomes_t)

    if t_current % 100 == 0:
        print ("Round {}".format(t_current))

    arriving_user_idx = user_seq[t_current]

    # if we have reached the end of the epoch for the current user type
    # we need to recompute the assortment that we'd like to present to the current user
    if flag[arriving_user_idx]:

        ### With small probability, we will perform exploration for this user ###
        if random.random() < epsilon_t[t_current] * N:
            x_t[arriving_user_idx] = random_sizeK_assortments(w_hat, K)
        else:
            ### Solve the relaxed fair recommendation problem under the estimated instance ###
            x_item_fair = compute_item_maxmin_fair(w_hat, p_hat, revenue, K, obj=item_objective)
            # x_item_fair = compute_item_KS_fair(w_hat, p_hat, revenue, K, obj=item_objective)
            item_fair_outcomes = compute_item_outcomes(x_item_fair, w_hat, p_hat, revenue, obj=item_objective)

            x_user_fair = compute_user_fair(w_hat, K)
            user_fair_outcomes = compute_user_outcomes(x_user_fair, w_hat)

            status, x_t, revenue_t = solve_fair_recommendation_problem(w_hat, p_hat, revenue, \
                                                                       item_fair_outcomes, user_fair_outcomes, K, \
                                                                       delta_item=delta_item, \
                                                                       delta_user=delta_user,\
                                                                       eta=eta_t[t_current],\
                                                                       item_obj=item_objective)

        # Prepare to draw from the assortments based on x_solution
        assortments = list(x_t[arriving_user_idx].keys())
        probabilities = list(x_t[arriving_user_idx].values())

        # Draw an assortment based on the adjusted probabilities
        offered_assortment = random.choices(assortments, weights=probabilities, k=1)[0]

        # For any user of this type, we will keep present this assortment until we reach the end of epoch
        current_assortments[arriving_user_idx] = offered_assortment
        flag[arriving_user_idx] = False

        num_of_epochs[offered_assortment, arriving_user_idx] += 1

    else:
        offered_assortment = current_assortments[arriving_user_idx]

    ### Show this assortment until no purchase ###
    purchase_t = draw_purchase_decision(offered_assortment, popularity_weight, arriving_user_idx)

    # if a no-purchase option showed up, we will
    # (i) update the estimated w for this user type
    # (ii) update its flag so that we will recompute the assortment for this user type when he/she arrives again
    if purchase_t == -1:
        flag[arriving_user_idx] = True # indicate that we have reached the end of current epoch

        ### Update the popularity weights ###
        for item_idx in offered_assortment:
            w_hat[item_idx, arriving_user_idx] = \
                purchase_count_total[item_idx, arriving_user_idx]/num_of_epochs[item_idx, arriving_user_idx]

    else:
        purchase_count_total[purchase_t, arriving_user_idx] += 1

    ### Update estimate for arrival rate ###
    arrival_count_total[arriving_user_idx] += 1
    p_hat = arrival_count_total/(t_current+1)

    # Increment the round
    t_current += 1

expected_revenues_lst = np.array(expected_revenues_lst)
item_outcomes_lst = np.array(item_outcomes_lst)
user_outcomes_lst = np.array(user_outcomes_lst)

time_end = time.time()

print ("Total time taken: {}".format(time_end - time_start))

np.save("expected_revenue_lst.npy", expected_revenues_lst)
np.save("item_outcomes_lst.npy", item_outcomes_lst)
np.save("user_outcomes_lst.npy", user_outcomes_lst)
