'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  radial.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import numpy as np

import torch
import torch.nn as nn


class BesselRBF(nn.Module):
    """Computes the Bessel radial basis function.
    
    Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
    Equation (7)

    Adapted from MACE (https://github.com/ACEsuit/mace/blob/main/mace/modules/radial.py).
    
    Args:
        r_cutoff (float): Cutoff radius.
        n_basis (int, optional): Number of basis functions. Defaults to 8.
    """
    def __init__(self, 
                 r_cutoff: float, 
                 n_basis: int = 8):
        super().__init__()
        bessel_weights = np.pi / r_cutoff * torch.linspace(start=1.0, end=n_basis, steps=n_basis, dtype=torch.get_default_dtype())
        
        self.register_buffer('bessel_weights', bessel_weights)
        self.register_buffer('r_cutoff', torch.tensor(r_cutoff, dtype=torch.get_default_dtype()))
        self.register_buffer('pre_factor', torch.tensor(np.sqrt(2.0 / r_cutoff), dtype=torch.get_default_dtype()))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Evaluates the Bessel radial basis on the provided input distances.

        Args:
            x (torch.Tensor): Input distances.

        Returns:
            torch.Tensor: Values of the Bessel radial basis.
        """
        sin_x = torch.sin(self.bessel_weights * x)
        return self.pre_factor * (sin_x / x)
    
    def __repr__(self):
        return f'{self.__class__.__name__}(r_cutoff={self.r_cutoff}, n_basis={len(self.bessel_weights)}'


class PolynomialCutoff(nn.Module):
    """Computes polynomial cutoff function.
    
    Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
    Equation (8)

    Adapted from MACE (https://github.com/ACEsuit/mace/blob/main/mace/modules/radial.py).
    
    Args:
        r_cutoff (float): Cutoff radius.
        p (int, optional): Polynomial order. Defaults to 6.
    """
    def __init__(self, 
                 r_cutoff: float, 
                 p: int = 6):
        super().__init__()
        self.register_buffer('p', torch.tensor(p, dtype=torch.get_default_dtype()))
        self.register_buffer('r_cutoff', torch.tensor(r_cutoff, dtype=torch.get_default_dtype()))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Evaluates the cutoff function on the provided input distances.

        Args:
            x (torch.Tensor): Input distances.

        Returns:
            torch.Tensor: Values of the cutoff function.
        """
        envelope = (
                1.0
                - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_cutoff, self.p)
                + self.p * (self.p + 2.0) * torch.pow(x / self.r_cutoff, self.p + 1)
                - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_cutoff, self.p + 2)
        )
        return envelope * (x < self.r_cutoff)
    
    def __repr__(self):
        return f'{self.__class__.__name__}(p={self.p}, r_cutoff={self.r_cutoff})'
