import networkx as nx
import numpy as np
from ...core.Base import LinUCB_IND

class Cluster:
    def __init__(self, users, S, b, N):
        self.users = users
        self.S = S
        self.b = b
        self.N = N
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)

class CLUB(LinUCB_IND):
    def __init__(self, nu, d, T, ni, alpha, edge_probability = 1):
        super(CLUB, self).__init__(nu, d, T, ni)
        self.alpha = alpha
        self.G = nx.gnp_random_graph(nu, edge_probability)
        self.clusters = {0:Cluster(users=range(nu), S=np.eye(d), b=np.zeros(d), N=0)}
        self.cluster_inds = np.zeros(nu)
        self.num_clusters = np.zeros(T)

    def test_recommend(self, i, items, t):
        cluster = self.clusters[self.cluster_inds[i]]
        return self._select_item_ucb(cluster.S, cluster.Sinv, cluster.theta, items, cluster.N, t)

    def store_info(self, i, x, y, t, r, br):
        super(CLUB, self).store_info(i, x, y, t, r, br)
        c = self.cluster_inds[i]
        self.clusters[c].S += np.outer(x, x)
        self.clusters[c].b += y * x
        self.clusters[c].N += 1

    def _if_split(self, theta, N1, N2):
        def _factT(T):
            return np.sqrt((1 + np.log(1 + T)) / (1 + T))
        return np.linalg.norm(theta) >  alpha * (_factT(N1) + _factT(N2))

    def update(self):
        update_clusters = False
        for i in range(self.nu):
            c = self.cluster_inds[i]
            A = [a for a in self.G.neighbors(i)]
            for j in A:
                if self.N[i] and self.N[j] and self._if_split(self.theta[i] - self.theta[j], self.N[i], self.N[j]):
                    self.G.remove_edge(i, j)
                    update_clusters = True
        if update_clusters:
            C = set()
            for i in self.I: # suppose there is only one user per round
                C = nx.node_connected_component(self.G, i)
                if len(C) < len(self.clusters[c].users):
                    remain_users = set(self.clusters[c].users)
                    self.clusters[c] = Cluster(list(C), S=sum([self.S[k]-self.ucb_lambda*np.eye(self.d) for k in C])+self.ucb_lambda*np.eye(self.d), b=sum([self.b[k] for k in C]), N=sum([self.N[k] for k in C]))

                    remain_users = remain_users - set(C)
                    c = max(self.clusters) + 1
                    while len(remain_users) > 0:
                        j = np.random.choice(list(remain_users))
                        C = nx.node_connected_component(self.G, j)
                        self.clusters[c] = Cluster(list(C), S=sum([self.S[k]-self.ucb_lambda*np.eye(self.d) for k in C])+self.ucb_lambda*np.eye(self.d), b=sum([self.b[k] for k in C]), N=sum([self.N[k] for k in C]))
                        for j in C:
                            self.cluster_inds[j] = c
                        c += 1
                        remain_users = remain_users - set(C)
            for c in range(len(self.clusters)):
                self.clusters[c].Sinv, self.clusters[c].theta = self._update_inverse(self.clusters[c].S, self.clusters[c].b, self.clusters[c].Sinv, .0, self.clusters[c].N)
            for i in range(self.nu):
                self.theta[i] = self.clusters[self.cluster_inds[i]].theta
        # print(self.num_clusters)
        print(len(self.clusters))
