import numpy as np
from .utils import isInvertible
import random
from ..configs.default import offline_learn_method

class Base:
    def __init__(self,nu, d, T, ni, offline_ratio = 0.5):
        self.nu = nu
        self.d = d
        self.T = T
        self.offline_learn_T = int(T * offline_ratio)
        self.test_T = int(T * (1-offline_ratio))
        self.ni = ni
        self.items = np.zeros((ni, d))
        self.ucb_lambda = 0.5
        self.rewards = np.zeros(self.T)
        self.best_rewards = np.zeros(self.T)
        self.test_rewards = np.zeros(self.test_T)
        self.best_test_rewards = np.zeros(self.test_T)
        self.epsilon = 0.1
        self.offline_learn_method = offline_learn_method
        
    def _beta(self, N, t):
        return np.sqrt(self.d * np.log(1 + N / self.d) + 4 * np.log(t) + np.log(2)) + 1

    def _select_item_ucb(self, S, Sinv, theta, items, N, t):
        return np.argmax(np.dot(items, theta) + self._beta(N, t) * (np.matmul(items, Sinv) * items).sum(axis = 1))

    def _select_worst_case(self, S, Sinv, theta, items, N, t):
        return np.argmax(np.dot(items, theta) + self._beta(N, t) * (np.matmul(items, Sinv) * items).sum(axis = 1))

    
    def _select_item_empirical(self, theta, items):
        return np.argmax(np.dot(items, theta))
    

    def offline_learn_recommend(self, t = 0, i = 0):
        if self.offline_learn_method == 'random':
            return np.random.randint(0, self.ni)
        elif self.offline_learn_method == 'LinUCB_ind':
            if len(self.theta) == self.d:
                return self._select_item_ucb(self.S, self.Sinv, self.theta, self.items, t, t)
            else:
                return self._select_item_ucb(self.S[i], self.Sinv[i], self.theta[i], self.items, self.N[i], t)

        elif self.offline_learn_method == 'empirical':
            if len(self.theta) == self.d:
                return self._select_item_empirical(self.theta, self.items)
            else:
                return self._select_item_empirical(self.theta[i], self.items)
        elif self.offline_learn_method == 'epsilon_greedy':
            if np.random.rand() < self.epsilon:
                return np.random.randint(0, self.ni)
            else:
                if len(self.theta) == self.d:
                    return self._select_item_empirical(self.theta, self.items)
                else:   
                    return self._select_item_empirical(self.theta[i], self.items)

    def test_recommend(self, i, items, t):
        return

    def store_info(self, i, x, y, t, r, br):
        return

    def update(self):
        return
    
    def _update_inverse(self, S, b, Sinv_old, x, t):
        Sinv = np.linalg.inv(S)
        theta = np.matmul(Sinv, b)
        return Sinv, theta

    def run(self, envir):
        for t in range(self.offline_learn_T):
            self.I = envir.generate_users()
            for i in self.I:
                self.items = envir.get_items()
                recommended = self.offline_learn_recommend(t=t, i=i)
                x = self.items[recommended]
                y, r, br = envir.feedback(i=i, k=recommended)
                self.store_info(i=i, x=x, y=y, t=t, r=r, br=br)
        if len(self.theta) == self.d:
            self.Sinv, self.theta = self._update_inverse(self.S, self.b, 0, 0, 0)
        else:
            for user_index in range(self.nu):
                self.Sinv[user_index], self.theta[user_index] = self._update_inverse(self.S[user_index], self.b[user_index], 0, 0, 0)

        self.update()
        test_users = list(range(self.nu))
        np.random.shuffle(test_users)
        for t in  range(self.offline_learn_T, self.T):
            self.I = envir.generate_users()
            for test_user in self.I:
                self.items = envir.get_items()
                recommended = self.test_recommend(test_user, self.items, self.offline_learn_T)
                y, r, br = envir.test_feedback(i=test_user, k=recommended)
                self.store_info_test(t=t, r=r, br=br)
        self.test_rewards[:] = self.rewards[self.offline_learn_T:]
        self.best_test_rewards[:] = self.best_rewards[self.offline_learn_T:]
        return 0

class Baseline(Base):
    def __init__(self, nu, d, T, ni):
        super(Baseline, self).__init__(nu, d, T, ni, offline_ratio=1)
        self.S = np.eye(d)
        self.b = np.zeros(d)
        self.Sinv = np.eye(d)
        self.theta = np.zeros(d)
        self.test_rewards = np.zeros(self.T)
        self.best_test_rewards = np.zeros(self.T)
    def test_recommend(self, i, items, t):
        return self._select_item_ucb(self.S, self.Sinv, self.theta, items, t, t)

    def store_info(self, i, x, y, t, r, br):
        self.rewards[t] += r
        self.best_rewards[t] += br
        self.S += np.outer(x, x)
        self.b += y * x
        self.Sinv, self.theta = self._update_inverse(self.S, self.b, self.Sinv, x, t)

    def store_info_test(self,t, r, br):
        self.rewards[t] = r
        self.best_rewards[t] = br


    def run(self, envir):
        for t in range(self.offline_learn_T):
            self.I = envir.generate_users()
            for i in self.I:
                self.items = envir.get_items()
                recommended = self.offline_learn_recommend(t=t, i=i)
                x = self.items[recommended]
                y, r, br = envir.feedback(i=i, k=recommended)
                self.store_info(i=i, x=x, y=y, t=t, r=r, br=br)
        self.test_rewards[:] = self.rewards[:]
        self.best_test_rewards[:] = self.best_rewards[:]
        return 0

        for t in  range(self.offline_learn_T, self.T):
            self.I = envir.generate_users()
            for test_user in self.I:
                self.items = envir.get_items()
                recommended = self.test_recommend(test_user, self.items, self.offline_learn_T)
                y, r, br = envir.test_feedback(i=test_user, k=recommended)
                self.store_info_test(t=t, r=r, br=br)
        self.test_rewards[:] = self.rewards[self.offline_learn_T:]
        self.best_test_rewards[:] = self.best_rewards[self.offline_learn_T:]
        return 0


class LinUCB(Base):
    def __init__(self, nu, d, T, ni):
        super(LinUCB, self).__init__(nu, d, T, ni)
        self.S = np.eye(d)
        self.b = np.zeros(d)
        self.Sinv = np.eye(d)
        self.theta = np.zeros(d)

    def test_recommend(self, i, items, t):
        return self._select_item_ucb(self.S, self.Sinv, self.theta, items, t, t)

    def store_info(self, i, x, y, t, r, br):
        self.rewards[t] += r
        self.best_rewards[t] += br
        self.S += np.outer(x, x)
        self.b += y * x
        self.Sinv, self.theta = self._update_inverse(self.S, self.b, self.Sinv, x, t)

    def store_info_test(self,t, r, br):
        self.rewards[t] = r
        self.best_rewards[t] = br
class LinUCB_IND(Base):
    def __init__(self, nu, d, T, ni):
        super(LinUCB_IND, self).__init__(nu, d, T, ni)
        self.S = np.repeat(self.ucb_lambda * np.eye(d)[np.newaxis, :, :], nu, axis=0)
        self.b = np.zeros((nu, d))
        self.Sinv = np.repeat(self.ucb_lambda*np.eye(d)[np.newaxis, :, :], nu, axis=0)
        self.theta = np.zeros((nu, d))
        self.N = np.zeros(nu)

    def test_recommend(self, i, items, t):
        return self._select_item_ucb(self.S[i], self.Sinv[i], self.theta[i], items, self.N[i], t)

    def store_info(self, i, x, y, t, r, br):
        self.rewards[t] += r
        self.best_rewards[t] += br
        self.S[i] += np.outer(x, x)
        self.b[i] += y * x
        self.N[i] += 1
        self.Sinv[i], self.theta[i] = self._update_inverse(self.S[i], self.b[i], self.Sinv[i], x, t)


    def store_info_test(self, t, r, br):
        self.rewards[t] = r
        self.best_rewards[t] = br

class LinUCB_Neighbor(LinUCB_IND):
    def __init__(self, nu, d, T, ni, alpha, delta, C1):
        super(LinUCB_Neighbor, self).__init__(nu, d, T, ni)
        self.S_Neighbor = np.repeat(self.ucb_lambda * np.eye(d)[np.newaxis, :, :], nu, axis=0)  # (nu, d, d)
        self.b_Neighbor = np.zeros((nu, d))  # (nu, d)
        self.Sinv_Neighbor = np.repeat(np.eye(d)[np.newaxis, :, :], nu, axis=0)  # (nu, d, d)
        self.theta_Neighbor = np.zeros((nu, d))  # (nu, d)
        self.N_Neighbor = np.zeros(nu)
        self.alpha = alpha
        self.delta = delta
        self.C1 = C1

    def store_info_Neighbor(self, neighbor_lists):
        for i in neighbor_lists:
            self.S_Neighbor[i] += self.S[i] - self.ucb_lambda*np.eye(self.d)
            self.b_Neighbor[i] += self.b[i]
            self.N_Neighbor[i] += self.N[i]

            for neighbor in neighbor_lists[i]:
                self.S_Neighbor[i] += self.S[neighbor] - self.ucb_lambda*np.eye(self.d)
                self.b_Neighbor[i] += self.b[neighbor]
                self.N_Neighbor[i] += self.N[neighbor]

            self.Sinv_Neighbor[i], self.theta_Neighbor[i] = self._update_inverse(
                self.S_Neighbor[i], self.b_Neighbor[i], self.Sinv_Neighbor[i], None, self.N_Neighbor[i]
            )
            self.theta[i] = self.theta_Neighbor[i]

    def lcb_beta(self, N):
        return np.sqrt(self.d * np.log(1 + N / (self.d*self.ucb_lambda)) + 2 * np.log(2*self.nu/self.delta)) + np.sqrt(self.ucb_lambda)
    
    def _select_item_lcb(self, S, Sinv, theta, items, N):
        return np.argmax(np.dot(items, theta) - 0.1*self.lcb_beta(N) * (np.matmul(items, Sinv) * items).sum(axis = 1))

    def test_recommend(self, i, items, t):
        return self._select_item_lcb(self.S_Neighbor[i], self.Sinv_Neighbor[i], self.theta_Neighbor[i], items, self.N_Neighbor[i])


