import networkx as nx
import numpy as np
from sklearn.cluster import DBSCAN
try:
    import hdbscan
    HDBSCAN_AVAILABLE = True
except ImportError:
    HDBSCAN_AVAILABLE = False
    print("Warning: hdbscan not available, falling back to DBSCAN")
from ...core.utils import edge_probability, is_power2, isInvertible
from ...core.Base import LinUCB_IND

class Cluster:
    def __init__(self, users, S, b, N):
        self.users = users
        self.S = S
        self.b = b
        self.N = N
        self.Sinv = np.linalg.inv(self.S)
        self.theta = np.matmul(self.Sinv, self.b)

class OffDBSCAN_improve(LinUCB_IND):
    def __init__(self, nu, d, T, ni, eps, min_samples, use_hdbscan=True):
        super(OffDBSCAN_improve, self).__init__(nu, d, T, ni)
        self.eps = eps

        self.use_hdbscan = use_hdbscan and HDBSCAN_AVAILABLE
        # HDBSCAN-specific parameters
        self.min_samples = max(min_samples, nu//1000)
        self.min_cluster_size = max(5, nu//1000)
        self.theta = np.zeros((nu, d))
        self.clusters = {0: Cluster(users=list(range(nu)), S=self.ucb_lambda * np.eye(d), b=np.zeros(d), N=0)}
        self.cluster_inds = np.zeros(nu, dtype=int)

    def cluster_theta(self):
        if self.use_hdbscan:
            # Use HDBSCAN
            # min_cluster_size: minimum cluster size
            # min_samples: minimum number of neighbors for a core point
            clusterer = hdbscan.HDBSCAN(min_cluster_size=self.min_cluster_size, min_samples=self.min_samples)
            labels = clusterer.fit_predict(self.theta)
        else:
            # Use classic DBSCAN
            dbscan = DBSCAN(eps=self.eps, min_samples=self.min_samples, metric='euclidean')
            labels = dbscan.fit_predict(self.theta)
        
        # Convert labels to cluster dictionary
        clusters = {}
        for user, label in enumerate(labels):
            if label != -1:  # Ignore noise points
                if label not in clusters:
                    clusters[label] = []
                clusters[label].append(user)
        
        # If no valid clusters, put all users into a single cluster
        if not clusters:
            clusters[0] = list(range(self.nu))
        
        return clusters

    def test_recommend(self, i, items, t):
        c = int(self.cluster_inds[i])
        cluster = self.clusters[c]
        return self._select_item_ucb(cluster.S, cluster.Sinv, cluster.theta, items, cluster.N, t)

    def collaborative_filtering(self, clusters):
        """Perform collaborative filtering within each cluster."""
        # Update cluster structures
        self.clusters = {}
        for cluster_id, cluster_users in clusters.items():
            cluster_S = sum([self.S[k] - self.ucb_lambda * np.eye(self.d) for k in cluster_users]) + self.ucb_lambda * np.eye(self.d)
            cluster_b = sum([self.b[k] for k in cluster_users])
            cluster_N = sum([self.N[k] for k in cluster_users])
            cluster_Sinv = np.linalg.inv(cluster_S)
            cluster_theta = np.matmul(cluster_Sinv, cluster_b)
            
            # Create new cluster object
            self.clusters[cluster_id] = Cluster(
                users=cluster_users,
                S=cluster_S,
                b=cluster_b,
                N=cluster_N
            )
            
            # Update user-to-cluster indices
            for user in cluster_users:
                self.cluster_inds[user] = cluster_id

    def update(self):
        clusters = self.cluster_theta()
        # print(len(clusters))
        # print(clusters)
        self.collaborative_filtering(clusters)



