import torch
import torch.nn as nn
from tqdm import tqdm
# try:
#     from pykeops.torch import LazyTensor
#     KEOPS_ENABLED = True
# except ModuleNotFoundError:
#     KEOPS_ENABLED = False
KEOPS_ENABLED = False

from .base import BaseDivergence


class EuclideanDistance(BaseDivergence):
    def __init__(self, squared=False, **kwargs):
        super().__init__(**kwargs)
        self.squared = squared
        if squared:
            self.post_fn = lambda x: torch.pow(x, 2)
        else:
            self.post_fn = lambda x: x

    def compute_mat(self, x, y):
        x3d = x[:, None, :]
        y3d = y[None, :, :]
        D_ij = torch.norm(x3d - y3d, dim=-1)
        return self.post_fn(D_ij)

    def pairwise_distance(self, x, y):
        pdist = torch.norm(x - y, dim=-1)
        return self.post_fn(pdist)
    
    def compute_full_mat(self, x, y):
        if KEOPS_ENABLED:
            Q_i = LazyTensor(x[:, None, :])
            R_j = LazyTensor(y[None, :, :])
            
            if self.squared:
                D_ij = ((Q_i - R_j) ** 2).sum(-1)
            else:
                D_ij = ((Q_i - R_j) ** 2).sum(-1).sqrt()
            return D_ij
        return self.batch_compute_mat(x, y)


class MahalanobisDistance(BaseDivergence):
    def __init__(self, in_features, out_features, squared=False, **kwargs):
        super().__init__(**kwargs)
        self.squared = squared
        if squared:
            self.post_fn = lambda x: torch.pow(x, 2)
        else:
            self.post_fn = lambda x: x
        self.layer = nn.Linear(in_features, out_features, bias=False)

    def compute_mat(self, x, y):
        x3d = x[:, None, :]
        y3d = y[None, :, :]
        D_ij = torch.norm(self.layer(x3d - y3d), dim=-1)
        return self.post_fn(D_ij)

    def pairwise_distance(self, x, y):
        pdist = torch.norm(self.layer(x - y), dim=-1)
        return self.post_fn(pdist)

    def compute_full_mat(self, x, y):
        return self.batch_compute_mat(x, y)


class DeepnormDivergence(BaseDivergence):
    def __init__(self, net, **kwargs):
        super().__init__(**kwargs)
        self.net = net

    def compute_mat(self, query_emb, ref_emb):
        M, P = query_emb.shape
        N, _ = ref_emb.shape

        x = query_emb.unsqueeze(1).repeat(1, N, 1)
        y = ref_emb.unsqueeze(0).repeat(M, 1, 1)

        dist_mat = self.net(
            x.view(M * N, P), y.view(M * N, P)
        )
        dist_mat = dist_mat.clone().view(M, N)
        return dist_mat

    def pairwise_distance(self, query_emb, ref_emb):
        x = query_emb
        y = ref_emb
        return self.net(x, y)

    def compute_full_mat(self, x, y):
        return self.batch_compute_mat(x, y)


class WidenormDivergence(BaseDivergence):
    def __init__(self, net, **kwargs):
        super().__init__(**kwargs)
        self.net = net

    def compute_mat(self, query_emb, ref_emb):
        M, P = query_emb.shape
        N, _ = ref_emb.shape

        x = query_emb.unsqueeze(1).repeat(1, N, 1)
        y = ref_emb.unsqueeze(0).repeat(M, 1, 1)

        dist_mat = self.net(
            x.view(M * N, P), y.view(M * N, P)
        )
        dist_mat = dist_mat.clone().view(M, N)
        return dist_mat

    def pairwise_distance(self, query_emb, ref_emb):
        x = query_emb
        y = ref_emb
        return self.net(x, y)

    def compute_full_mat(self, x, y):
        return self.batch_compute_mat(x, y)
