import torch
from typing import Optional, Tuple, List
import time

@torch.no_grad()
def group_experts_by_clustering(
    model: str,
    num_groups: int,
    cluster: str,
    linkage: str,
    hierarchical_stopping_metric: str,
    num_experts: int,
    experts: torch.Tensor,
    experts2: Optional[torch.Tensor] = None,
    experts3: Optional[torch.Tensor] = None,
    init_center: Optional[torch.Tensor] = None,
    w1: float = 1.0,
    w2: float = 1.0,
    w3: float = 1.0,
):  
    experts = experts.to(torch.float)
    experts2 = experts2.to(torch.float) if experts2 is not None else None
    experts3 = experts3.to(torch.float) if experts3 is not None else None

    if cluster == "hierarchical":
        labels, dom_experts = hierarchical_clustering(experts, num_groups, linkage)
        print(f"group: {labels}, dom: {dom_experts}")
        return dom_experts, labels
    elif cluster == 'hierarchical-dynamic':
        labels, dom_experts = hierarchical_clustering_dynamic(experts, linkage, hierarchical_stopping_metric, num_groups, 1)
        print(f"group: {labels}, dom: {dom_experts}")
        return dom_experts, labels
    elif cluster == 'kmeans':
        
        dom_experts, labels = kmeans_plus_plus_clustering(
            experts=experts,
            num_groups=num_groups,
            experts2=experts2,
            experts3=experts3,
            init_center=init_center,
            w1=w1,
            w2=w2,
            w3=w3
        )
        print(f"group: {labels}, dom: {dom_experts}")
        return dom_experts, labels
    elif cluster == 'Graph_Partitioning':
        dom_experts, labels = Graph_Partitioning_clustering(
            experts=experts,
            num_groups=num_groups,
            experts2=experts2,
            experts3=experts3,
            init_center=init_center,
            w1=w1,
            w2=w2,
            w3=w3
        )
        print(f"group: {labels}, dom: {dom_experts}")
        return dom_experts, labels
    else:
        raise ValueError(f"Please set cluster to 'hierarchical', 'hierarchical-dynamic' or 'kmeans', but the input is `{cluster}`")

@torch.no_grad()
def compute_silhouette_score(tensor_list, cluster_labels):
        
        def compute_pairwise_distances(tensor_list):
            num_tensors = tensor_list.shape[0]
            distances = torch.zeros((num_tensors, num_tensors))

            for i in range(num_tensors):
                for j in range(i, num_tensors):
                    dist = torch.norm(tensor_list[i] - tensor_list[j])
                    distances[i, j] = dist
                    distances[j, i] = dist  

            return distances
        
        
        pairwise_distances = compute_pairwise_distances(tensor_list)

        num_tensors = tensor_list.shape[0]
        unique_labels = torch.unique(cluster_labels)

        silhouette_scores = torch.zeros(num_tensors)

        
        for i in range(num_tensors):
            
            same_cluster = [j for j in range(num_tensors) if cluster_labels[j] == cluster_labels[i] and j != i]
            if len(same_cluster) > 0:
                a_i = torch.mean(pairwise_distances[i, same_cluster])
            else:
                a_i = 0  

            
            b_i = float('inf')
            for label in unique_labels:
                if label == cluster_labels[i]:
                    continue
                other_cluster = [j for j in range(num_tensors) if cluster_labels[j] == label]
                if len(other_cluster) > 0:
                    mean_dist_to_other_cluster = torch.mean(pairwise_distances[i, other_cluster])
                    b_i = min(b_i, mean_dist_to_other_cluster)

            
            silhouette_scores[i] = (b_i - a_i) / max(a_i, b_i)

        
        overall_silhouette_score = torch.mean(silhouette_scores)

        return overall_silhouette_score


def safe_average(tensor):
    non_inf_mask = ~torch.isinf(tensor)
    if non_inf_mask.sum() == 0:
        return float('inf')
    return tensor[non_inf_mask].mean()

@torch.no_grad()
def compute_distance(pair_distances, clusters, method='average', X=None):
    if method == 'average':
        
        cluster_labels = torch.unique(clusters)
        distances = torch.zeros((len(cluster_labels), len(cluster_labels)))
        
        for i, ci in enumerate(cluster_labels):
            for j, cj in enumerate(cluster_labels):
                if i >= j:
                    continue
                dist = []
                
                for vi in torch.where(clusters == ci)[0]:
                    for vj in torch.where(clusters == cj)[0]:
                        dist.append(pair_distances[vi, vj].item())
                new_dist = torch.sum(torch.tensor(dist)) / (torch.sum(clusters == ci) * torch.sum(clusters == cj))
                distances[i, j] = new_dist
                distances[j, i] = new_dist
        distances.fill_diagonal_(float('inf'))
        idx = torch.argmin(distances)
        final_i, final_j = cluster_labels[idx // distances.shape[0]], cluster_labels[idx % distances.shape[0]]
    elif method == 'ward':
        
        cluster_labels = torch.unique(clusters)
        cluster_centers = torch.zeros((len(cluster_labels), X.shape[1]))
        for i, cluster in enumerate(cluster_labels):
            cluster_centers[i] = X[clusters == cluster].mean(dim=0)
        
        
        distances = torch.zeros((len(cluster_labels), len(cluster_labels)))
        for i, ci in enumerate(cluster_labels):
            for j, cj in enumerate(cluster_labels):
                if i >= j:
                    continue
                ni = torch.sum(clusters == ci)
                nj = torch.sum(clusters == cj)
                new_dist = (ni * nj) / (ni + nj) * torch.cdist(cluster_centers[i].unsqueeze(0), cluster_centers[j].unsqueeze(0), p=2)
                distances[i, j] = new_dist
                distances[j, i] = new_dist
        distances.fill_diagonal_(float('inf'))
        idx = torch.argmin(distances)
        final_i, final_j = cluster_labels[idx // distances.shape[0]], cluster_labels[idx % distances.shape[0]]
    else:
        raise NotImplementedError("Unsupported linkage method: {}".format(method))
    
    return final_i, final_j

@torch.no_grad()
def pairwise_distances(X, method='single'):
    dot_product = torch.mm(X, X.t())
    square_norm = dot_product.diag()
    distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
    distances = torch.clamp(distances, min=0.0).sqrt()
    if method == 'single' or method == 'average':
        distances.fill_diagonal_(float('inf'))
    elif method == 'complete':
        distances.fill_diagonal_(0.0)
    return distances

@torch.no_grad()
def linkage_step(distances, pair_distances, clusters=None, method='single', X=None):
    
    
    if method == 'single':
        
        min_idx = torch.argmin(distances).item()
        i, j = min_idx // distances.shape[0], min_idx % distances.shape[0]
        
    elif method == 'complete':
        
        max_idx = torch.argmax(distances).item()
        i, j = max_idx // distances.shape[0], max_idx % distances.shape[0]
    else:
        i, j = compute_distance(pair_distances, clusters, method, X)
    
    if i > j:
        i, j = j, i
    
    if method == 'average' or method == 'ward':
        return i, j, distances
    
    
    
    
    for k in range(distances.shape[0]):
        if k != i and k != j: 
            if method == 'single':
                new_dist = torch.min(distances[i, k], distances[j, k])
            elif method == 'complete':
                new_dist = torch.max(distances[i, k], distances[j, k])
            distances[i, k] = new_dist
            distances[k, i] = new_dist

    if method == 'single':
        distances[i, i] = float('inf')
        distances[j, :] = float('inf')
        distances[:, j] = float('inf')
    elif method == 'complete':
        distances[i, i] = 0.0
        distances[j, :] = 0.0
        distances[:, j] = 0.0
    
    return i, j, distances

@torch.no_grad()
def hierarchical_clustering(X, n_clusters, method='single'):
    print("hierarchical clustering - {} to {} clusters".format(method, n_clusters))
    device = X.device
    n_samples = X.shape[0]
    
    
    distances = pairwise_distances(X, method)
    pair_distances = distances.clone()
    
    
    clusters = torch.tensor([i for i in range(n_samples)])
    
    
    while len(torch.unique(clusters)) > n_clusters:
        i, j, distances = linkage_step(distances, pair_distances, clusters, method, X)
        print(f"clusters: {len(torch.unique(clusters))}, merge ({i}, {j})")
        cj = clusters[j]
        
        clusters[clusters == cj] = clusters[i]

    
    
    d = {}
    element_id = 0
    for i, idx in enumerate(clusters):
        if idx.item() not in d:
            d[idx.item()] = element_id
            element_id += 1
        clusters[i] = d[idx.item()]
    
    center_indices = []
    for k in range(n_clusters):
        cluster_members = X[clusters == k]
        cluster_center = cluster_members.mean(dim=0)
        distances = torch.cdist(cluster_members, cluster_center.unsqueeze(0), p=2)
        closest_expert_idx = torch.argmin(distances, dim=0).item()
        center_indices.append(torch.where(clusters == k)[0][closest_expert_idx].item())
    
    del distances
    return clusters, center_indices


def hierarchical_clustering_dynamic(X, linkage='single', stopping_metric='silhouette', max_clusters=8, min_clusters=2):
    n_samples = X.shape[0]
    
    
    distances = pairwise_distances(X, linkage)
    pair_distances = distances.clone()
    
    
    clusters = torch.tensor([i for i in range(n_samples)])
    best_score = -float('inf')  
    best_clusters = None
    
    
    while len(torch.unique(clusters)) > min_clusters:
        i, j, distances = linkage_step(distances, pair_distances, clusters, linkage, X)
        cj = clusters[j]
        
        clusters[clusters == cj] = clusters[i]

        
        if len(torch.unique(clusters)) <= max_clusters:
            if stopping_metric == 'silhouette' and len(clusters) >= 2:
                score = compute_silhouette_score(X, clusters)
                if score > best_score:
                    best_score = score
                    del best_clusters
                    best_clusters = clusters.clone()
                    print(f"Update score to {score}, {best_clusters}")
            elif stopping_metric == 'inertia':
                inertia = 0.0
                for idx, cluster in enumerate(clusters):
                    cluster_experts = X[cluster]
                    centroid = cluster_experts.mean(dim=0)
                    inertia += torch.sum((cluster_experts - centroid) ** 2).item()
                if inertia < best_score:
                    best_score = inertia
                    del best_clusters
                    best_clusters = clusters.clone()
                    print(f"Update score to {score}, {best_clusters}")

    
    
    d = {}
    element_id = 0
    for i, idx in enumerate(best_clusters):
        if idx.item() not in d:
            d[idx.item()] = element_id
            element_id += 1
        best_clusters[i] = d[idx.item()]
    
    center_indices = []
    for k in range(len(torch.unique(best_clusters))):
        cluster_members = X[best_clusters == k]
        cluster_center = cluster_members.mean(dim=0)
        distances = torch.cdist(cluster_members, cluster_center.unsqueeze(0), p=2)
        closest_expert_idx = torch.argmin(distances, dim=0).item()
        center_indices.append(torch.where(best_clusters == k)[0][closest_expert_idx].item())
    
    del distances
    return best_clusters, center_indices

@torch.no_grad()
def kmeans_plus_plus_clustering(
    experts: torch.Tensor,
    num_groups: int,
    experts2: Optional[torch.Tensor] = None,
    experts3: Optional[torch.Tensor] = None,
    init_center: Optional[torch.Tensor] = None,
    w1: float = 1.0,
    w2: float = 1.0,
    w3: float = 1.0,
    max_iter: int = 100,
    tol: float = 1e-4,
) -> Tuple[List[int], torch.Tensor]:
    def _standardize(x):
        x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
        min_value = x.min()
        return x - min_value  

    def kmeans_plus_plus_init(experts, num_groups):
        num_experts = experts.size(0)
        centers = []
        center_indices = []
        
        
        first_center_idx = torch.randint(0, num_experts, (1,)).item()
        centers.append(experts[first_center_idx])
        center_indices.append(first_center_idx)

        
        for _ in range(1, num_groups):
            
            dist_list = []
            for i, center in enumerate(centers):
                dist = torch.cdist(experts, center.unsqueeze(0))
                dist_list.append(dist)
            distances = torch.min(torch.concat(dist_list, dim=-1), dim=-1).values ** 2
            
            
            probabilities = distances / distances.sum()

            
            next_center_idx = torch.multinomial(probabilities, 1).item()
            centers.append(experts[next_center_idx])
            center_indices.append(next_center_idx)
        
        return torch.tensor(center_indices)

    
    if init_center is not None:
        indices = init_center
    else:
        indices = kmeans_plus_plus_init(experts, num_groups)
    
    centers = experts[indices]
    centers2 = experts2[indices] if experts2 is not None else None
    centers3 = experts3[indices] if experts3 is not None else None
    
    
    s1 = experts.shape[1]
    s2 = experts2.shape[1] if experts2 is not None else 1.0
    s3 = experts3.shape[1] if experts3 is not None else 1.0

    
    for _ in range(max_iter):
        
        distances1 = _standardize(torch.cdist(experts, centers) / s1)
        distances2 = _standardize(torch.cdist(experts2, centers2) / s2) if experts2 is not None else torch.zeros(1, device=experts.device)
        distances3 = _standardize(torch.cdist(experts3, centers3) / s3) if experts3 is not None else torch.zeros(1, device=experts.device)

        
        distances = (w1 * distances1 + w2 * distances2 + w3 * distances3) / (w1 + w2 + w3)
        assignments = torch.argmin(distances, dim=1)
        del distances, distances1, distances2, distances3

        
        new_centers = []
        for k in range(num_groups):
            cluster_samples = experts[assignments == k]
            if cluster_samples.shape[0] > 0:
                new_centers.append(cluster_samples.mean(dim=0))
            else:
                new_center_idx = torch.randint(0, experts.shape[0], (1,)).item()
                new_centers.append(experts[new_center_idx])
        new_centers = torch.stack(new_centers)
        
        if experts2 is not None:
            new_centers2 = []
            for k in range(num_groups):
                cluster_samples = experts2[assignments == k]
                if cluster_samples.shape[0] > 0:
                    new_centers2.append(cluster_samples.mean(dim=0))
                else:
                    new_center_idx = torch.randint(0, experts2.shape[0], (1,)).item()
                    new_centers2.append(experts2[new_center_idx])
            new_centers2 = torch.stack(new_centers2)
        else:
            new_centers2 = None
            
        if experts3 is not None:
            new_centers3 = []
            for k in range(num_groups):
                cluster_samples = experts3[assignments == k]
                if cluster_samples.shape[0] > 0:
                    new_centers3.append(cluster_samples.mean(dim=0))
                else:
                    new_center_idx = torch.randint(0, experts3.shape[0], (1,)).item()
                    new_centers3.append(experts3[new_center_idx])
            new_centers3 = torch.stack(new_centers3)
        else:
            new_centers3 = None

        
        max_diff = 0
        for i in range(num_groups):
            diff = torch.max(torch.abs(new_centers[i] - centers[i]))
            diff2 = torch.max(torch.abs(new_centers2[i] - centers2[i])) if experts2 is not None else torch.zeros(1, device=experts.device)
            diff3 = torch.max(torch.abs(new_centers3[i] - centers3[i])) if experts3 is not None else torch.zeros(1, device=experts.device)
            max_diff = max(max_diff, diff.item(), diff2.item(), diff3.item())
        
        if max_diff < tol:
            break
            
        centers = new_centers
        centers2 = new_centers2 if experts2 is not None else None
        centers3 = new_centers3 if experts3 is not None else None
    
    
    center_indices = []
    for k in range(num_groups):
        cluster_members = experts[assignments == k]
        if cluster_members.shape[0] == 0:
            center_indices.append(torch.randint(0, experts.shape[0], (1,)).item())
            continue
        
        distances1 = torch.cdist(cluster_members, new_centers[k].unsqueeze(0))
        distances2 = torch.cdist(experts2[assignments == k], new_centers2[k].unsqueeze(0)) if experts2 is not None else torch.zeros_like(distances1)
        distances3 = torch.cdist(experts3[assignments == k], new_centers3[k].unsqueeze(0)) if experts3 is not None else torch.zeros_like(distances1)
        
        weights_sum = w1
        if experts2 is not None:
            weights_sum += w2
        if experts3 is not None:
            weights_sum += w3
            
        final_distances = (w1 * distances1 + 
                          (w2 * distances2 if experts2 is not None else 0) + 
                          (w3 * distances3 if experts3 is not None else 0)) / weights_sum
        closest_expert_idx = torch.argmin(final_distances, dim=0).item()
        center_indices.append(torch.where(assignments == k)[0][closest_expert_idx].item())
    
    return center_indices, assignments

@torch.no_grad()
def Graph_Partitioning_clustering(
    experts: torch.Tensor,
    num_groups: int,
    experts2: Optional[torch.Tensor] = None,
    experts3: Optional[torch.Tensor] = None,
    init_center: Optional[torch.Tensor] = None,
    w1: float = 1.0,
    w2: float = 1.0,
    w3: float = 1.0,
    max_iter: int = 100,
    tol: float = 1e-4,
) -> Tuple[List[int], torch.Tensor]:
    
    def _standardize(x):
        x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
        min_value = x.min()
        return x - min_value  
    
    num_experts = experts.size(0)
    device = experts.device
    s1 = experts.shape[1] 
    
    raw_distance_matrix = torch.cdist(experts, experts, p=2)
    distance_matrix = _standardize(raw_distance_matrix / s1)
    
    if init_center is not None and len(init_center) == num_groups:
        assignments = torch.zeros(num_experts, dtype=torch.long, device=device)
        
        centers = experts[init_center]
        center_distances = _standardize(torch.cdist(experts, centers) / s1)
        
        assignments = torch.argmin(center_distances, dim=1)
    else:
        assignments = torch.randint(0, num_groups, (num_experts,), device=device)
    
    for i in range(num_groups):
        if (assignments == i).sum() == 0:
            rand_expert = torch.randint(0, num_experts, (1,), device=device)[0]
            assignments[rand_expert] = i
    
    prev_cost = float('inf')
    
    for iteration in range(max_iter):
        improved = False
        
        for expert_idx in range(num_experts):
            current_partition = assignments[expert_idx].item()
            current_cost = compute_total_intra_cluster_cost(distance_matrix, assignments)
            
            best_partition = current_partition
            best_cost = current_cost
            
            for new_partition in range(num_groups):
                if new_partition == current_partition:
                    continue
                    
                if (assignments == current_partition).sum() <= 1:
                    continue  
                
                assignments[expert_idx] = new_partition
                new_cost = compute_total_intra_cluster_cost(distance_matrix, assignments)
                
                if new_cost < best_cost:
                    best_cost = new_cost
                    best_partition = new_partition
                    improved = True
                
                assignments[expert_idx] = current_partition
            
            if best_partition != current_partition:
                assignments[expert_idx] = best_partition
        
        current_cost = compute_total_intra_cluster_cost(distance_matrix, assignments)
        
        if abs(prev_cost - current_cost) < tol:
            break
            
        prev_cost = current_cost
        
        if not improved and iteration < max_iter - 1:
            num_swaps = max(1, num_experts // (10 * num_groups))
            for _ in range(num_swaps):
                expert1 = torch.randint(0, num_experts, (1,), device=device)[0]
                expert2 = torch.randint(0, num_experts, (1,), device=device)[0]
                if expert1 != expert2:
                    assignments[expert1], assignments[expert2] = assignments[expert2], assignments[expert1]
    
    center_indices = []
    for partition in range(num_groups):
        partition_members = torch.where(assignments == partition)[0]
        
        if len(partition_members) == 0:
            center_indices.append(torch.randint(0, num_experts, (1,), device=device)[0].item())
            continue
        
        if len(partition_members) == 1:
            center_indices.append(partition_members[0].item())
            continue
        
        min_avg_distance = float('inf')
        best_representative = partition_members[0].item()
        
        for expert_idx in partition_members:
            distances_to_others = distance_matrix[expert_idx][partition_members]
            avg_distance = distances_to_others.mean().item()
            
            if avg_distance < min_avg_distance:
                min_avg_distance = avg_distance
                best_representative = expert_idx.item()
        
        center_indices.append(best_representative)
    
    return center_indices, assignments


def compute_total_intra_cluster_cost(distance_matrix: torch.Tensor, assignments: torch.Tensor) -> float:
    total_cost = 0.0
    num_groups = assignments.max().item() + 1
    
    for partition in range(num_groups):
        partition_members = torch.where(assignments == partition)[0]
        
        if len(partition_members) <= 1:
            continue
        
        partition_distances = distance_matrix[partition_members][:, partition_members]
        
        upper_triangle = torch.triu(partition_distances, diagonal=1)
        partition_cost = upper_triangle.sum().item()
        
        total_cost += partition_cost
    
    return total_cost


def spectral_graph_partitioning_init(experts: torch.Tensor, num_groups: int) -> torch.Tensor:
    def _standardize(x):
        x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
        min_value = x.min()
        return x - min_value
    
    device = experts.device
    num_experts = experts.size(0)
    s1 = experts.shape[1]
    
    raw_distance_matrix = torch.cdist(experts, experts, p=2)
    distance_matrix = _standardize(raw_distance_matrix / s1)
    
    sigma = distance_matrix.median()  
    similarity_matrix = torch.exp(-distance_matrix**2 / (2 * sigma**2))
    
    degree_matrix = torch.diag(similarity_matrix.sum(dim=1))
    
    laplacian = degree_matrix - similarity_matrix
    
    eigenvalues, eigenvectors = torch.linalg.eigh(laplacian)
    
    feature_vectors = eigenvectors[:, :num_groups]
    
    assignments = torch.zeros(num_experts, dtype=torch.long, device=device)
    
    centers = feature_vectors[torch.randperm(num_experts)[:num_groups]]
    
    for _ in range(50):  
        raw_distances = torch.cdist(feature_vectors, centers)
        distances = _standardize(raw_distances / feature_vectors.shape[1])
        assignments = torch.argmin(distances, dim=1)
        
        new_centers = []
        for k in range(num_groups):
            cluster_points = feature_vectors[assignments == k]
            if len(cluster_points) > 0:
                new_centers.append(cluster_points.mean(dim=0))
            else:
                new_centers.append(centers[k])  
        centers = torch.stack(new_centers)
    
    return assignments
