

import math
from torch import norm
import torch
import wandb
from slot_attention.helpers.cos_dist import get_cos_dist
from slot_attention.visualization.attention_matrix import attention_matrix_to_image


def soft_k_means(X, cluster_centers, n_iterations=3, eps=1e-6, **kwargs):
    
    batch_size, n_points, n_dims = X.size()
    
    for i in range(n_iterations):
        sim = torch.einsum('bnd,bkd->bnk', X, cluster_centers)  # B x N x K
        sim = sim * (1.0 / math.sqrt(n_dims))
        assign_points_to_clusters = torch.softmax(sim, dim=-1)  # normalize over clusters
        assign_clusters_to_points = assign_points_to_clusters / (assign_points_to_clusters.sum(dim=-2, keepdim=True) + eps) # normalize over points
        cluster_centers = torch.einsum('bnd,bnk->bkd', X, assign_clusters_to_points)  # B x K x D
        
    # sum of distances from points to closest cluster center
    squared_dissim = torch.exp(-2 * sim) + eps  # B x N x K
    min_squared_dissim = torch.min(squared_dissim, dim=-1)[0]  # B x N
    dissim_measure = torch.sum(min_squared_dissim, dim=-1)  # B
        
    vis_carrier = kwargs.get('vis_carrier', None)
    if vis_carrier is not None:
        vis_carrier.add_qk_masks(name='Cluster assignment', mask=assign_clusters_to_points[0].detach().cpu().numpy())
    # wandb.log({"assign_clusters_to_points": wandb.Image(attention_matrix_to_image(assign_clusters_to_points[0].detach().cpu().numpy()))})
    
    return cluster_centers, assign_clusters_to_points, dissim_measure


def soft_k_means_with_arcos_dist(X, cluster_centers, n_iterations=3, eps=1e-6):
    # TODO: wrong code: the updates should be based on similarity, not distance
    batch_size, n_points, n_dims = X.size()
    batch_size, n_clusters, n_dims = cluster_centers.size()
    
    # cluster on the unit sphere
    X = X / X.norm(dim=-1, keepdim=True)
    
    for i in range(n_iterations):
        cluster_centers = cluster_centers / (cluster_centers.norm(dim=-1, keepdim=True) + eps)
        dist = get_cos_dist(X, cluster_centers, eps)  # B x N x K
        assign_points_to_clusters = dist / (dist.sum(dim=-1, keepdim=True) + eps)  # normalize over clusters
        assign_clusters_to_points = assign_points_to_clusters / (assign_points_to_clusters.sum(dim=-2, keepdim=True) + eps) # normalize over points
        cluster_centers = torch.einsum('bnd,bnk->bkd', X, assign_clusters_to_points)
        
    return cluster_centers, assign_clusters_to_points