import torch
import torch.nn as nn
from torch.autograd.functional import jacobian
# try:
    # from pykeops.torch import LazyTensor
    # KEOPS_ENABLED = True
# except ModuleNotFoundError:
    # KEOPS_ENABLED = False
KEOPS_ENABLED = False

from .base import BaseDivergence


class BregmanDivergence(BaseDivergence):

    def __init__(self, phi, take_sqrt=False, **kwargs):
        super().__init__(phi, **kwargs)
        self.take_sqrt = take_sqrt
        if take_sqrt:
            self.post_fn = torch.sqrt
        else:
            self.post_fn = lambda x: x

    def compute_mat(self, x, y):
        phi_x = self.phi(x).flatten().unsqueeze(1)
        phi_y = self.phi(y).flatten().unsqueeze(0)
        x_min_y = x.unsqueeze(1) - y.unsqueeze(0)
        
        grad_phi = jacobian(
            lambda z: self.phi(z).flatten().sum(), y, create_graph=True
        )
        grad_div = (x_min_y * grad_phi.unsqueeze(0)).sum(-1)
        breg_div = torch.clamp(phi_x - phi_y - grad_div, 0)
        return self.post_fn(breg_div)

    def pairwise_distance(self, x, y):
        phi_x = self.phi(x).flatten()
        phi_y = self.phi(y).flatten()
        
        grad_phi = jacobian(
            lambda z: self.phi(z).flatten().sum(), y, create_graph=True
        )
        grad_div = ((x-y) * grad_phi).sum(dim=1)
        breg_div = torch.clamp(phi_x - phi_y - grad_div, 0)
        return self.post_fn(breg_div)
      
    def compute_full_mat(self, x, y):
        if KEOPS_ENABLED:
            pk_x = LazyTensor(x[:, None, :])
            pk_y = LazyTensor(y[None, :, :])
            x_min_y = pk_x - pk_y

            phi_x = self.phi(x).flatten()
            phi_y = self.phi(y).flatten()
            pk_phi_x = LazyTensor(phi_x.unsqueeze(1), axis=0)
            pk_phi_y = LazyTensor(phi_y.unsqueeze(1), axis=1)

            grad_phi = jacobian(
                lambda z: self.phi(z).flatten().sum(), y, create_graph=False
            )
            grad_div = (x_min_y * grad_phi.unsqueeze(0)).sum(-1)
            
            breg_div = pk_phi_x - pk_phi_y - grad_div
            # breg_div = breg_div.clamp(0, breg_div)
            return self.post_fn(breg_div)
        return self.batch_compute_mat(x, y)

    def extract_components(self, x, y):
        phi_x = self.phi(x).flatten()
        phi_y = self.phi(y).flatten()
        
        grad_phi = jacobian(
            lambda z: self.phi(z).flatten().sum(), y, create_graph=True
        )
        grad_div = ((x-y) * grad_phi).sum(dim=1)
        return phi_x.detach().cpu().numpy(), phi_y.detach().cpu().numpy(), grad_div.detach().cpu().numpy()

    
class GSBregmanDivergence(BaseDivergence):

    def __init__(self, phi, alpha=1.0, beta=1.0, **kwargs):
        super().__init__(phi, **kwargs)
        self.alpha = alpha
        self.beta = beta

    def compute_mat(self, x, y):
        raise NotImplementedError

    def pairwise_distance(self, x, y):
        phi_x = self.phi(x).flatten()
        phi_y = self.phi(y).flatten()
        
        grad_phi_y = jacobian(
            lambda z: self.phi(z).flatten().sum(), y, create_graph=True
        )
        grad_div = ((x-y) * grad_phi_y).sum(dim=1)
        breg_div = torch.clamp(phi_x - phi_y - grad_div, 0)

        grad_phi_x = jacobian(
            lambda z: self.phi(z).flatten().sum(), x, create_graph=True
        )
        grad_div_rev = ((y-x) * grad_phi_x).sum(dim=1)
        breg_div_rev = torch.clamp(phi_y - phi_x - grad_div_rev, 0)

        xy_diff_norm = (self.alpha / 2) * torch.norm(x - y, dim=1).pow(2)
        grad_diff_norm = (self.beta / 2) * torch.norm(grad_phi_y - grad_phi_x, dim=1).pow(2)

        gs_breg_div = breg_div + breg_div_rev + xy_diff_norm + grad_diff_norm
        return torch.sqrt(gs_breg_div)

    def compute_full_mat(self, x, y):
        raise NotImplementedError


class DummyBregmanDivergence(BaseDivergence):

    def __init__(self, phi, **kwargs):
        super().__init__(phi, **kwargs)
        self.phi.register_parameter(name='alpha', param=torch.nn.Parameter(torch.tensor([0.5])))

    def compute_mat(self, x, y):
        phi_x = self.phi(x).flatten().unsqueeze(1)
        phi_y = self.phi(y).flatten().unsqueeze(0)
        x_min_y = x.unsqueeze(1) - y.unsqueeze(0)

        grad_phi = self.phi.alpha * y
        grad_div = (x_min_y * grad_phi.unsqueeze(0)).sum(-1)
        breg_div = torch.clamp(phi_x - phi_y - grad_div, 0)
        return breg_div
    
    def pairwise_distance(self, query_emb, ref_emb):
        x = query_emb
        y = ref_emb
        
        phi_x = self.phi(x).flatten()
        phi_y = self.phi(y).flatten()

        grad_phi = self.phi.alpha * y
        grad_div = ((x-y)*grad_phi).sum(dim=1)
        breg_div = torch.clamp(phi_x - phi_y - grad_div, 0)
        return breg_div

    # def compute_full_mat(self, query_emb, ref_emb):
    #     x = query_emb
    #     y = ref_emb
        
    #     pk_x = LazyTensor(x[:, None, :])
    #     pk_y = LazyTensor(y[None, :, :])
    #     x_min_y = pk_x - pk_y

    #     phi_x = self.phi(x).flatten()
    #     phi_y = self.phi(y).flatten()
    #     pk_phi_x = LazyTensor(phi_x.unsqueeze(1), axis=0)
    #     pk_phi_y = LazyTensor(phi_y.unsqueeze(1), axis=1)

    #     grad_phi = self.phi.alpha * y
    #     grad_div = (x_min_y * grad_phi.unsqueeze(0)).sum(-1)
    #     breg_div = torch.clamp(pk_phi_x - pk_phi_y - grad_div, 0)
    #     return breg_div

