import torch
import torch.nn.functional as F

import utils.const as C

def get_centroids(doc_embeddings, normalize=True):
    centroids = doc_embeddings.mean(dim=1)
    error = doc_embeddings - centroids[:, None, :]
    error = torch.norm(error, p=2, dim=2)
    mean_error = error.mean(dim=1)
    min_error = error.min(dim=1).values
    max_error = error.max(dim=1).values
    if normalize:
        centroids = F.normalize(centroids, p=2, dim=-1)
    return centroids, (mean_error, min_error, max_error)

def generate_score_fields(points, doc_embeddings):
    N, D = points.shape
    B, _, _ = doc_embeddings.shape

    scores = torch.empty(B, N, D, dtype=torch.float16, device=C.DEVICE)
    for start in range(0, B, C.BATCH_SIZE):
        end = start + C.BATCH_SIZE
        pts = points[start:end]
        rp_exp = pts.unsqueeze(0).unsqueeze(2)
        tp_exp = doc_embeddings.unsqueeze(1)
        diffs = tp_exp - rp_exp

        dists = torch.sum(diffs**2, dim=-1)
        nn_indices = dists.argmin(dim=-1)

        batch_idx = torch.arange(B, device=C.DEVICE)[:, None]
        point_idx = torch.arange(end - start, device=C.DEVICE)[None, :]
        closest_diffs = diffs[batch_idx, point_idx, nn_indices]

        closest_diffs = F.normalize(closest_diffs, p=2, dim=-1)

        scores[:, start:end, :] = closest_diffs
    return scores