import torch


class SnippetDistance():
    def __init__(self, model, TYPE):
        self.type = TYPE
        self.model = model

    def __call__(self, a, b):
        # pool across time axis
        # a, a torch tensor of shape, [1, T1, 51]
        # b, a batch of tensor of shape, [B, T2, 51]
        T1 = a.size(1)
        T2 = b.size(1)
        B = b.size(0)
        res1 = self.model(a, torch.tensor([T1]))
        res2 = self.model(b, torch.tensor([T2] * B))
        forward_feat = res1[:, 0]  # [T1, f]
        forward_feat /= torch.linalg.norm(forward_feat, dim=-1, keepdim=True, ord=2)
        reverse_feat = res2  # [T2, B, f]
        reverse_feat /= torch.linalg.norm(reverse_feat, dim=-1, keepdim=True, ord=2)
        attmap = torch.mean(reverse_feat, dim=0) @ torch.mean(forward_feat, dim=0)
        return attmap
