import torch

def get_tensors_from_ids(x, center_ids):
    '''
        x.shape = (batch_size, n_heads, seq_len, dim)
        center_ids.shape = (batch_size, n_heads, num_clusters)
    '''
    batch_size, n_heads, _, dim = x.shape
    _, _, num_clusters = center_ids.shape
    return torch.gather(x, 2, center_ids.unsqueeze(-1).expand(batch_size, n_heads, num_clusters, dim))


def kcenter_greedy(x, num_centers, norm_ord=2, return_radius=False, dist='euclidean', seed_select='greedy', **kwargs):
    batch_size, n_heads, seq_len, dim = x.shape

    center_ids = []
    x_centers = []
    assigned_center_ids = None
    # min_dist = torch.finfo(x.dtype).max * torch.ones(batch_size, n_heads, seq_len, dtype=x.dtype, device=x.device)
    min_dist = torch.inf * torch.ones(batch_size, n_heads, seq_len, dtype=torch.float32, device=x.device)
    for kk in range(num_centers):
        if len(center_ids)==0:
            new_ind = torch.randint(seq_len, size=(batch_size, n_heads, 1), device=x.device)
            center_ids = new_ind
        else:
            if seed_select == 'greedy':
                new_ind = min_dist.argmax(-1).unsqueeze(-1)
            elif seed_select == 'kpp':
                probs = torch.nn.functional.normalize(min_dist, p=2.0, dim=-1)
                new_ind = probs.reshape(-1, probs.shape[-1]).multinomial(num_samples=1, replacement=True).reshape(min_dist.shape[:-1] + (1,))
            else:
                import pdb; pdb.set_trace();

            center_ids = torch.cat((center_ids, new_ind), axis=-1)

        if new_ind.ndim==3:
            x_center = torch.gather(x, 2, new_ind.expand(new_ind.shape[:-1] + (dim,)).unsqueeze(-2))
        elif new_ind.ndim==4:
            x_center = torch.gather(x, 2, new_ind.expand(new_ind.shape[:-1] + (dim,)))

        if dist=='euclidean':
            new_min_dist = torch.cdist(x, x_center).squeeze(-1).float()
        elif dist in ['angle', 'inner', 'inner-product', 'innerproduct']:
            new_min_dist = -(x @ x_center.transpose(-1,-2)).squeeze(-1).float()

        if assigned_center_ids is None:
            assigned_center_ids = new_ind.expand(batch_size, n_heads, seq_len)
        else:
            assigned_center_ids = torch.where(new_min_dist > min_dist, assigned_center_ids, new_ind.expand(batch_size, n_heads, seq_len))

        # for i in range(n_heads):
        #     if assigned_center_ids[0, i].unique().shape[0] != kk+1:
        #         import pdb; pdb.set_trace();

        min_dist = torch.min(min_dist, new_min_dist)
        min_dist = min_dist.scatter_(2, center_ids, torch.zeros_like(center_ids, dtype=min_dist.dtype))
        x_centers.append(x_center)

    x_centers = torch.cat(x_centers, axis=2)
    return x_centers, center_ids
    