import numpy as np
from scipy.sparse.linalg import eigsh

class UncertaintyQuantifier:
    def __init__(self, graph):
        """
        Initialize the uncertainty quantifier with a graph.
        
        Args:
            graph (networkx.Graph): The graph of diffusion models.
        """
        self.graph = graph

    def intrinsic_uncertainty(self, node):
        """
        Compute intrinsic uncertainty for a node.
        
        Args:
            node (int): Index of the node.
        
        Returns:
            float: Intrinsic uncertainty score.
        """
        output = self.graph.nodes[node]['output']
        variance = np.var(output)
        entropy = -np.sum(output * np.log(output + 1e-10))  # Add small epsilon to avoid log(0)
        cross_moment = np.mean(np.outer(output - np.mean(output), output - np.mean(output)))
        alpha, beta, gamma = 1.0, 0.5, 0.2  # Hyperparameters
        return alpha * variance + beta * entropy + gamma * cross_moment

    def propagated_uncertainty(self, node, max_iter=100, tol=1e-6):
        """
        Compute propagated uncertainty for a node using iterative updates.
        
        Args:
            node (int): Index of the node.
            max_iter (int): Maximum number of iterations.
            tol (float): Convergence tolerance.
        
        Returns:
            float: Propagated uncertainty score.
        """
        U_total = {n: self.intrinsic_uncertainty(n) for n in self.graph.nodes}
        for _ in range(max_iter):
            U_new = U_total.copy()
            for n in self.graph.nodes:
                neighbors = list(self.graph.neighbors(n))
                if not neighbors:
                    continue
                weights = [self.graph[n][m]['weight'] for m in neighbors]
                total_weight = sum(weights)
                U_prop = sum(w * U_total[m] for w, m in zip(weights, neighbors)) / total_weight
                U_new[n] = 0.7 * U_total[n] + 0.3 * U_prop  # lambda = 0.3
            if np.allclose(list(U_total.values()), list(U_new.values()), atol=tol):
                break
            U_total = U_new
        return U_total[node]

    def spectral_uncertainty_propagation(self):
        """
        Compute uncertainty using spectral graph theory (SGUP).
        
        Returns:
            dict: Uncertainty scores for all nodes.
        """
        n = len(self.graph.nodes)
        W = nx.adjacency_matrix(self.graph, weight='weight').toarray()
        D = np.diag(np.sum(W, axis=1))
        L = D - W
        eigenvalues, eigenvectors = eigsh(L, k=1, which='SM')
        U_smooth = eigenvectors[:, 0]  # Smallest eigenvector for smooth uncertainty
        return {i: U_smooth[i] for i in range(n)}

    def path_specific_uncertainty(self, node, decay_rate=0.1):
        """
        Compute path-specific uncertainty influence (PUI) for a node.
        
        Args:
            node (int): Index of the node.
            decay_rate (float): Decay constant for path length attenuation.
        
        Returns:
            float: Path-specific uncertainty score.
        """
        U_path = 0.0
        for neighbor in self.graph.nodes:
            if neighbor == node:
                continue
            paths = nx.all_simple_paths(self.graph, source=node, target=neighbor, cutoff=3)  # Limit path length
            for path in paths:
                path_weight = np.prod([self.graph[path[i]][path[i + 1]]['weight'] for i in range(len(path) - 1)])
                path_length = len(path) - 1
                U_path += path_weight * np.exp(-decay_rate * path_length) * self.intrinsic_uncertainty(neighbor)
        return self.intrinsic_uncertainty(node) + 0.5 * U_path  # gamma = 0.5