import numpy as np
import math
from pyrsistent import b
import torch
import torch.nn.functional as F
import wandb

from slot_attention.helpers.cos_dist import get_cos_dist
from slot_attention.visualization.attention_matrix import attention_matrix_to_image

def k_means_plus_plus_init(params, X, n_clusters, dist_func='exp', eps=1e-6, **kwargs):
    
    batch_size, n_points, n_dims = X.size()
    assert n_points >= n_clusters, "n_points must be >= n_clusters"
    
    sim = torch.einsum('bnd,bmd->bnm', X, X)
    sim *= (1.0 / math.sqrt(n_dims))
    
    # # test normalization
    if params.slatn_pp_init_sim_center:
        sim_mean = torch.mean(sim, dim=(1, 2), keepdim=True)
        sim = sim - sim_mean  # center
    if params.slatn_pp_init_sim_scale:
        sim_std = torch.std(sim, dim=(1, 2), keepdim=True)
        sim = sim / sim_std   # scale
    
    vis_carrier = kwargs.get('vis_carrier', None)
    if vis_carrier is not None:
        vis_carrier.add_qk_masks(name='Cos similarity', mask=sim[0].detach().cpu().numpy())
    # wandb.log({"init similarity_matrix": wandb.Image(attention_matrix_to_image(sim[0].detach().cpu().numpy()))})
    
    if dist_func == 'exp':
        dist = torch.exp(-sim) + eps
    elif dist_func == 'exp2':
        dist = torch.exp(-2 * sim) + eps
    elif dist_func == 'linear':
        dist = sim - torch.min(sim, dim=0, keepdim=True)[0]
        dist /= torch.max(dist, dim=0, keepdim=True)[0]
        dist = 1 - dist
    elif dist_func == 'cos':
        X = X / X.norm(dim=-1, keepdim=True)
        dist = get_cos_dist(X, X, eps)
    else:
        raise ValueError(f"Unknown dist_func: {dist_func}")
    # set diagonal to 0
    torch.diagonal(dist, offset=0, dim1=-2, dim2=-1).zero_()
    
    if vis_carrier is not None:
        wandb.log({'dist matrix histogram': wandb.Histogram(dist[0].detach().cpu().numpy().flatten())})
        vis_carrier.add_qk_masks(name='Dist matrix', mask=dist[0].detach().cpu().numpy())
        min_d_list = []
        
    # cluster_centers = torch.zeros(batch_size, n_clusters, n_dims, device=X.device)
    centroid_selection_mat = torch.zeros(batch_size, n_points, n_clusters, dtype=torch.float32, device=X.device)  # B x N x K
    
    min_d = torch.ones(batch_size, n_points, device=X.device) * 1e6  # B x N
    if vis_carrier is not None:
        min_d_list.append(min_d[0].clone().detach())
    next_ctr_idcs = torch.randint(0, n_points, (batch_size,), device=X.device)  # B x 1
    # cluster_centers[:, 0] = X[torch.arange(batch_size), next_ctr_idcs]  # B x K x D
    centroid_selection_mat[torch.arange(batch_size), next_ctr_idcs, 0] = 1
    for i in range(1, n_clusters):
        torch.min(min_d, dist[torch.arange(batch_size), next_ctr_idcs], out=min_d)
        if vis_carrier is not None:
            min_d_list.append(min_d[0].clone().detach())
        next_ctr_idcs = torch.multinomial(min_d, 1).squeeze()
        centroid_selection_mat[torch.arange(batch_size), next_ctr_idcs, i] = 1
    
    if vis_carrier is not None:
        min_d_stack = torch.stack(min_d_list, dim=0)
        vis_carrier.add_qk_masks(name=f'min_d', mask=min_d_stack.detach().cpu().numpy())
    
    # return cluster_centers
    return centroid_selection_mat
