import torch
from torch import nn
from abc import ABC, abstractmethod

class Exptest(nn.Module):

    def __init__(self, data, num_centers):
        super().__init__()
        datatype = torch.float32
        q = torch.quantile(data.to(datatype), torch.linspace(0,1,num_centers+1, dtype=datatype, device=data.device))
        self.register_buffer('quantiles', q)
        self.register_buffer('centers_list', 0.5 * (self.quantiles[1:] + self.quantiles[:-1]))
        self.register_buffer('sigmas_list', torch.diff(self.quantiles)/2)

    def rescale(self, x):
        return (x[:, None] - self.centers_list[:, None, None, None] ) / self.sigmas_list[:, None, None, None]

    def __call__(self, x):
        x_r = self.rescale(x)
        return torch.exp(-x_r**2 / 2)

    def derivative(self, x):
        x_r = self.rescale(x)
        return -x_r*torch.exp(-x_r**2 / 2) / self.sigmas_list[:,None, None, None]
    
    def laplacian(self, x):
        x_r = self.rescale(x)
        return (x_r**2-1)*torch.exp(-x_r**2 / 2) / self.sigmas_list[:,None, None, None]**2


class ActivationFamily(nn.Module,ABC):

    def __init__(self, centers_list):
        super().__init__()
        self.register_buffer('centers_list', centers_list)
        self.num_functions = len(centers_list)
    
    @abstractmethod
    def _function(self, x): 
        raise NotImplementedError
    
    @abstractmethod
    def _derivative(self, x):
        raise NotImplementedError

    @abstractmethod
    def _laplacian(self, x):
        raise NotImplementedError

    @abstractmethod 
    def apply(self, x, function):
        raise NotImplementedError

    def __call__(self, x):
        """
        (N, C, W, H) -> (N, M, C, W, H)
        """
        return self.apply(self._function, x) #(N, M, C, W, H)
    
    def derivative(self, x):
        """
        (N, C, W, H) -> (N, M, C, W, H)
        """
        return self.apply(self._derivative, x) #(N, M, C, W, H)

    def laplacian(self, x):
        """
        (N, C, W, H) -> (N, M, C, W, H)
        """
        return self.apply(self._laplacian, x) #(N, M, C, W, H)

class QuarticWindowedActivation(ActivationFamily):

    def __init__(self, window):
        super().__init__(torch.zeros(1)) # only one center
        self.window = window
    
    def w(self,x):
        """
        Indicator of x \in [-window, window]
        """
        return torch.abs(x) < self.window

    def _function(self,t):
        return (t**4) * self.w(t)

    def _derivative(self,t):
        return 4*(t**3) * self.w(t)

    def _laplacian(self, t):
        return 12*(t**2) * self.w(t)
    
    def apply(self,function,x):
        return function(x)[:, None]

class GaussianActivations(ActivationFamily):
    """
        The elementary function for the scalar potential ansatz : 
        f: t -> exp(-t^2 / 2 window_std^2)

        methods: 
        - call f(x)
        - derivative f'(x)
        - laplacian f''(x) 
    """

    def __init__(self, centers_list, window_std):
        super().__init__(centers_list)
        self.window_std = window_std

    def apply(self, function, x):
        x_centered = (x[:, None] - self.centers_list[:, None, None, None])
        return function(x_centered)

    def _function(self, t):
        return torch.exp(-t**2 / (2 * self.window_std ** 2))
    
    def _derivative(self, t):
        p = t / self.window_std ** 2
        return - p * self._function(t)

    def _laplacian(self, t):
        """ Second derivative. """
        p = ((t ** 2 / self.window_std ** 2) - 1) / (self.window_std ** 2)
        return p*self._function(t)


class SplineActivations(ActivationFamily):
    """
    Non-symmetric splines adapted to non-regular grid of centers. The elementary spline 
    is the unique deg-3 polynomial p such that p(0)=1, p(1)=0 and p'(0)=p'(1)=0.
    
    """
    def __init__(self, centers_list):

        super().__init__(centers_list)
        append = torch.tensor([2*centers_list[-1] - centers_list[-2]])
        prepend = torch.tensor([2*centers_list[0] - centers_list[1]])
        self.diff = torch.diff( centers_list, append=append, prepend=prepend)

    def apply(self, function, x):
        x_centered = x[:, None] - self.centers_list[:, None, None, None] #(N, M, C, W, H)
        x_std = (x_centered > 0) * ( 
            self.diff[None, 1:, None, None, None] - self.diff[None, :-1, None, None, None]
            ) + self.diff[None, :-1, None, None, None]
        return function(x_centered / x_std) #(N, M, C, W, H)

    def _function(self, x):
        xx = torch.abs(x)
        s = xx < 1
        t = -6*xx**5 + 15*xx**4 - 10*xx**3 + 1
        return s * t
    
    def _derivative(self, x):
        xx = torch.abs(x)
        r = torch.sign(x)
        s = xx < 1
        t = -30*xx**4 + 60*xx**3 - 30*xx**2
        return r * s * t

    def _laplacian(self, x):
        xx = torch.abs(x)
        s = xx < 1
        t = -120*xx**3 + 180*xx**2 - 60*xx
        return s * t


def compute_quantile_centers(data, num_centers):
    """
    Computes the center of the bins defined by the m-th quantiles. 
    Returns array with dim (num_centers)
    """
    quantiles = torch.quantile(data, torch.linspace(0,1,num_centers+1))
    quantiles_centers = 0.5 * (quantiles[1:] + quantiles[:-1])
    return quantiles_centers

class AdaptativeSplineActivations(SplineActivations):

    def __init__(self, data, num_centers):
        centers = compute_quantile_centers(data, num_centers)
        super().__init__(centers)


# old implementation

class SmoothActivationFunction:
    """
        The elementary function for the scalar potential ansatz : 
        f: t -> exp(-t^2 / 2 window_std^2)

        methods: 
        - call f(x)
        - derivative f'(x)
        - laplacian f''(x) 
    """
    def __init__(self, window_std):
        self.window_std = window_std

    def __call__(self, t):
        """
        The elementary function for the potential ansatz.
        t -> exp(-t^2 / 2 window_std^2)
        """
        return torch.exp(-t**2 / (2 * self.window_std ** 2))
    
    def derivative(self, t):
        p = t / self.window_std ** 2
        return - p * self(t)

    def laplacian(self, t):
        """ Second derivative. """
        p = ((t ** 2 / self.window_std ** 2) - 1) / (self.window_std ** 2)
        return p*self(t)