import networkx as nx
import numpy as np
from ...core.Base import LinUCB_Neighbor

class Cluster:
    def __init__(self, users, S, b, N):
        self.users = users
        self.S = S
        self.b = b
        self.N = N
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)

class OffCLUB(LinUCB_Neighbor):
    def __init__(self, nu, d, T, ni, alpha, delta, C1, edge_probability = 1):
        super(OffCLUB, self).__init__(nu, d, T, ni, alpha, delta, C1)
        self.G = nx.gnp_random_graph(nu, edge_probability)

    def compute_CI_u(self, N1, delta, C1):
        numerator = np.sqrt(self.d * np.log(1 + N1 / (self.ucb_lambda * self.d)) + 2 * np.log(2 * self.nu / delta)) + np.sqrt(self.ucb_lambda)
        denominator = np.sqrt(self.ucb_lambda + C1 * N1)
        CI_u = numerator / denominator
        return CI_u

    def _if_split(self, theta, N1, N2):
        alpha = self.alpha
        delta = self.delta
        C1 = self.C1
        return np.linalg.norm(theta) >  alpha * (self.compute_CI_u(N1, delta, C1) + self.compute_CI_u(N2, delta, C1))

    def update_graph(self):
        edges_to_remove = []
        for i, j in self.G.edges:
            if self.N[i] and self.N[j] and self._if_split(self.theta[i] - self.theta[j], self.N[i], self.N[j]):
                edges_to_remove.append((i, j))
        for i, j in edges_to_remove:
            self.G.remove_edge(i, j)
        neighbor_lists = {node: list(self.G.neighbors(node)) for node in self.G.nodes}
        return neighbor_lists

    def update(self):
        neighbor_lists = self.update_graph()
        self.store_info_Neighbor(neighbor_lists)