# -*- coding: utf-8 -*-
# final version

import datetime

import numpy as np

import env.Environment as Envi
from algorithms import Base


#use one theta for all users
class LinUCB:
    def __init__(self, n, d, T):
        self.usernum = n  # the number of users in a server
        self.rounds = T  # the number of all rounds
        self.d = d  # dimension
        self.regret = np.zeros(self.rounds)
        self.reward = np.zeros(self.rounds)
        self.best_reward = np.zeros(self.rounds)
        self.users = {}  # all users' information
        self.init_users()  # initialize every user

        # TODO: lambda
        self.lambda_ = 1
        self.alpha_ = 0.5

        self.starttime = None
        self.endtime = None
        self.runtime = None

    def init_users(self):
        for user_index in range(self.usernum):
            self.users[user_index] = Base.User(self.d, user_index, self.rounds)

    # recommend the item to the user
    def recommend(self, user_index, items):
        V_t = np.zeros((self.d, self.d))
        b_t = np.zeros(self.d)
        for i in range(self.usernum):
            V, b, _ = self.users[i].get_info()
            V_t += V
            b_t += b
        M_t = np.eye(self.d) * self.lambda_ + V_t
        Minv = np.linalg.inv(M_t)
        theta_hat = np.dot(Minv, b_t)
        r_item_index = np.argmax(np.dot(items, theta_hat) + self.alpha_ * (np.matmul(items, Minv) * items).sum(axis=1))
        return r_item_index

    def start_time(self):
        return self.starttime

    def run_time(self):
        return self.runtime

    def run(self, envir):
        reward_all_rounds = []  # to save feedback in each round
        items_all_rounds = []  # to save the recommended item in each round
        theta_hat = {}  # to save the users' final theta information
        self.starttime = datetime.datetime.now()
        for i in range(1, self.rounds + 1):
            if i % 5000 == 0:
                print(i)
            user_index = envir.generate_users()  # random user arrives
            items = envir.get_items()
            r_item_index = self.recommend(user_index, items)
            selected_item = items[r_item_index]
            items_all_rounds.append(selected_item)
            # get feedback
            self.reward[i - 1], instant_rwd, self.best_reward[i - 1] = envir.feedback(items=items, i=user_index, k=r_item_index)
            reward_all_rounds.append(instant_rwd)
            # update the user's information
            self.users[user_index].store_info(selected_item, instant_rwd, i - 1, self.reward[i - 1], self.best_reward[i - 1])
            self.regret[i - 1] = self.best_reward[i - 1] - self.reward[i - 1]

            if i == self.rounds:
                for user_idx in range(self.usernum):
                    theta_hat[user_idx] = self.users[user_idx].theta

        self.endtime = datetime.datetime.now()
        self.runtime = self.endtime - self.starttime

        final_results = {}
        final_results["runtime"] = self.runtime
        final_results["regret"] = self.regret
        final_results["theta"] = theta_hat
        final_results["reward"] = self.reward
        final_results["items_all_rounds"] = items_all_rounds
        final_results["reward_all_rounds"] = reward_all_rounds
        return final_results


# maintain one theta for one user
class LinUCB_Ind(LinUCB):
    def __init__(self, n, d, T):
        super().__init__(n, d, T)
        self.alpha_ = 0.5 * 6

    # recommend the item to the user
    def recommend(self, user_index, items):
        user = self.users[user_index]
        V_t, b_t, _ = user.get_info()
        M_t = np.eye(self.d) * self.lambda_ + V_t
        Minv = np.linalg.inv(M_t)
        theta_hat = np.dot(Minv, b_t)
        r_item_index = np.argmax(np.dot(items, theta_hat) + self.alpha_ * (np.matmul(items, Minv) * items).sum(axis=1))
        return r_item_index
