import numpy as np
import torch
from loguru import logger
from graph_utils.cluster import Cluster

def squared_wasserstein_1d(X, Y, use_gpu=True):
    if use_gpu:
        X_sorted, _ = torch.sort(X, dim=0)
        Y_sorted, _ = torch.sort(Y, dim=0)
    else:
        X_sorted, _ = np.sort(X, axis=0)
        Y_sorted, _ = np.sort(Y, axis=0)
    n = min(X_sorted.shape[0], Y_sorted.shape[0])
    diff = X_sorted[:n] - Y_sorted[:n]
    return (diff ** 2).mean(dim=0).sum()

def compute_graph_laplacian(affinity_matrix, normalized=True, use_gpu=True):
    if use_gpu:
        try:
            import torch
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            
            if isinstance(affinity_matrix, torch.Tensor):
                affinity_tensor = affinity_matrix.to(device)
            else:
                affinity_tensor = torch.tensor(affinity_matrix, dtype=torch.float32, device=device)
            
            # Compute degree matrix
            degrees = torch.sum(affinity_tensor, dim=1)
            
            if normalized and torch.all(degrees > 0):
                # Normalized Laplacian: L = I - D^(-1/2) * A * D^(-1/2)
                degrees_safe = torch.clamp(degrees, min=1e-8)  # Ensure no zeros or negative values
                D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degrees_safe))
                I = torch.eye(len(degrees), device=device)
                L = I - D_inv_sqrt @ affinity_tensor @ D_inv_sqrt
            else:
                # Unnormalized Laplacian: L = D - A
                D = torch.diag(degrees)
                L = D - affinity_tensor
            
            return L.cpu().numpy()
            
        except ImportError:
            logger.warning("PyTorch not available, falling back to CPU implementation")
            use_gpu = False
    
    # CPU implementation
    try:
        # Compute degree matrix
        degrees = np.sum(affinity_matrix, axis=1)
        
        if normalized and np.all(degrees > 0):
            # Normalized Laplacian: L = I - D^(-1/2) * A * D^(-1/2)
            degrees_safe = np.clip(degrees, 1e-8, None)  # Ensure no zeros or negative values
            D_inv_sqrt = np.diag(1.0 / np.sqrt(degrees_safe))
            L = np.eye(len(degrees)) - D_inv_sqrt @ affinity_matrix @ D_inv_sqrt
        else:
            # Unnormalized Laplacian: L = D - A
            D = np.diag(degrees)
            L = D - affinity_matrix
        
        return L
        
    except Exception as e:
        print(f"Error in compute_graph_laplacian: {e}")
        print(f"affinity_matrix shape: {affinity_matrix.shape}")
        print(f"affinity_matrix min/max: {np.min(affinity_matrix)}, {np.max(affinity_matrix)}")
        if 'degrees' in locals():
            print(f"degrees min/max: {np.min(degrees)}, {np.max(degrees)}")
        raise

def compute_eigens(D, k_eigenvals=5, use_gpu=True, eigen_values=True):
    # Ensure k doesn't exceed the matrix size
    max_eigenvals = int(D.shape[0])  # Convert to int to avoid tensor type issues
    k = min(k_eigenvals, max_eigenvals)
    L = compute_graph_laplacian(D, use_gpu=use_gpu)
    if use_gpu:
        try:
            import torch
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # Fix the torch.tensor warning by properly handling tensor conversion
            if isinstance(L, torch.Tensor):
                L_tensor = L.clone().detach().to(device, dtype=torch.float32)
            else:
                L_tensor = torch.tensor(L, dtype=torch.float32, device=device)
            
            # Use bounded k value
            eigenvals_tensor, eigenvecs_tensor = torch.linalg.eigh(L_tensor)
            
            ind = torch.argsort(eigenvals_tensor)
            eigenvals_tensor = eigenvals_tensor[ind]
            eigenvecs_tensor = eigenvecs_tensor[:, ind]

            eigenvals_k = eigenvals_tensor[:k]
            eigenvecs_k = eigenvecs_tensor[:, :k]
            
            if eigen_values:
                return eigenvals_k
            else:
                return eigenvecs_k.cpu().numpy()
        except ImportError:
            logger.warning("PyTorch not available, falling back to CPU implementation")
            use_gpu = False
            
    eigenvals, eigenvecs = np.linalg.eigh(L)
    # Use the same bounded k value for consistency
    eigenvals_k = eigenvals[:k]
    eigenvecs_k = eigenvecs[:, :k]
    
    if eigen_values:
        return eigenvals_k
    else:
        return eigenvecs_k

def compute_spectral_distance(eigen1, eigen2, use_gpu=True, eigen_values=True, debug=False):
    if use_gpu:
        try:
            import torch
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            
            if eigen_values:
                # Compute spectral distance (L2 norm of eigenvalue differences)
                if len(eigen1) == len(eigen2):
                    spectral_dist = torch.norm(eigen1 - eigen2).item()
                else:
                    # Handle different sizes by padding with maximum eigenvalue
                    max_len = max(len(eigen1), len(eigen2))
                    # Fix linter error by converting tensor values to scalars
                    max_eigenval = max(eigen1[-1].item() if len(eigen1) > 0 else 0,
                                     eigen2[-1].item() if len(eigen2) > 0 else 0)
                    
                    # Pad tensors
                    pad_len1 = max_len - len(eigen1)
                    pad_len2 = max_len - len(eigen2)
                    
                    if pad_len1 > 0:
                        pad_values1 = torch.full((pad_len1,), max_eigenval, device=device)
                        eigen1 = torch.cat([eigen1, pad_values1])
                    
                    if pad_len2 > 0:
                        pad_values2 = torch.full((pad_len2,), max_eigenval, device=device)
                        eigen2 = torch.cat([eigen2, pad_values2])
                    
                    spectral_dist = torch.norm(eigen1 - eigen2).item()
                
                return {
                    'spectral_distance': spectral_dist,
                    'eigen1': eigen1.cpu().numpy().tolist(),
                    'eigen2': eigen2.cpu().numpy().tolist()
                }
                
            else:
                spectral_dist = squared_wasserstein_1d(eigen1, eigen2, use_gpu)
                return {
                    'spectral_distance': spectral_dist,
                    'eigen1': eigen1.cpu().numpy().tolist(),
                    'eigen2': eigen2.cpu().numpy().tolist()
                }
            
        except ImportError:
            logger.warning("PyTorch not available, falling back to CPU implementation")
            use_gpu = False
    
    # CPU implementation
    if len(eigen1) == len(eigen2):
        spectral_dist = np.linalg.norm(eigen1 - eigen2)
    else:
        # Handle different sizes by padding with maximum eigenvalue
        max_len = max(len(eigen1), len(eigen2))
        max_eigenval = max(eigen1[-1] if len(eigen1) > 0 else 0,
                            eigen2[-1] if len(eigen2) > 0 else 0)
        
        padded1 = np.pad(eigen1, (0, max_len - len(eigen1)), 
                        'constant', constant_values=max_eigenval)
        padded2 = np.pad(eigen2, (0, max_len - len(eigen2)), 
                        'constant', constant_values=max_eigenval)
        spectral_dist = np.linalg.norm(padded1 - padded2)
    
    return {
        'spectral_distance': spectral_dist,
        'eigen1': eigen1.tolist(),
        'eigen2': eigen2.tolist()
    }
    

def compute_cluster_spectral_results(cluster_affinity_dict1, cluster_affinity_dict2,
                                   label1_prefix='emb1_label', label2_prefix='emb2_label',
                                   k_eigenvals=5, use_gpu=False, eigen_values=True, debug=False):
    spectral_results = []
    eigens_1 = {}
    eigens_2 = {}
    for label1 in cluster_affinity_dict1.keys():
        laplacian1 = compute_graph_laplacian(cluster_affinity_dict1[label1], use_gpu=use_gpu)
        eigens_1[label1] = compute_eigens(laplacian1, k_eigenvals, use_gpu=use_gpu, eigen_values=eigen_values)
    for label2 in cluster_affinity_dict2.keys():
        laplacian2 = compute_graph_laplacian(cluster_affinity_dict2[label2], use_gpu=use_gpu)
        eigens_2[label2] = compute_eigens(laplacian2, k_eigenvals, use_gpu=use_gpu, eigen_values=eigen_values)
    
    for label1 in cluster_affinity_dict1.keys():
        for label2 in cluster_affinity_dict2.keys():
            spectral_result = compute_spectral_distance(eigens_1[label1], eigens_2[label2], use_gpu=use_gpu, eigen_values=eigen_values, debug=debug)
            
            if spectral_result is not None:
                spectral_results.append({
                    label1_prefix: label1,
                    label2_prefix: label2,
                    'spectral_distance': spectral_result['spectral_distance'],
                    'emb1_size': cluster_affinity_dict1[label1].shape[0],
                    'emb2_size': cluster_affinity_dict2[label2].shape[0],
                })
    
    return spectral_results

def cluster_matching(emb1, emb2, emb1_ori_ind, emb2_ori_ind, n_clu1, n_clu2, args):
    
    cluster1 = Cluster(emb1, emb1_ori_ind, n_clu1, args.cluster_method, graph_method=args.graph_method, knn_k=args.knn_k)
    cluster2 = Cluster(emb2, emb2_ori_ind, n_clu2, args.cluster_method, graph_method=args.graph_method, knn_k=args.knn_k)
    
    emb1_clu_affinity = cluster1.get_affinity_matrix(args.distance_metric)
    emb2_clu_affinity = cluster2.get_affinity_matrix(args.distance_metric)
    
    spectral_results = compute_cluster_spectral_results(emb1_clu_affinity, emb2_clu_affinity, k_eigenvals=args.k_eigenvals, eigen_values=args.eigen_values, use_gpu=args.use_gpu, debug=args.debug)
    spectral_results.sort(key=lambda x: x['spectral_distance'])
    
    print("Optimal Cluster Mapping (ranked by spectral distance):")
    print("-" * 50)
    ranks = []
    intersection_ratios_emb1 = []
    intersection_ratios_emb2 = []
    distances = []
    mapping_pairs = []
    
    
    for (i, result) in enumerate(spectral_results):
        label1 = result['emb1_label']
        label2 = result['emb2_label']
        ind1_original = cluster1.get_ori_ind(label1)
        ind2_original = cluster2.get_ori_ind(label2)
        intersection_count = len(np.intersect1d(ind1_original, ind2_original))
        intersection_ratio_emb1 = intersection_count / len(ind1_original) if len(ind1_original) > 0 else 0
        intersection_ratio_emb2 = intersection_count / len(ind2_original) if len(ind2_original) > 0 else 0 
        intersecting_indices = np.intersect1d(ind1_original, ind2_original)
        
        rank = i + 1
        ranks.append(rank)
        intersection_ratios_emb1.append(intersection_ratio_emb1)
        intersection_ratios_emb2.append(intersection_ratio_emb2)
        distances.append(result['spectral_distance'])
        mapping_pairs.append(f"emb1 {label1} <-> emb2 {label2}")
        if intersection_count > 0:
            print(f"Emb1 cluster {label1} <-> Emb2 cluster {label2}: Intersection ratio (emb1): {intersection_ratio_emb1:.3f}, Intersection ratio (emb2): {intersection_ratio_emb2:.3f}, distance: {result['spectral_distance']}, rank: {rank}")
            print(f"  Intersecting indices: {intersecting_indices[:10]}{'...' if len(intersecting_indices) > 10 else ''}")
    
    return spectral_results, cluster1, cluster2 