import torch
import torch.nn as nn
import torch.nn.functional as F


def get_digit_loss(
    logits,
    targets,
    vocab_emb,
    temperature: float = 0.1,
):
    # BLV, BL, VC
    if len(logits.shape) > 2:
        logits = logits.reshape(-1, logits.shape[-1])
        targets = targets.reshape(-1)
    with torch.no_grad():
        t2 = F.embedding(targets, vocab_emb)
        # L2 distance
        d1 = (t2**2).sum(-1)[:, None]
        d2 = torch.einsum("xc,vc->xv", t2, vocab_emb)
        d3 = (vocab_emb**2).sum(-1)[None, :]
        dist = d1 + d3 - 2 * d2
        target_dist = (-dist / temperature).softmax(-1)

    loss = nn.KLDivLoss(reduction="batchmean")(
        F.log_softmax(logits, dim=-1), target_dist
    )
    # z = target_dist
    # print(targets)
    # print(z.topk(5, -1))
    # import ipdb; ipdb.set_trace()  # noqa # fmt: skip
    return loss.mean()


def get_digit_base_loss(logits, targets):
    raise NotImplementedError


if __name__ == "__main__":
    device = "cuda"
    vocab_emb_path = "../../../../data/vq/vq_ds16_c2i_vocab.pt"
    vocab_emb = torch.load(vocab_emb_path, map_location="cpu", weights_only=False)
    vocab_emb = vocab_emb.to(device)
    vocab_emb = F.normalize(vocab_emb, p=2, dim=-1)

    logits = torch.randn(2, 7, vocab_emb.shape[0]).to(device)
    targets = torch.randint(vocab_emb.shape[0], (2, 7)).to(device)
    loss = get_digit_loss(logits, targets, vocab_emb, 0.06)
    print(loss * 0.1)
    import ipdb; ipdb.set_trace()  # noqa # fmt: skip
