""" 
Gaussian volume overlap scoring functions -- Shape-only (i.e., not color)

Batched and non-batched functionalities

Reference math:
https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K
https://doi.org/10.1021/j100011a016
"""
from typing import Union
import numpy as np
import torch


def VAB_2nd_order(centers_1: torch.Tensor,
                  centers_2: torch.Tensor,
                  alpha: float
                  ) -> torch.Tensor:
    """
    2nd order volume overlap of AB.
    Torch implemenation with batched and single instance functionality.
    """
    # Single instance
    if len(centers_1.shape) == 2:
        R2 = (torch.cdist(centers_1, centers_2, compute_mode='use_mm_for_euclid_dist')**2.0).T

        VAB_second_order = torch.sum(np.pi**(1.5) * torch.exp(-(alpha / 2) * R2) / ((2*alpha)**(1.5)))

    # Batched
    elif len(centers_1.shape) == 3:
        R2 = (torch.cdist(centers_1, centers_2)**2.0).permute(0,2,1)
    
        VAB_second_order = torch.sum(torch.sum(np.pi**(1.5) *
                                            torch.exp(-(alpha / 2) * R2) /
                                            ((2*alpha)**(1.5)),
                                            dim = 2),
                                    dim = 1)
    return VAB_second_order


def shape_tanimoto(centers_1: torch.Tensor,
                   centers_2: torch.Tensor,
                   alpha: float
                   ) -> torch.Tensor:
    """ Compute Tanimoto shape similarity """
    VAA = VAB_2nd_order(centers_1, centers_1, alpha)
    VBB = VAB_2nd_order(centers_2, centers_2, alpha)
    VAB = VAB_2nd_order(centers_1, centers_2, alpha)
    return VAB / (VAA + VBB - VAB)


def get_ROCS(centers_1: Union[torch.Tensor, np.ndarray],
             centers_2: Union[torch.Tensor, np.ndarray],
             alpha:float = 0.81
             ) -> torch.Tensor:
    """
    Volumentric shape similarity with tunable "alpha" Gaussian width parameter.
    Handles both batched and single instances. PyTorch implementation.

    Parameters
    ----------
    centers_1 : Union[torch.Tensor, np.ndarray] (batch, N, 3) or (N, 3)
        Coordinates of each point of the first point cloud.
    centers_2 : Union[torch.Tensor, np.ndarray] (batch, N, 3) or (N, 3)
        Coordinates of each point of the second point cloud.
    alpha : float (default=0.81)
        Gaussian width parameter. Lower value corresponds to wider Gaussian (longer tail).
    
    Returns
    -------
    torch.Tensor : (batch, 1) or (1,)
    """
    if isinstance(centers_1, np.ndarray):
        centers_1 = torch.Tensor(centers_1)
    if isinstance(centers_2, np.ndarray):
        centers_2 = torch.Tensor(centers_2)
    # initialize prefactor and alpha matrices
    tanimoto = shape_tanimoto(centers_1, centers_2, alpha)
    return tanimoto


##################################################################
##################### Older implementations ######################
##################################################################

def VAB_2nd_order_batched(centers_1: torch.Tensor,
                          centers_2: torch.Tensor,
                          alphas_1: torch.Tensor,
                          alphas_2: torch.Tensor,
                          prefactors_1: torch.Tensor,
                          prefactors_2: torch.Tensor
                          ) -> torch.Tensor:
    """ 
    Calculate the 2nd order volume overlap of AB -- batched functionality
    
    Parameters
    ----------
        centers_1 : (torch.Tensor) (batch_size, num_atoms_1, 3)
            Coordinates of atoms in molecule 1

        centers_2 : (torch.Tensor) (batch_size, num_atoms_2, 3)
            Coordinates of atoms in molecule 2

        alphas_1 : (torch.Tensor) (batch_size, num_atoms_1)
            Alpha values for atoms in molecule 1

        alphas_2 : (torch.Tensor) (batch_size, num_atoms_2)
            Alpha values for atoms in molecule 2

        prefactors_1 : (torch.Tensor) (batch_size, num_atoms_1)
            Prefactor values for atoms in molecule 1

        prefactors_2 : (torch.Tensor) (batch_size, num_atoms_2)
            Prefactor values for atoms in molecule 2
    
    Returns
    -------
    torch.Tensor (batch_size,)
        Representing the 2nd order volume overlap of AB for each batch
    """
    R2 = (torch.cdist(centers_1, centers_2)**2.0).permute(0,2,1)

    prefactor1_prod_prefactor2 = (prefactors_1.unsqueeze(1) * prefactors_2.unsqueeze(2))
    
    alpha1_prod_alpha2 = (alphas_1.unsqueeze(1) * alphas_2.unsqueeze(2))
    alpha1_sum_alpha2 = (alphas_1.unsqueeze(1) + alphas_2.unsqueeze(2))
    VAB_second_order = torch.sum(torch.sum(np.pi**(1.5) *
                                        prefactor1_prod_prefactor2 *
                                        torch.exp(-(alpha1_prod_alpha2 / alpha1_sum_alpha2) * R2) /
                                        (alpha1_sum_alpha2**(1.5)),
                                        dim = 2),
                                dim = 1)
    return VAB_second_order


def shape_tanimoto_batched(centers_1: torch.Tensor,
                           centers_2: torch.Tensor,
                           alphas_1: torch.Tensor,
                           alphas_2: torch.Tensor,
                           prefactors_1: torch.Tensor,
                           prefactors_2: torch.Tensor
                           ) -> torch.Tensor:
    """
    Calculate the Tanimoto shape similarity between two batches of molecules.
    """
    VAA = VAB_2nd_order_batched(centers_1, centers_1, alphas_1, alphas_1, prefactors_1, prefactors_1)
    VBB = VAB_2nd_order_batched(centers_2, centers_2, alphas_2, alphas_2, prefactors_2, prefactors_2)
    VAB = VAB_2nd_order_batched(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)
    return VAB / (VAA + VBB - VAB)


def get_ROCS_batch(centers_1:torch.Tensor,
                   centers_2:torch.Tensor,
                   prefactor:float = 0.8,
                   alpha:float = 0.81) -> torch.Tensor:

    # initialize prefactor and alpha matrices
    prefactors_1 = (torch.ones(centers_1.shape[0]) * prefactor).unsqueeze(0)
    prefactors_2 = (torch.ones(centers_2.shape[0]) * prefactor).unsqueeze(0)
    alphas_1 = (torch.ones(prefactors_1.shape[0]) * alpha).unsqueeze(0)
    alphas_2 = (torch.ones(prefactors_2.shape[0]) * alpha).unsqueeze(0)

    tanimoto_score = shape_tanimoto_batched(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)
    return tanimoto_score


def VAB_2nd_order_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2) -> torch.Tensor:
    """ 2nd order volume overlap of AB """
    R2 = (torch.cdist(centers_1, centers_2, compute_mode='use_mm_for_euclid_dist')**2.0).T
    prefactor1_prod_prefactor2 = prefactors_1 * prefactors_2.unsqueeze(1)
    alpha1_prod_alpha2 = alphas_1 * alphas_2.unsqueeze(1)
    alpha1_sum_alpha2 = alphas_1 + alphas_2.unsqueeze(1)

    VAB_second_order = torch.sum(np.pi**(1.5) * prefactor1_prod_prefactor2 * torch.exp(-(alpha1_prod_alpha2 / alpha1_sum_alpha2) * R2) / (alpha1_sum_alpha2**(1.5)))
    return VAB_second_order


def shape_tanimoto_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2) -> torch.Tensor:
    """ Compute Tanimoto shape similarity """
    VAA = VAB_2nd_order_full(centers_1, centers_1, alphas_1, alphas_1, prefactors_1, prefactors_1)
    VBB = VAB_2nd_order_full(centers_2, centers_2, alphas_2, alphas_2, prefactors_2, prefactors_2)
    VAB = VAB_2nd_order_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)
    return VAB / (VAA + VBB - VAB)


def get_ROCS_full(centers_1:torch.Tensor,
             centers_2:torch.Tensor,
             prefactor:float = 0.8,
             alpha:float = 0.81
             ) -> torch.Tensor:

    # initialize prefactor and alpha matrices
    prefactors_1 = torch.ones(centers_1.shape[0]) * prefactor
    prefactors_2 = torch.ones(centers_2.shape[0]) * prefactor
    alphas_1 = torch.ones(prefactors_1.shape[0]) * alpha
    alphas_2 = torch.ones(prefactors_2.shape[0]) * alpha

    tanimoto = shape_tanimoto_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)
    return tanimoto
