import networkx as nx
import numpy as np
# Compatibility shim for numpy>=2.0 and pyclustering's numpy.warnings usage
try:
    import warnings as _warnings  # noqa
    if not hasattr(np, "warnings"):
        np.warnings = _warnings  # type: ignore[attr-defined]
except Exception:
    pass

from ...core.utils import edge_probability, is_power2, isInvertible
from ...core.Base import LinUCB_IND
from pyclustering.cluster.xmeans import xmeans
from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer

class Cluster:
    def __init__(self, users, S, b, N):
        self.users = users  # List or array of users
        self.S = S
        self.b = b
        self.N = N
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)

class OffXMeans_improve(LinUCB_IND):
    def __init__(self, nu, d, T, ni):
        super(OffXMeans_improve, self).__init__(nu, d, T, ni)
        self.theta = np.zeros((nu, d))  # Initialize theta for all users
        self.clusters = {0: Cluster(users=list(range(nu)), S=self.ucb_lambda * np.eye(d), b=np.zeros(d), N=0)}
        self.cluster_inds = np.zeros(nu, dtype=int)

    def cluster_theta(self):
        initial_centers = kmeans_plusplus_initializer(self.theta, 2).initialize()  # initial k=2
        xmeans_instance = xmeans(self.theta, initial_centers, ccore=True)  # use C++ acceleration
        xmeans_instance.process()
        clusters = xmeans_instance.get_clusters()
        cluster_dict = {}
        for i, cluster in enumerate(clusters):
            cluster_dict[i] = cluster
        return cluster_dict

    def test_recommend(self, i, items, t):
        c = int(self.cluster_inds[i])
        cluster = self.clusters[c]
        return self._select_item_ucb(cluster.S, cluster.Sinv, cluster.theta, items, cluster.N, t)
    
    def collaborative_filtering(self, clusters):
        """Perform collaborative filtering within each cluster."""
        # Update cluster structures
        self.clusters = {}
        for cluster_id, cluster_users in clusters.items():
            cluster_S = sum([self.S[k] - self.ucb_lambda * np.eye(self.d) for k in cluster_users]) + self.ucb_lambda * np.eye(self.d)
            cluster_b = sum([self.b[k] for k in cluster_users])
            cluster_N = sum([self.N[k] for k in cluster_users])
            cluster_Sinv = np.linalg.inv(cluster_S)
            cluster_theta = np.matmul(cluster_Sinv, cluster_b)
            
            # Create new cluster object
            self.clusters[cluster_id] = Cluster(
                users=cluster_users,
                S=cluster_S,
                b=cluster_b,
                N=cluster_N
            )
            
            # Update user-to-cluster indices
            for user in cluster_users:
                self.cluster_inds[user] = cluster_id
    def update(self):
        """Update theta based on clustering."""
        clusters = self.cluster_theta()
        # print(len(clusters))
        # print(clusters)
        self.collaborative_filtering(clusters)


