import torch
import numpy as np
from sklearn.cluster import AffinityPropagation
from collections import defaultdict
import matplotlib.pyplot as plt


def similarity_pair_layer_batch_cka(data1, data2, bs=2048):
    feat1 = data1['feat']
    feat2 = data2['feat']
    num_layer1 = len(feat1.keys())
    num_layer2 = len(feat2.keys())
    name1 = data1['model_name']
    name2 = data2['model_name']
    #print(f'number of layers in {name1} is {num_layer1}')
    #print(f'number of layers in {name2} is {num_layer2}')

    
    def process_features(features):
        if isinstance(features, dict):
            processed = {}
            for layer_name, layer_data in features.items():
                if isinstance(layer_data, list):
                    processed[layer_name] = torch.cat(layer_data, dim=0)
                else:
                    processed[layer_name] = layer_data
            return processed
        return features
    
    feat1_processed = process_features(feat1)
    feat2_processed = process_features(feat2)
    first_layer_key = list(feat1_processed.keys())[0]
    num_sample = feat1_processed[first_layer_key].shape[0]
    
    num_batch = int(np.ceil(num_sample / bs))
    #print(f'number of samples {num_sample}, number of batch {num_batch}, batch size {bs}')

    cka_map = torch.zeros((num_batch, num_layer1, num_layer2)).cuda()
    
    for b_id in range(num_batch):
        start = b_id * bs
        end = min((b_id + 1) * bs, num_sample)
        for i, (k1, v1) in enumerate(feat1_processed.items()):
            for j, (k2, v2) in enumerate(feat2_processed.items()):
                tensor1 = v1[start:end] if isinstance(v1, torch.Tensor) else torch.tensor(v1[start:end])
                tensor2 = v2[start:end] if isinstance(v2, torch.Tensor) else torch.tensor(v2[start:end])
                cka_from_examples = cka_linear_torch(
                    tensor1.cuda(),
                    tensor2.cuda())
                cka_map[b_id, i, j] = cka_from_examples
    
    return cka_map.mean(0).detach().cpu().numpy()

def compute_cka_matrix(feature_matrices):
    n_layers = len(feature_matrices)
    cka_matrix = np.zeros((n_layers, n_layers))
    
    for i in range(n_layers):
        for j in range(i, n_layers):  
            cka_val = cka_linear_torch(feature_matrices[i], feature_matrices[j])
            cka_matrix[i, j] = cka_val
            cka_matrix[j, i] = cka_val  
            #print(f"CKA between layer {i} and {j}: {cka_val:.4f}")
    
    return cka_matrix

def cka_linear_torch(x1, x2):
    x1 = gram_linear(rearrange_activations(x1))
    x2 = gram_linear(rearrange_activations(x2))
    similarity = _cka(x1.cuda(), x2.cuda())
    return similarity

def rearrange_activations(activations):
    batch_size = activations.shape[0]
    # flat_activations = activations.view(batch_size, -1) # orginal one
    flat_activations = activations.reshape(batch_size, -1)
    return flat_activations

def gram_linear(x):
    return torch.mm(x, x.T)

def center_gram(gram, unbiased=False):
    if not torch.allclose(gram, gram.T):
        raise ValueError('Input must be a symmetric matrix.')

    if unbiased:
        pass
        # TODO
    else:
        means = torch.mean(gram, dim=0, dtype=torch.float64) # , dtype=torch.float64
        means -= torch.mean(means) / 2
        gram -= torch.unsqueeze(means, len(means.shape))
        gram -= torch.unsqueeze(means, 0)
    return gram

def _cka(gram_x, gram_y, debiased=False):
    
    gram_x = center_gram(gram_x, unbiased=debiased)
    gram_y = center_gram(gram_y, unbiased=debiased)

    
    scaled_hsic = torch.dot(gram_x.reshape(-1), gram_y.reshape(-1))

    normalization_x = torch.norm(gram_x)
    normalization_y = torch.norm(gram_y)
    
    if normalization_x == 0.0 or normalization_y == 0.0:
        return 0.0
    return scaled_hsic / (normalization_x * normalization_y)

def process_layer_features(features):
    
    if isinstance(features, list):
        processed_features = []
        for tensor in features:
            if isinstance(tensor, torch.Tensor):
                if len(tensor.shape) > 2:
                    tensor = tensor.view(tensor.size(0), -1)
                processed_features.append(tensor)
        
        if processed_features:
            combined = torch.cat(processed_features, dim=0)
            return combined
        else:
            return None
    elif isinstance(features, torch.Tensor):
        if len(features.shape) > 2:
            features = features.view(features.size(0), -1)
        return features
    else:
        return None


def cross_model_layer_clustering(client_layer_dict, damping, max_iter, bs, plot_matrix=False):
    
    #print("Reorganizing layer data...")
    layer_keys = list(client_layer_dict.keys())
    num_layers = len(layer_keys)
    
    
    layer_index_map = {i: key for i, key in enumerate(layer_keys)}
    
    #print(f"Total layers to cluster: {num_layers}")
    #print(f"Layer keys sample: {layer_keys[:5]}")  
    #print("Computing cross-model layer similarity matrix...")
    similarity_matrix = np.zeros((num_layers, num_layers))
    
   
    for i in range(num_layers):
        for j in range(i, num_layers):
            if i == j:
                similarity_matrix[i, j] = 1.0
                similarity_matrix[j, i] = 1.0
            else:
                
                data1 = {
                    'feat': {'layer_0': client_layer_dict[layer_keys[i]]},
                    'model_name': f'client_{layer_keys[i][0]}_layer_{layer_keys[i][1]}'
                }
                data2 = {
                    'feat': {'layer_0': client_layer_dict[layer_keys[j]]},
                    'model_name': f'client_{layer_keys[j][0]}_layer_{layer_keys[j][1]}'
                }
                
               
                cka_value = similarity_pair_layer_batch_cka(data1, data2, bs=bs)
                similarity_matrix[i, j] = cka_value[0, 0]  
                similarity_matrix[j, i] = cka_value[0, 0]
                
                #print(f"Similarity between {layer_keys[i]} and {layer_keys[j]}: {cka_value[0, 0]:.4f}")
    
    
    similarity_matrix = np.nan_to_num(similarity_matrix, nan=0.0, posinf=1.0, neginf=0.0)
    
    
    flat_sim = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
    #if len(flat_sim) > 0:
    preference_value = np.quantile(flat_sim, 0.75)
    #else:
    #preference_value = np.median(similarity_matrix)
    
    print(f"Similarity matrix shape: {similarity_matrix.shape}")
    print(f"Similarity value range: [{similarity_matrix.min():.4f}, {similarity_matrix.max():.4f}]")
    print(f"Preference value: {preference_value:.4f}")
    
    
    try:
        affprop = AffinityPropagation(
            affinity="precomputed",
            preference=preference_value,
            damping=damping,
            random_state=42,
            convergence_iter=100,  
            max_iter=max_iter,
            verbose=True
        ).fit(similarity_matrix)
        
        cluster_centers_indices = affprop.cluster_centers_indices_
        cluster_labels = affprop.labels_
        
        print(f"Number of clusters: {len(cluster_centers_indices)}")
        
       
        center_layer_info = {}
        cluster_dict = {}
        
        for cluster_id, center_idx in enumerate(cluster_centers_indices):
            center_layer_info[cluster_id] = layer_keys[center_idx]
            cluster_dict[cluster_id] = []
        
        for idx, label in enumerate(cluster_labels):
            if label != -1:  
                cluster_dict[label].append(layer_keys[idx])
        
        
        print("\n=== Cross-Model Layer Clustering Results ===")
        for cluster_id, center_key in center_layer_info.items():
            members = cluster_dict[cluster_id]
            print(f"Cluster {cluster_id}: Centered at {center_key}")
            print(f"  Members ({len(members)} layers):")
            for member in members:
                print(f"    - {member}")
            print()
        
        if plot_matrix:
            visualize_similarity_matrix(similarity_matrix, cluster_labels, layer_keys)
        
        
        return cluster_dict, center_layer_info, similarity_matrix
        
    except Exception as e:
        print(f"AP clustering failed: {e}")
        default_cluster = {0: layer_keys}
        default_centers = {0: layer_keys[0]}
        return default_cluster, default_centers, similarity_matrix

