#Stochastic choice
import wandb
import numpy as np
import torch
from tqdm import tqdm
user_vectors = np.load('uservectors.npy')
item_vectors = np.load('itemvectors.npy')
n_items = len(item_vectors)
n_users = len(user_vectors)

def get_recommendations(user_vector, already_rated):
    predicted_ratings = {index: None for index in range(0, n_items)}
    for item in range(0, n_items):
        item_rating = user_vector @ torch.tensor(item_vectors[item])
        predicted_ratings[item] = item_rating
    sorted_pred_ratings = sorted(predicted_ratings.items(), key=lambda x: x[1], reverse=True)

    count_var = 0
    top_10_item_names=[]
    for i in sorted_pred_ratings:
        item_id = i[0]
        if item_id not in already_rated:
            count_var+=1
            top_10_item_names.append(item_id)
        if count_var==10:
            break
    return top_10_item_names

def get_recommendation_scores(user_vector, already_rated):
    predicted_ratings = {index: None for index in range(0, n_items)}
    for item in range(0, n_items):
        item_rating = user_vector @ torch.tensor(item_vectors[item])
        predicted_ratings[item] = item_rating
    sorted_pred_ratings = sorted(predicted_ratings.items(), key=lambda x: x[1], reverse=True)
    count_var = 0
    top_10_item_names=[]
    top_10_item_scores=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=torch.float64)
    for i in sorted_pred_ratings:
        item_id = i[0]
        item_score = i[1]
        if item_id not in already_rated:
            top_10_item_scores[count_var]=item_score
            count_var+=1
            top_10_item_names.append(item_id)
            # top_10_item_scores.append(item_score)
        if count_var==10:
            break
    return top_10_item_names, top_10_item_scores

#define user_type.choice
#define user_type.rating
import random
class UserType:
    def __init__(self, user_type, rec_scores, popularities, beta=0):
        self.user_type = user_type
        self.rec_scores = rec_scores
        self.popularities = popularities
        self.beta = beta
        self.set_user_type_values()

    def set_user_type_values(self):
        if self.user_type == "Enjoyer":
            self.item_choice = 0
            self.item_rating = 5
        elif self.user_type == "Hater":
            self.item_choice = 0
            self.item_rating = 1
        elif self.user_type == "Random Enjoyer":
            self.item_choice = random.randint(0, len(self.rec_scores) - 1)
            self.item_rating = 5
        elif self.user_type == "Random Hater":
            self.item_choice = random.randint(0, len(self.rec_scores) - 1)
            self.item_rating = 1
        elif self.user_type == "Choice Enjoyer":
            # probabilities = np.exp(self.beta * np.array(self.rec_scores))
            # probabilities /= probabilities.sum() 
            # self.item_choice = np.random.choice(len(self.rec_scores), p=probabilities)
            # self.item_rating = 5
            
            probabilities = torch.softmax(torch.mul(self.rec_scores, self.beta), dim=0)

            # Choose an item based on probabilities
            self.item_choice = torch.multinomial(probabilities, 1).item()

            # Set item rating to 5
            self.item_rating = torch.tensor(5, dtype=torch.float64)
        elif self.user_type == "Choice Hater":
            probabilities = np.exp(self.beta * np.array(self.rec_scores))
            probabilities /= probabilities.sum() 
            self.item_choice = np.random.choice(len(self.rec_scores), p=probabilities)
            self.item_rating = 1
        elif self.user_type == "Popular Enjoyer":
            probabilities = np.exp(self.beta * np.array(self.rec_scores))
            probabilities /= probabilities.sum() 
            self.item_choice = np.random.choice(len(self.values), p=probabilities)
            if self.popularities[self.item_choice]<500:
                self.item_rating = 5
            else:
                self.item_rating = 1
        elif self.user_type == "Niche Enjoyer":
            probabilities = np.exp(self.beta * np.array(self.rec_scores))
            probabilities /= probabilities.sum() 
            self.item_choice = np.random.choice(len(self.values), p=probabilities)
            if self.popularities[self.item_choice]<500:
                self.item_rating = 1
            else:
                self.item_rating = 5
        else:
            raise ValueError("Invalid user type. Please choose a valid user type.")
        
def update_user_tensor(user_vector, items, ratings):
    Q_list = [item_vectors[item] for item in items]
    Q = torch.tensor(Q_list, dtype=torch.float64)
    # p = np.linalg.inv(Q.T @ Q) @ Q.T @ ratings
    p = torch.inverse(Q.t() @ Q) @ Q.t() @ ratings
    return p

import pickle
file = open('items_rated', 'rb')

# dump information to that file
data = pickle.load(file)

# close the file
file.close()

#5 is the user we are looking at
chosen_item = data[5][-1][0]
chosen_rating = data[5][-1][1]

file = open('ui_rating_dict', 'rb')

# dump information to that file
user_item_rating_dict = pickle.load(file)

# close the file
file.close()

already_rated = list(user_item_rating_dict[5].keys())
ratings = user_item_rating_dict[5]

for i in ratings:
    ratings[i]=torch.tensor(ratings[i], dtype=torch.float64)
    
rating_tensor = torch.tensor(list(ratings.values()))

ratings_dict = ratings

# wandb.init(project="reachability-single-stochastic", name="run_1_lr_0.08")
# user_action = torch.tensor(chosen_rating, requires_grad=True, dtype=torch.float64)
# optimizer = torch.optim.Adam([user_action], lr=0.08)
# item_to_be_reached = 19
# min_score = float('inf')
# for epoch in range(1, 20):
#     user_action_clamped = user_action.clamp(1, 5) 
#     rating_tensor = torch.tensor(list(ratings_dict.values()))
#     user_vector_initial = torch.from_numpy(user_vectors[5])
#     user_vector = user_vector_initial
#     time_max = 5
#     already_rated = list(user_item_rating_dict[5].keys())
#     #print(already_rated)
#     #already_rated = already_rated[:-1]
#     ratings_old = rating_tensor#user_item_rating_dict[5]
#     # del ratings[chosen_item]
#     ratings_old[-1] = user_action_clamped#ratings[chosen_item] = user_action
#     # ratings_old_temp = torch.cat((ratings_old[:-1], user_action), dim=0)
#     n = len(ratings_old)
#     zeros_to_add = torch.zeros(time_max)
#     ratings = torch.cat((ratings_old, zeros_to_add), dim=0)
#     #print(ratings)
#     #already_rated.append(chosen_item)
#     user_vector = update_user_tensor(user_vector, already_rated, ratings[:n])
#     for timestep in range(1, time_max+1):
#         recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#         #print(recommendation_scores)
#         min_score = min(recommendation_scores)
#         usr_type = UserType("Choice Enjoyer", recommendation_scores, [], beta=0.8)
#         choice_of_item = recommendations[usr_type.item_choice]
#         ratings[n+timestep-1] = torch.tensor(usr_type.item_rating)
#         already_rated.append(choice_of_item)
#         user_vector = update_user_tensor(user_vector, already_rated, ratings[:n+timestep])
#     item_rating = -torch.matmul(user_vector, torch.from_numpy(item_vectors[item_to_be_reached]))
#     #print(abs(item_rating))
#     wandb.log({
#         'abs_item_rating': abs(item_rating.item()),
#         'user_action': user_action.tolist(),
#         'user_action_clamped': user_action_clamped.tolist()
#     })
#     #print(min_score-abs(item_rating))
#     if abs(item_rating)>min_score:
#         print("Item Reached")
#         break
#     #item_rating.requires_grad_()
#     #user_action.retain_grad()
#     # print(user_action.is_leaf)
#     item_rating.backward()
#     optimizer.step()
#     # print(user_action.grad)
#     #print(user_action)
#     optimizer.zero_grad()
#     #user_action = user_action.clamp(1, 5) 

#Stochastic with averaging
#Stochastic choice
# import wandb

wandb.init(project="reachability-single-stochastic_averaging", name="run_1_lr_0.08")
user_action = torch.tensor(chosen_rating, requires_grad=True, dtype=torch.float64)
optimizer = torch.optim.Adam([user_action], lr=0.08)
item_to_be_reached = 19
min_score = float('inf')
for epoch in tqdm(range(1, 50)):
    item_rating = 0 
    rating_vals = torch.zeros(20)
    for int_var in range(0,20):
        user_action_clamped = user_action.clamp(1, 5) 
        rating_tensor = torch.tensor(list(ratings_dict.values()))
        user_vector_initial = torch.from_numpy(user_vectors[5])
        user_vector = user_vector_initial
        time_max = 0
        already_rated = list(user_item_rating_dict[5].keys())
        #print(already_rated)
        #already_rated = already_rated[:-1]
        ratings_old = rating_tensor#user_item_rating_dict[5]
        # del ratings[chosen_item]
        ratings_old[-1] = user_action_clamped#ratings[chosen_item] = user_action
        # ratings_old_temp = torch.cat((ratings_old[:-1], user_action), dim=0)
        n = len(ratings_old)
        zeros_to_add = torch.zeros(time_max)
        ratings = torch.cat((ratings_old, zeros_to_add), dim=0)
        #print(ratings)
        #already_rated.append(chosen_item)
        user_vector = update_user_tensor(user_vector, already_rated, ratings[:n])
        for timestep in range(1, time_max+1):
            recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
            #print(recommendation_scores)
            min_score = min(recommendation_scores)
            usr_type = UserType("Choice Enjoyer", recommendation_scores, [], beta=0.8)
            choice_of_item = recommendations[usr_type.item_choice]
            ratings[n+timestep-1] = torch.tensor(usr_type.item_rating)
            already_rated.append(choice_of_item)
            user_vector = update_user_tensor(user_vector, already_rated, ratings[:n+timestep])
        item_rating = -torch.matmul(user_vector, torch.from_numpy(item_vectors[item_to_be_reached]))
        rating_vals[int_var] = item_rating
    fin_rating = torch.sum(rating_vals)/len(rating_vals)
    #print(abs(fin_rating))
    #print(abs(item_rating))
    wandb.log({
        'abs_item_rating': abs(fin_rating.item()),
        'user_action': user_action.tolist(),
        'user_action_clamped': user_action_clamped.tolist()
    })
    #print(min_score-abs(item_rating))
    if abs(fin_rating)>min_score:
        print("Item Reached")
        break
    #item_rating.requires_grad_()
    #user_action.retain_grad()
    # print(user_action.is_leaf)
    fin_rating.backward()
    optimizer.step()
    # print(user_action.grad)
    #print(user_action)
    optimizer.zero_grad()
    #user_action = user_action.clamp(1, 5) 


# #Deterministic choice
# wandb.init(project="reachability-single-deterministic", name="run_1_lr_0.08")
# user_action = torch.tensor(chosen_rating, requires_grad=True, dtype=torch.float64)
# optimizer = torch.optim.Adam([user_action], lr=0.08)
# item_to_be_reached = 8
# min_score = float('inf')
# for epoch in range(1, 50):
#     user_action_clamped = user_action.clamp(1, 5) 
#     rating_tensor = torch.tensor(list(ratings_dict.values()))
#     user_vector_initial = torch.from_numpy(user_vectors[5])
#     user_vector = user_vector_initial
#     time_max = 5
#     already_rated = list(user_item_rating_dict[5].keys())
#     #print(already_rated)
#     #already_rated = already_rated[:-1]
#     ratings_old = rating_tensor#user_item_rating_dict[5]
#     # del ratings[chosen_item]
#     ratings_old[-1] = user_action_clamped#ratings[chosen_item] = user_action
#     # ratings_old_temp = torch.cat((ratings_old[:-1], user_action), dim=0)
#     n = len(ratings_old)
#     zeros_to_add = torch.zeros(time_max)
#     ratings = torch.cat((ratings_old, zeros_to_add), dim=0)
#     #print(ratings)
#     #already_rated.append(chosen_item)
#     user_vector = update_user_tensor(user_vector, already_rated, ratings[:n])
#     for timestep in range(1, time_max+1):
#         recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#         #print(recommendation_scores)
#         min_score = min(recommendation_scores)
#         # usr_type = UserType("Choice Enjoyer", recommendation_scores, [], beta=0.8)
#         choice_of_item = recommendations[0]
#         ratings[n+timestep-1] = 5
#         already_rated.append(choice_of_item)
#         user_vector = update_user_tensor(user_vector, already_rated, ratings[:n+timestep])
#     item_rating = -torch.matmul(user_vector, torch.from_numpy(item_vectors[item_to_be_reached]))
#     print(abs(item_rating))
#     wandb.log({
#         'abs_item_rating': abs(item_rating.item()),
#         'user_action': user_action.tolist(),
#         'user_action_clamped': user_action_clamped.tolist()
#     })
#     #print(min_score-abs(item_rating))
#     if abs(item_rating)>min_score:
#         print("Item Reached")
#         break
#     #item_rating.requires_grad_()
#     #user_action.retain_grad()
#     # print(user_action.is_leaf)
#     item_rating.backward()
#     optimizer.step()
#     # print(user_action.grad)
#     #print(user_action)
#     optimizer.zero_grad()
#     #user_action = user_action.clamp(1, 5) 

#5 is the user we are looking at
# item_and_rating =  data[5][-5:]
# chosen_items = [i[0] for i in item_and_rating]
# chosen_ratings = [i[1] for i in item_and_rating]

# wandb.init(project="reachability-multi-history", name="run_1_lr_0.05")

# #Stochastic choice
# user_action = torch.tensor(chosen_ratings, requires_grad=True)
# optimizer = torch.optim.Adam([user_action], lr=0.5)
# item_to_be_reached = 13
# for epoch in range(1, 50):
#     user_action_clamped = user_action.clamp(1, 5)
#     rating_tensor = torch.tensor(list(ratings_dict.values()))
#     user_vector_initial = torch.from_numpy(user_vectors[5])
#     user_vector = user_vector_initial
#     time_max = 5
#     already_rated = list(user_item_rating_dict[5].keys())
#     already_rated = already_rated[:-5]
#     #print(already_rated)
#     #already_rated = already_rated[:-1]
#     ratings_old = rating_tensor[:-5]
#     n = len(ratings_old)
#     zeros_to_add = torch.zeros(time_max)
#     ratings = torch.cat((ratings_old, zeros_to_add), dim=0)
#     # del ratings[chosen_item]
#     for timestep in tqdm(range(0,time_max)):
#         curr_item = chosen_items[timestep]
#         ratings[n+timestep] = user_action_clamped[timestep]
#         already_rated.append(curr_item)
#         user_vector = update_user_tensor(user_vector, already_rated, ratings[:n+timestep+1])
#     # ratings[chosen_item] = user_action.item()
#     #print(ratings)
#     #already_rated.append(chosen_item)
#     # user_vector = torch.from_numpy(update_user_vector(user_vector, already_rated, np.array(list(ratings.values()))))
#     # for timestep in tqdm(range(1, time_max+1)):
#     #     recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#     #     usr_type = UserType("Choice Enjoyer", recommendation_scores, temp_popularity=[])
#     #     choice_of_item = recommendations[usr_type.item_choice]
#     #     ratings[choice_of_item] = usr_type.item_rating
#     #     already_rated.append(choice_of_item)
#     #     user_vector = torch.from_numpy(update_user_vector(user_vector, already_rated, np.array(list(ratings.values()))))
#     item_rating = -torch.matmul(user_vector, torch.from_numpy(item_vectors[item_to_be_reached]))
#     print(abs(item_rating))
#     wandb.log({
#         'abs_item_rating': abs(item_rating.item()),
#         'user_action_1': user_action.tolist()[0],
#         'user_action_2': user_action.tolist()[1],
#         'user_action_3': user_action.tolist()[2],
#         'user_action_4': user_action.tolist()[3],
#         'user_action_5': user_action.tolist()[4],
#         'user_action_clamped_1': user_action_clamped.tolist()[0],
#         'user_action_clamped_2': user_action_clamped.tolist()[1],
#         'user_action_clamped_3': user_action_clamped.tolist()[2],
#         'user_action_clamped_4': user_action_clamped.tolist()[3],
#         'user_action_clamped_5': user_action_clamped.tolist()[4]
#     })
#     recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#     if min(recommendation_scores)<item_rating:
#         print("Item Reached")
#         break
#     # item_rating.requires_grad_()
#     item_rating.backward()
#     optimizer.step()
#     # print(user_action.grad)
#     # print(user_action)
#     optimizer.zero_grad()
#     #user_action = user_action.clamp(1, 5) 
# wandb.finish()
# wandb.init(project="reachability-multi-future", name="run_3_lr_0.1")
# user_action = torch.tensor(chosen_ratings, requires_grad=True)
# optimizer = torch.optim.Adam([user_action], lr=0.1)
# item_to_be_reached = 10
# for epoch in range(1, 50):
#     user_action_clamped = user_action.clamp(1, 5)
#     rating_tensor = torch.tensor(list(ratings_dict.values()))
#     user_vector_initial = torch.from_numpy(user_vectors[5])
#     user_vector = user_vector_initial
#     time_max = 5
#     already_rated = list(user_item_rating_dict[5].keys())
#     already_rated = already_rated[:-5]
#     #print(already_rated)
#     #already_rated = already_rated[:-1]
#     ratings_old = rating_tensor[:-5]
#     n = len(ratings_old)
#     zeros_to_add = torch.zeros(time_max)
#     ratings = torch.cat((ratings_old, zeros_to_add), dim=0)
#     # del ratings[chosen_item]
#     item_rec_list=[]
#     for timestep in tqdm(range(0,time_max)):
#         recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#         curr_item = recommendations[0]
#         item_rec_list.append(curr_item)
#         ratings[n+timestep] = user_action_clamped[timestep]
#         already_rated.append(curr_item)
#         user_vector = update_user_tensor(user_vector, already_rated, ratings[:n+timestep+1])
#     print(item_rec_list)
#     # ratings[chosen_item] = user_action.item()
#     #print(ratings)
#     #already_rated.append(chosen_item)
#     # user_vector = torch.from_numpy(update_user_vector(user_vector, already_rated, np.array(list(ratings.values()))))
#     # for timestep in tqdm(range(1, time_max+1)):
#     #     recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#     #     usr_type = UserType("Choice Enjoyer", recommendation_scores, temp_popularity=[])
#     #     choice_of_item = recommendations[usr_type.item_choice]
#     #     ratings[choice_of_item] = usr_type.item_rating
#     #     already_rated.append(choice_of_item)
#     #     user_vector = torch.from_numpy(update_user_vector(user_vector, already_rated, np.array(list(ratings.values()))))
#     item_rating = -torch.matmul(user_vector, torch.from_numpy(item_vectors[item_to_be_reached]))
#     #print(abs(item_rating))
#     wandb.log({
#         'abs_item_rating': abs(item_rating.item()),
#         'user_action_1': user_action.tolist()[0],
#         'user_action_2': user_action.tolist()[1],
#         'user_action_3': user_action.tolist()[2],
#         'user_action_4': user_action.tolist()[3],
#         'user_action_5': user_action.tolist()[4],
#         'user_action_clamped_1': user_action_clamped.tolist()[0],
#         'user_action_clamped_2': user_action_clamped.tolist()[1],
#         'user_action_clamped_3': user_action_clamped.tolist()[2],
#         'user_action_clamped_4': user_action_clamped.tolist()[3],
#         'user_action_clamped_5': user_action_clamped.tolist()[4]
#     })
#     recommendations, recommendation_scores = get_recommendation_scores(user_vector, already_rated)
#     if min(recommendation_scores)<item_rating:
#         print("Item Reached")
#         break
#     # item_rating.requires_grad_()
#     item_rating.backward()
#     optimizer.step()
#     # print(user_action.grad)
#     # print(user_action)
#     optimizer.zero_grad()
#     #user_action = user_action.clamp(1, 5) 
# # wandb.finish()