import numpy as np
from sklearn.cluster import AgglomerativeClustering
import torch


def pairwise_angles(sources):
    angles = torch.zeros([len(sources), len(sources)])
    for i, source1 in enumerate(sources):
        for j, source2 in enumerate(sources):
            s1 = flatten(source1)
            s2 = flatten(source2)
            angles[i,j] = torch.sum(s1*s2)/(torch.norm(s1)*torch.norm(s2)+1e-12)
    return angles.numpy()

def compute_pairwise_similarities(clients):
    return pairwise_angles([client.dW for client in clients])

def cluster_clients( S):
    clustering = AgglomerativeClustering(affinity="precomputed", linkage="complete").fit(-S)

    c1 = np.argwhere(clustering.labels_ == 0).flatten() 
    c2 = np.argwhere(clustering.labels_ == 1).flatten() 
    return c1, c2

def compute_max_update_norm(self, cluster):
        return np.max([torch.norm(flatten(client.dW)).item() for client in cluster])

    
def compute_mean_update_norm(self, cluster):
    return torch.norm(torch.mean(torch.stack([flatten(client.dW) for client in cluster]), 
                                    dim=0)).item()