'''

Part of this code was adapted from https://github.com/muhrin/mrs-tutorial

'''


import math
import torch
import numpy as np
import scipy as sp
import scipy.special


def FixedCosineRadialModel(max_radius, number_of_basis, min_radius=0.):
    spacing = (max_radius - min_radius) / number_of_basis
    radii = torch.linspace(min_radius, max_radius, number_of_basis)
    step = radii[1] - radii[0]

    def radial_function(r):
        shape = r.shape
        radial_shape = [1] * len(shape) + [number_of_basis]
        centers = radii.reshape(*radial_shape)
        return (r.unsqueeze(-1) - centers).div(step).add(1).relu().sub(2).neg().relu().add(1).mul(
            math.pi / 2).cos().pow(2)

    return radial_function

def ZernickeRadialModel(r_max, number_of_basis, l_max):
    
    def radial_function(r):
        try:
            r = r.numpy()
        except:
            r = r.detach().numpy()
        return_val = []
        for l in range(l_max+1):
            for n in range(number_of_basis):
                if (n-l) % 2 == 1:
                    return_val.append(np.full(r.shape[0], 0.+0j))
                    continue

                # dimension of the Zernike polynomial
                D = 3.
                # constituent terms in the polynomial
                A = np.power(-1,(n-l)/2.) 
                B = np.sqrt(2.*n + D)
                C = sp.special.binom(int((n+l+D)/2. - 1),
                                     int((n-l)/2.))
                E = sp.special.hyp2f1(-(n-l)/2.,
                                       (n+l+D)/2.,
                                       l+D/2.,
                                       np.array(r)/r_max*np.array(r)/r_max)
                F = np.power(np.array(r)/r_max,l)
                
                coeff = A*B*C*E*F
                
                return_val.append(coeff)
        
        return torch.tensor(np.transpose(np.vstack(return_val)))
    
    return radial_function

class ZernickeRadialFunctions:
    
    def __init__(self, rcut, number_of_basis, lmax, complex_sph=False, record_zeros=False):
        self.rcut = rcut
        self.number_of_basis = number_of_basis
        self.lmax = lmax
        self.complex_sph = complex_sph
        self.record_zeros = record_zeros
        self.radius_depends_on_l = True
        if record_zeros:
            self.multiplicities = [number_of_basis] * (lmax + 1)
        else:
            rv = torch.arange(number_of_basis)
            self.multiplicities = [rv[torch.logical_and(rv >= l, (rv - l) % 2 == 0)].shape[0] for l in range(lmax + 1)]
    
    def __call__(self, r):
        try:
            r = r.numpy()
        except:
            r = r.detach().numpy()
        
        # cap radiuses at self.rcut
        r[r > self.rcut] = self.rcut
        
        return_val = []
        for l in range(self.lmax+1):
            for n in range(self.number_of_basis):
                if (n-l) % 2 == 1 or (n < l):
                    if self.record_zeros:
                        return_val.append(np.full(r.shape[0], 0.0))
                    continue

                # dimension of the Zernike polynomial
                D = 3.
                # constituent terms in the polynomial
                A = np.power(-1,(n-l)/2.) 
                B = np.sqrt(2.*n + D)
                C = sp.special.binom(int((n+l+D)/2. - 1),
                                     int((n-l)/2.))
                E = sp.special.hyp2f1(-(n-l)/2.,
                                       (n+l+D)/2.,
                                       l+D/2.,
                                       np.array(r)/self.rcut*np.array(r)/self.rcut)
                F = np.power(np.array(r)/self.rcut,l)
                
                coeff = A*B*C*E*F
                
                return_val.append(coeff)
        
        return torch.tensor(np.transpose(np.vstack(return_val))).type(torch.float)


class CosineFunctions:
    def __init__(self, nRadialFunctions, radialCutoff):
        self.nRadialFunctions = nRadialFunctions
        self.radialCutoff = radialCutoff
        self.factors = (math.pi / radialCutoff) * torch.arange(0, nRadialFunctions)

    def __call__(self, r):
        return torch.cos(torch.outer(r, self.factors))


# Have value and 1st derivative both go to 0 at the cutoff.
class FadeAtCutoff:
    def __init__(self, radialModel, radialCutoff):
        self.radialModel = radialModel
        self.radialCutoff = radialCutoff

    def __call__(self, r):
        f = self.radialModel(r)
        nRadialFunctions = f.shape[1]
        fadeFn = (self.radialCutoff - r) * (self.radialCutoff - r)
        fadeFn = torch.outer(fadeFn, torch.ones((nRadialFunctions,)))
        return f * fadeFn


# Orthonormalizes a set of radial basis functions on a sphere.
#   Uses modified Gram-Schmidt w/trapezoidal integration rule.
#   Tabulates radial basis, returns linear interpolation differentiable in r
class OrthonormalRadialFunctions:
    def innerProduct(self, a, b):
        return torch.trapz(a * b * self.areaSamples, self.radialSamples)

    def norm(self, a):
        return torch.sqrt(self.innerProduct(a, a))

    def __init__(self, num_radials, radialModel, rcut, num_samples, radius_depends_on_l=False):
        self.nRadialFunctions = num_radials
        self.radialCutoff = rcut
        self.radius_depends_on_l = radius_depends_on_l
        self.complex_sph = False

        self.radialSamples = torch.linspace(0, rcut, num_samples)
        self.radialStep = self.radialSamples[1] - self.radialSamples[0]

        nonOrthogonalSamples = radialModel(self.radialSamples)

        self.areaSamples = 4 * math.pi * self.radialSamples * self.radialSamples

        self.fSamples = torch.zeros_like(nonOrthogonalSamples)

        u0 = nonOrthogonalSamples[:, 0]
        self.fSamples[:, 0] = u0 / self.norm(u0)

        for i in range(1, num_radials):
            ui = nonOrthogonalSamples[:, i]
            for j in range(i):
                uj = self.fSamples[:, j]
                ui -= self.innerProduct(uj, ui) / self.innerProduct(uj, uj) * uj
            self.fSamples[:, i] = ui / self.norm(ui)

        self.radialStep

    def __call__(self, r):
        rNormalized = r / self.radialStep
        rNormalizedFloor = torch.floor(rNormalized)
        rNormalizedFloorInt = rNormalized.long()
        indicesLow = torch.min(torch.max(rNormalizedFloorInt, torch.tensor([0], dtype=torch.long)),
                               torch.tensor([len(self.radialSamples) - 2], dtype=torch.long))
        rRemainderNormalized = rNormalized - indicesLow
        rRemainderNormalized = torch.unsqueeze(rRemainderNormalized, -1)  # add a dimension at the end
        rRemainderNormalized = rRemainderNormalized.expand(
            list(rRemainderNormalized.shape[:-1]) + [self.nRadialFunctions])
        # rRemainderNormalized = torch.outer(rNormalized - indicesLow, torch.ones((self.nRadialFunctions,)))

        lowSamples = self.fSamples[indicesLow, :]
        highSamples = self.fSamples[indicesLow + 1, :]

        ret = lowSamples * (1 - rRemainderNormalized) + highSamples * rRemainderNormalized

        return ret


if __name__ == "__main__":
    radialCutoff = 3.5
    nRadialFunctions = 14

    fixedCosineRadialModel = FixedCosineRadialModel(radialCutoff, nRadialFunctions)
    cosineModel = CosineFunctions(nRadialFunctions, radialCutoff)
    cosineModelFaded = FadeAtCutoff(cosineModel, radialCutoff)

    r = torch.linspace(0, radialCutoff, 15)
    y1 = fixedCosineRadialModel(r)
    y2 = cosineModel(r)
    y3 = cosineModelFaded(r)

    onRadialModel = OrthonormalRadialFunctions(nRadialFunctions, cosineModelFaded, radialCutoff, 100)
    r = torch.linspace(0, radialCutoff, 51, requires_grad=True)
    y = onRadialModel(r)
    g = torch.zeros_like(y)
    g[:, -1] += 1
    g2 = y.backward(gradient=g)

    torch.set_printoptions(linewidth=10000)
    print(y)
    print(r.grad)

    print(y.shape)
