import networkx as nx
import numpy as np
from ...core.Base import LinUCB_Neighbor

class Cluster:
    def __init__(self, users, S, b, N):
        self.users = users  # List/array of users
        self.S = S
        self.b = b
        self.N = N
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)


class OffNCLUB(LinUCB_Neighbor):
    def __init__(self, nu, d, T, ni, alpha, delta, C1, gamma):
        super(OffNCLUB, self).__init__(nu, d, T, ni, alpha, delta, C1)
        self.nu = nu
        self.gamma = gamma
        self.G = nx.Graph()  # Null graph
        self.G.add_nodes_from(range(nu))  # Add all users as nodes

    def compute_CI_u(self, N1, delta, C1):
        numerator = np.sqrt(
            self.d * np.log(1 + N1 / (self.ucb_lambda * self.d)) + 2 * np.log(2 * self.nu / delta)) + np.sqrt(
            self.ucb_lambda)
        denominator = np.sqrt(self.ucb_lambda + C1 * N1)
        CI_u = numerator / denominator
        return CI_u

    def _if_connect(self, theta, N1, N2):
        return np.linalg.norm(theta) < self.gamma - self.alpha * (
                    self.compute_CI_u(N1, self.delta, self.C1) + self.compute_CI_u(N2, self.delta, self.C1))

    def precompute_CI(self):
        self.CI = np.array([self.compute_CI_u(self.N[u], self.delta, self.C1) for u in range(self.nu)])

    def connect_edges(self):
        theta_diff = np.linalg.norm(self.theta[:, np.newaxis, :] - self.theta[np.newaxis, :, :], axis=2)
        CI_sum = self.CI[:, np.newaxis] + self.CI[np.newaxis, :]
        Gamma_matrix = theta_diff + self.alpha * CI_sum
        valid_edges = np.argwhere((Gamma_matrix < self.gamma) & (np.triu(np.ones_like(Gamma_matrix), k=1) == 1))
        self.G.add_edges_from(valid_edges)
        neighbor_lists = {node: list(self.G.neighbors(node)) for node in self.G.nodes}
        return neighbor_lists

    def update(self):
        self.precompute_CI()
        neighbor_lists = self.connect_edges()
        self.store_info_Neighbor(neighbor_lists)
