""" 
Electrostatic potential similarity scoring functions.
JAX VERSIONS
"""

from shepherd_score.score.constants import COULOMB_SCALING, LAM_SCALING
from jax import jit
import jax.numpy as jnp
from shepherd_score.score.gaussian_volume_overlap_jax import jax_sq_cdist, get_ROCS_jax, jax_cdist


def VAB_2nd_order_esp_jax(centers_1: jnp.ndarray,
                          centers_2: jnp.ndarray,
                          charges_1: jnp.ndarray,
                          charges_2: jnp.ndarray,
                          alpha: float,
                          lam: float
                          ) -> jnp.ndarray:
    """ 2nd order volume overlap of AB """
    R2 = jax_sq_cdist(centers_1, centers_2)
    C2 = jax_sq_cdist(charges_1, charges_2)

    VAB_2nd_order = jnp.sum(jnp.pi**(1.5) \
                           * jnp.exp(-(alpha / 2) * R2) \
                           / ((2*alpha)**(1.5))\
                           * jnp.exp(-C2/lam)
                          )
    return VAB_2nd_order


def shape_tanimoto_esp_jax(centers_1: jnp.ndarray,
                           centers_2: jnp.ndarray,
                           charges_1: jnp.ndarray,
                           charges_2: jnp.ndarray,
                           alpha: float,
                           lam: float
                           ) -> jnp.ndarray:
    """ Compute Tanimoto shape similarity """
    VAA = VAB_2nd_order_esp_jax(centers_1, centers_1, charges_1, charges_1, alpha, lam)
    VBB = VAB_2nd_order_esp_jax(centers_2, centers_2, charges_2, charges_2, alpha, lam)
    VAB = VAB_2nd_order_esp_jax(centers_1, centers_2, charges_1, charges_2, alpha, lam)
    return VAB / (VAA + VBB - VAB)

@jit
def get_ROCS_esp_jax(centers_1: jnp.ndarray,
                     centers_2: jnp.ndarray,
                     charges_1: jnp.ndarray,
                     charges_2: jnp.ndarray,
                     alpha: float = 0.81,
                     lam: float = 0.1
                     ) -> jnp.ndarray:
    """
    Jitted Jax function.
    Compute electrostatic similarity which weights Gaussian volume overlap by electrostatics. 
    The Tanimoto score is used.
    
    Parameters
    ----------
    centers_1 : jnp.ndarray (N, 3)
        Coordinates for the sets of points representing molecule 1.
    centers_2 : jnp.ndarray (N, 3)
        Coordinates for the sets of points representing molecule 2.
    charges_1 : jnp.ndarray (N,)
        Electrostatic energy for the sets of points representing molecule 1.
    charges_2 : jnp.ndarray (N,)
        Electrostatic energy for the sets of points representing molecule 2.
    alpha : float
        Parameter controlling the width of the Gaussians.
    lam : float
        Parameter controlling the influence of electrostatics.
    
    Returns
    -------
    jnp.ndarray (N,)
        Tanimoto similarities of electrostatics.
    """
    # initialize prefactor and alpha matrices
    if len(charges_1.shape) == 1:
        charges_1 = charges_1.reshape((-1,1))
    if len(charges_2.shape) == 1:
        charges_2 = charges_2.reshape((-1,1))

    tanimoto = shape_tanimoto_esp_jax(centers_1, centers_2,
                                      charges_1, charges_2,
                                      alpha,
                                      lam)
    return tanimoto


def _esp_comparison_jax(points_1: jnp.ndarray,
                        centers_w_H_2: jnp.ndarray, # EXPECTS HYDROGENS INCLUDED
                        partial_charges_2: jnp.ndarray,
                        points_charges_1: jnp.ndarray,
                        radii_2: jnp.ndarray,
                        probe_radius: float = 1.0,
                        lam: float = 0.001
                        ) -> jnp.ndarray:
    """ 
    Helper function for computing the electrostatic potential (ESP) component of ShaEP score.
    It computes the difference in ESP at surface/observer points of molecule 1 for the ESP values
    generated by molecule 1 and molecule 2. It masks out observer points if they are in
    molecule 2's volume defined by vdW+probe_radius.

    Parameters
    ----------
    points_1 : jnp.ndarray (N_surf, 3)
        Surface points of molecule 1 for which ESP's will be computed and compared.

    centers_w_H_2 : jnp.ndarray (M + m_H, 3)
        Coordinates for atoms (including hydrogens) of molecule 2. Used in calculation of ESP at
        points_1 and masking out those within molecule 2's volume.
    
    partial_charges_2 : jnp.ndarray (M + m_H,)
        Partial charges corresponding to centers_w_H_2. Used to calculate ESP.

    points_charges_1 : jnp.ndarray (N_surf,)
        Precalculated ESP's of molecule 1 corresponding to points_1.
    
    radii_2 : jnp.ndarray (M + m_H,)
        Radii of each atom corresponding to centers_w_H_2. Used for masking operation.
    
    probe_radius : float (default = 1.0)
        Probe radius (default is 1 angstrom). Surfaces assumed to be generated with vdW radius and
        a probe radius of 1.2 angstroms (vdW radius of hydrogen). 1.0 used rather than 1.2 as a
        tolerance.
    
    lam : float (default = 0.001)
        Electrostatic potential weighting parameter (smaller = higher weight).
        0.001 was chosen as default based empirical observations of the distribution of scores
        generated before the summation in this function.

    Returns
    -------
    jnp.ndarray (1,)
        Point to point ESP comparison. Scores range: [0, N_surf]. Score decreases for differences
        in ESP or due to masking of poorly aligned surface points.
    """
    lam = LAM_SCALING * lam
    distances = jax_cdist(points_1, centers_w_H_2)
    # mask out molecule 1 surface points that are within molecule 2
    mask = jnp.where(jnp.all(distances >= radii_2 + probe_radius, axis=1), 1., 0.)
    # Calculate the potentials
    esp_at_surf_1 = jnp.dot(partial_charges_2, 1 / distances.T) * COULOMB_SCALING

    esp = jnp.sum(mask * jnp.exp(-jnp.square(points_charges_1 - esp_at_surf_1)/lam))
    return esp

@jit
def esp_combo_score_jax(centers_w_H_1: jnp.ndarray,
                        centers_w_H_2: jnp.ndarray,
                        centers_1: jnp.ndarray,
                        centers_2: jnp.ndarray,
                        points_1: jnp.ndarray,
                        points_2: jnp.ndarray,
                        partial_charges_1: jnp.ndarray,
                        partial_charges_2: jnp.ndarray,
                        point_charges_1: jnp.ndarray,
                        point_charges_2: jnp.ndarray,
                        radii_1: jnp.ndarray,
                        radii_2: jnp.ndarray,
                        alpha: float,
                        lam: float=0.001,
                        probe_radius: float=1.0,
                        esp_weight: float=0.5
                        ) -> jnp.ndarray:
    """
    Computes a similarity score defined by ShaEP. It is a balanced score between electrostatics
    and shape similarity.

    Parameters
    ----------
    centers_w_H_1 : jnp.ndarray (N + n_H, 3)
        Coordinates of atom centers INCLUDING hydrogens of molecule 1.
        Used for computing electrostatic potential.
        Same for centers_w_H_2 except (M + m_H, 3).

    centers_1 : jnp.ndarray (N, 3) or (n_surf, 3)
        Coordinates of points for molecule 1 used to compute shape similarity.
        Use atom centers for volumentric similarity. Use surface centers for surface similarity.
        Same for centers except (M, 3) or (m_surf, 3).
    
    points_1 : jnp.ndarray (n_surf, 3)
        Coordinates of surface points for molecule 1.
        Same for points_2 except (m_surf, 3).
    
    partial_charges_1 : jnp.ndarray (N + n_H,)
        Partial charges corresponding to the atoms in centers_w_H_1.
        Same for partial_charges_2 except (M + m_H,).
    
    point_charges_1 : jnp.ndarray (n_surf,)
        The electrostatic potential calculated at each surface point (points_1).
        Same for point_charges_1 except (m_surf,)
    
    radii_1 : jnp.ndarray (N + n_H,)
        vdW radii corresponding to the atoms in centers_w_H_1 (angstroms)
        Same for radii_2 except (M + m_H,)
    
    alpha : float
        Gaussian width parameter used for shape similarity.
    
    lam : float (default = 0.001)
        Electrostatic potential weighting parameter (smaller = higher weight).
        0.001 was chosen as default based empirical observations of the distribution of scores
        generated by _esp_comparison before summation.

    probe_radius : float (default = 1.0)
        Surface points found within vdW radii + probe radius will be masked out. Surface generation
        uses a probe radius of 1.2 (radius of hydrogen) so we use a slightly lower radius for be
        more tolerant.
    
    esp_weight : float (default = 0.5)
        Weight to be placed on electrostatic similarity with respect to shape similarity.
        0 = only shape similarity
        1 = only electrostatic similarity
    
    Returns
    -------
    jnp.ndarray (1,)
        Similarity score (range: [0, 1]). Higher is more similar.
    """

    # Calculate the difference in ESP at the surface of molecule 1
    #   Expects hydrogens for the centers
    esp_1 = _esp_comparison_jax(points_1, centers_w_H_2, partial_charges_2, point_charges_1, radii_2, probe_radius, lam)
    esp_2 = _esp_comparison_jax(points_2, centers_w_H_1, partial_charges_1, point_charges_2, radii_1, probe_radius, lam)

    electrostatic_sim = (esp_1 + esp_2) / (len(points_1) + len(points_2))
    volumetric_sim = get_ROCS_jax(centers_1, centers_2, alpha)

    score = esp_weight*electrostatic_sim + (1-esp_weight)*volumetric_sim
    return score
