import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from . import monotone_1d_spline as spline1d

from abc import ABC, abstractmethod

class VectorField(ABC):

    @abstractmethod
    def __call__(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Evaluate the vector field at (x, y)
        :param X: (x,y) coordinates, shape (batch_size, 2)
        :return: (df/dx, df/dy), shape (batch_size, 2)
        """
        pass

    def divergence(self, X: torch.Tensor) -> torch.Tensor:
        """
        Evaluate the divergence of the vector field at (x, y)
        :param X: (x,y) coordinate, shape (batch_size, 2)
        :return: div(f), shape (batch_size)
        """
        jac = self.jacobian(X)
        return jac[..., 0, 0] + jac[..., 1, 1]

    @abstractmethod
    def jacobian(self, X: torch.Tensor) -> torch.Tensor:
        """
        Evaluate the Jacobian of the vector field at (x, y)
        :param X: (x,y) coordinate, shape (batch_size, 2)
        :return: Jacobian matrix, shape (batch_size, 2, 2)
        """
        pass
    
    def eulerIntegratorWithJacobian(self, y0: torch.Tensor, t_span: tuple[float, float]) -> torch.Tensor:
        """
        Integrates the system forward in time using the Euler method
        :param V: differentiable vector field. Differentiable function from R^2->R^2
        :param y0: initial positions, shape (*, 2)
        :param t_span: time span
        :return: (yt, jac): (final positions, shape (*, 2), Jacobian of the deformation, shape(*, 2, 2))
        """
        t0, t1 = t_span
        t = t0
        y = y0
        h = 0.001

        if t0 > t1:
            h = -h

        ts = torch.arange(t0, t1, h)

        jac = torch.zeros(y0.shape[:-1] + (2, 2), device=y0.device)
        jac[..., 0, 0] = 1
        jac[..., 1, 1] = 1
        I = jac.clone()
        for t in ts:
            dj = I + h * self.jacobian(y)
            jac = dj @ jac
            y = y + h * self(y)
        return y, jac

    def eulerIntegrator(self, y0: torch.Tensor, t_span: tuple[float, float]) -> torch.Tensor:
        """
        Integrates the system forward in time using the Euler method
        :param V: differentiable vector field. Differentiable function from R^2->R^2
        :param y0: initial positions, shape (*, 2)
        :param t_span: time span
        :return: (yt, det): (final positions, shape (*, 2), determinant of the Jacobian of the deformation)
        """
        t0, t1 = t_span
        t = t0
        y = y0
        h = 0.001

        if t0 > t1:
            h = -h

        ts = torch.arange(t0, t1, h)

        log_det = torch.zeros(y0.shape[:-1], device=y0.device)
        for t in ts:
            log_det = log_det + h * self.divergence(y)
            y = y + h * self(y)
            t = t + h
        return y, torch.abs(torch.exp(log_det))

    def integrate(self, X: torch.Tensor, t: float, computeJacobian=True) -> torch.Tensor:
        """
        Integrate the vector field forward in time
        :param X: initial positions, shape (batch_size, 2)
        :param t: time
        :return: final positions, shape (batch_size, 2)
        """
        d = X.shape[-1]
        assert d == 2
        if computeJacobian:
            X_t, det_or_jac = self.eulerIntegratorWithJacobian(X, (0, t))
        else:
            X_t, det_or_jac = self.eulerIntegrator(X, (0, t))
        return X_t, det_or_jac


class CubicBumpVectorField(VectorField):
    """
    Vector field defined as the gradient of a cubic bump function
    f(x, y) = (b-sqrt(x² + y²))³(b + sqrt(x² + y²))³ / b^6
    """
    def __init__(self, h):
        """
        :param h: radius of the bump
        """
        super().__init__()
        self.h = h

    def __call__(self, X):
        x = X[..., 0]
        y = X[..., 1]
        h = self.h
        r = x**2 + y**2
        mask = r < h**2

        dx = torch.nan_to_num(-(6*x*(r - h**2)**2) / h**6)
        dy = torch.nan_to_num(-(6*y*(r - h**2)**2) / h**6)

        return torch.stack((mask * dx, mask * dy), dim=-1)

    def divergence(self, X):
        x = X[..., 0]
        y = X[..., 1]
        h = self.h
        r = x**2 + y**2
        mask = r < h**2

        num = -12 * (h**4 - ((4*h**2) * (r)) + 3*(r)**2)
        den = h**6

        return mask * torch.nan_to_num(num / den)

    def jacobian(self, X):
        x = X[..., 0]
        y = X[..., 1]
        h = self.h
        r = x**2 + y**2
        mask = r < h**2

        j11 = -(24 * x**2 * (r - h**2)) / h**6 - (6 * (r - h**2)**2) / h**6
        j22 = -(24 * y**2 * (r - h**2)) / h**6 - (6 * (r - h**2)**2) / h**6
        j12 = -(24 * x * y * (r - h**2)) / h**6
        j21 = -(24 * x * y * (r - h**2)) / h**6

        j11 *= mask
        j22 *= mask
        j12 *= mask
        j21 *= mask

        return torch.stack((torch.stack((mask * j11, mask * j12), dim=-1),
                            torch.stack((mask * j21, mask * j22), dim=-1)), dim=-2)


class Deformation(ABC):
    """
    Abstract class for deformations
    """
    @abstractmethod
    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """
        Deform the input tensor
        :param X: input tensor, shape (*, 2)
        :param theta: parameter of deformation
        :return: (deformed tensor, jacobian matrix, shape ((*, 2), (*, 2, 2))
        """
        pass

    @abstractmethod
    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """
        Compute the derivative of the deformation with respect to theta
        :param X: input tensor, shape (*, 2)
        :param theta: parameter of deformation
        :return: derivative of the deformation with respect to theta, shape (*, 2, dim(theta))
        """
        pass

    @abstractmethod
    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        pass

    @abstractmethod
    def numParameters(self) -> int:
        """
        :return: number of parameters of the deformation
        """
        pass

    @abstractmethod
    def getNeutralParameter(self) -> torch.Tensor:
        """
        :return: neutral parameter of the deformation
        """
        pass

    def inverse(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """
        Computes the inverse of the deformation. Uses Newton method if not overridden
        :param X: input tensor, shape (*, 2)
        :param theta: parameter of deformation
        :return: inverse of the deformation
        """
        x = X.clone()

        for i in range(5):
            curr, jac = self(x, theta)
            jac = torch.linalg.inv(jac)
            x = x - torch.matmul(jac, curr.unsqueeze(-1) - X.unsqueeze(-1)).squeeze(-1)
        return x

class VectorFieldDeformation(Deformation):
    """
    Deformation defined by flow along a vector field
    """
    def __init__(self, V: VectorField):
        super().__init__()
        self.V = V

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        X_t, det = self.V.integrate(X, theta.item())
        return X_t, det

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        _, jac = self.V.integrate(X, theta.item(), computeJacobian=True)
        return jac

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self.V(X)

    def numParameters(self) -> int:
        return 1

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.Tensor([0.0])
    
    def inverse(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self(X, -theta)[0]

class CubicDeformation(Deformation):

    def __init__(self, h):
        super().__init__()
        self.h = h

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        # transform into polar coordinates
        r = torch.norm(X, dim=-1)

        mask = r < self.h

        r_t = mask * theta * (self.h**2 - 2*self.h * r + r**2)

        r_t = torch.stack((r_t * X[..., 0] + X[..., 0], r_t * X[..., 1] + X[..., 1]), dim=-1)

        return r_t, self.jacobian(X, theta)

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:

        r = torch.norm(X, dim=-1)

        mask = r < self.h

        j11 = theta * (self.h**2 - 2*self.h * r + r**2) + theta * X[..., 0] * (2*X[..., 0] - (2*self.h * X[..., 0]) / r) + 1
        j22 = theta * (self.h**2 - 2*self.h * r + r**2) + theta * X[..., 1] * (2*X[..., 1] - (2*self.h * X[..., 1]) / r) + 1
        j12 = theta * X[..., 0] * (2*X[..., 1] - (2*self.h * X[..., 1]) / r)
        j21 = theta * X[..., 1] * (2*X[..., 0] - (2*self.h * X[..., 0]) / r)

        j11[~mask] = 1
        j22[~mask] = 1
        j12[~mask] = 0
        j21[~mask] = 0

        return torch.stack((torch.stack((j11, j12), dim=-1), torch.stack((j21, j22), dim=-1)), dim=-2)

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        
        r = torch.norm(X, dim=-1)

        mask = r < self.h

        dfdt = torch.zeros_like(X)

        dfdt[..., 0] = mask * (self.h**2 - 2*self.h * r + r**2) * X[..., 0]
        dfdt[..., 1] = mask * (self.h**2 - 2*self.h * r + r**2) * X[..., 1]

        return dfdt

    def numParameters(self) -> int:
        return 1

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.tensor([1.0])

class PowerDeformation(Deformation):

    def __init__(self, h):
        super().__init__()
        self.h = 1

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:

        r = torch.norm(X, dim=-1)

        mask = r < self.h

        r_t = mask * r ** theta

        r_t = torch.stack((r_t * X[..., 0] / r + ~mask*X[..., 0], r_t * X[..., 1] / r + ~mask*X[..., 1]), dim=-1)

        return r_t, self.jacobian(X, theta)

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        
        r = torch.norm(X, dim=-1)
        r2 = r**2

        mask = r < self.h

        j11 = 2*(theta/2 - 1/2) * X[..., 0]**2 * r2**(theta/2 - 3/2) + r2**(theta/2 - 1/2)
        j22 = 2*(theta/2 - 1/2) * X[..., 1]**2 * r2**(theta/2 - 3/2) + r2**(theta/2 - 1/2)
        j12 = 2*(theta/2 - 1/2)*X[..., 0]*X[..., 1]*r2**(theta/2 - 3/2)
        j21 = 2*(theta/2 - 1/2)*X[..., 0]*X[..., 1]*r2**(theta/2 - 3/2)

        j11[~mask] = 1
        j22[~mask] = 1
        j12[~mask] = 0
        j21[~mask] = 0

        return torch.stack((torch.stack((j11, j12), dim=-1), torch.stack((j21, j22), dim=-1)), dim=-2)

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        
        dfdt = torch.zeros_like(X)

        r = torch.norm(X, dim=-1)
        r2 = r**2

        mask = r < self.h

        dfdt[..., 0] = mask * 0.5 * X[..., 0]*r2**((theta - 1)/2)*torch.log(r2)
        dfdt[..., 1] = mask * 0.5 * X[..., 1]*r2**((theta - 1)/2)*torch.log(r2)

        return dfdt

    def numParameters(self) -> int:
        return 1

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.Tensor([1.0])

class HalfNormalTunableSigmoidDeformation(Deformation):
    """
    From https://math.stackexchange.com/questions/459872/adjustable-sigmoid-curve-s-curve-from-0-0-to-1-1
    Also see https://dhemery.github.io/DHE-Modules/technical/sigmoid/
    Here we actually only need the upper half of the sigmoid
    """

    def __init__(self, h) -> None:
        super().__init__()
        self.h = h

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r = torch.norm(X, dim=-1)

        mask = r < self.h

        r_t = mask * (theta*r - r) / (2*theta*r - theta - 1)

        r_t = torch.stack((r_t * X[..., 0] / r + ~mask*X[..., 0], r_t * X[..., 1] / r + ~mask*X[..., 1]), dim=-1)

        return r_t, self.jacobian(X, theta)

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r= torch.norm(X, dim=-1)
        r2 = r**2

        x = X[..., 0]
        y = X[..., 1]
        t = theta

        mask = r < self.h

        # thanks, wolfram alpha
        j11 = -(x**2 * (t*r2 - x**2 - y**2))/(r**3 * (2*t*r2 - t - 1)) - (4*x**2*t*(t*r2 - x**2 - y**2)) / (r*(2*t*r2 - t - 1)**2)
        j11 += (x*(2*x*t - 2*x))/(r*(2*t*r2 - t - 1)) + (t*r2 - x**2 - y**2)/(r*(2*t*r2 - t - 1))

        j22 = -(y**2 * (t*r2 - x**2 - y**2))/(r**3 * (2*t*r2 - t - 1)) - (4*y**2*t*(t*r2 - x**2 - y**2)) / (r*(2*t*r2 - t - 1)**2)
        j22 += (y*(2*y*t - 2*y))/(r*(2*t*r2 - t - 1)) + (t*r2 - x**2 - y**2)/(r*(2*t*r2 - t - 1))

        j12 = (x*(2*y*t - 2*y)) / (r*(2*t*r2 - t - 1)) - (x*y*(t*r2 - x**2 - y**2))/(r**3 * (2*t*r2 - t - 1)) - (4*x*y*t*(t*r2 - x**2 - y**2))/(r*(2*t*r2 - t - 1)**2)
        j21 = (y*(2*x*t - 2*x)) / (r*(2*t*r2 - t - 1)) - (x*y*(t*r2 - x**2 - y**2))/(r**3 * (2*t*r2 - t - 1)) - (4*x*y*t*(t*r2 - x**2 - y**2))/(r*(2*t*r2 - t - 1)**2)

        j11[~mask] = 1
        j22[~mask] = 1
        j12[~mask] = 0
        j21[~mask] = 0

        return torch.stack((torch.stack((j11, j12), dim=-1), torch.stack((j21, j22), dim=-1)), dim=-2)

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r= torch.norm(X, dim=-1)

        x = X[..., 0]
        y = X[..., 1]
        t = theta

        dfdt = torch.zeros_like(X)

        mask = r < self.h

        dfdt[..., 0] = (2*x*(x**4 + x**2 *(2*y**2 - 1) + y**2*(y**2 - 1)))/(r * (t*(2*x**2 + 2*y**2 -1) - 1)**2)
        dfdt[..., 1] = (2*y*(x**4 + x**2 *(2*y**2 - 1) + y**2*(y**2 - 1)))/(r * (t*(2*x**2 + 2*y**2 -1) - 1)**2)

        dfdt[~mask] = 0

        return dfdt

    def numParameters(self) -> int:
        return 1

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.Tensor([1.0])

class AnisotropicHalfNormalTunableSigmoidDeformation(Deformation):
    """
    From https://math.stackexchange.com/questions/459872/adjustable-sigmoid-curve-s-curve-from-0-0-to-1-1
    Also see https://dhemery.github.io/DHE-Modules/technical/sigmoid/
    Here we actually only need the upper half of the sigmoid
    """

    def __init__(self, h) -> None:
        super().__init__()
        self.h = h

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r = torch.norm(X, dim=-1)

        mask = r < self.h
        s = theta[0]
        t = theta[1]

        r_t_x = mask * (s - 1) / (2*s*r - s - 1)
        r_t_y = mask * (t - 1) / (2*t*r - t - 1)

        r_t = torch.stack((r_t_x * X[..., 0] + ~mask*X[..., 0], r_t_y * X[..., 1] + ~mask*X[..., 1]), dim=-1)

        return r_t, self.jacobian(X, theta)

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r= torch.norm(X, dim=-1)
        r2 = r**2

        x = X[..., 0]
        y = X[..., 1]
        s = theta[0]
        t = theta[1]

        mask = r < self.h

        j11 = torch.nan_to_num((-2*s*x**2*(s - 1)*r2**(7/2) + (1 - s)*r2**4*(-2*s*r + s + 1))/(r2**4*(-2*s*r + s + 1)**2), nan=1.0, posinf=1.0, neginf=1.0)
        j22 = torch.nan_to_num((-2*t*y**2*(t - 1)*r2**(7/2) + (1 - t)*r2**4*(-2*t*r + t + 1))/(r2**4*(-2*t*r + t + 1)**2), nan=1.0, posinf=1.0, neginf=1.0)
        j12 = torch.nan_to_num(-2*s*x*y*(s - 1)/(r*(-2*s*r + s + 1)**2), nan=0.0, posinf=0.0, neginf=0.0)
        j21 = torch.nan_to_num(-2*t*x*y*(t - 1)/(r*(-2*t*r + t + 1)**2), nan=0.0, posinf=0.0, neginf=0.0)

        j11[~mask] = 1
        j22[~mask] = 1
        j12[~mask] = 0
        j21[~mask] = 0

        return torch.stack((torch.stack((j11, j12), dim=-1), torch.stack((j21, j22), dim=-1)), dim=-2)

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r= torch.norm(X, dim=-1)

        x = X[..., 0]
        y = X[..., 1]

        s = theta[0]
        t = theta[1]

        dfdt = torch.zeros((*X.shape, self.numParameters()), device=X.device)

        mask = r < self.h

        dfdt[..., 0, 0] = torch.nan_to_num(2*x*(x**4 + x**2 *(2*y**2 - 1) + y**2*(y**2 - 1)))/(r * (s*(2*x**2 + 2*y**2 -1) - 1)**2)
        dfdt[..., 1, 1] = torch.nan_to_num(2*y*(x**4 + x**2 *(2*y**2 - 1) + y**2*(y**2 - 1)))/(r * (t*(2*x**2 + 2*y**2 -1) - 1)**2)
        dfdt[..., 0, 1] = 0
        dfdt[..., 1, 0] = 0

        dfdt[~mask] = 0

        return dfdt

    def numParameters(self) -> int:
        return 2

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.Tensor([0.0, 0.0])

class AnisotropicEllipticHalfNormalTunableSigmoidDeformation(Deformation):
    """
    From https://math.stackexchange.com/questions/459872/adjustable-sigmoid-curve-s-curve-from-0-0-to-1-1
    Also see https://dhemery.github.io/DHE-Modules/technical/sigmoid/
    Here we actually only need the upper half of the sigmoid
    """

    def __init__(self, h) -> None:
        super().__init__()
        self.h = h

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:

        p = theta[2]
        q = theta[3]
        r = torch.sqrt(p*X[..., 0]**2 + q*X[..., 1]**2)

        mask = r < self.h
        s = theta[0]
        t = theta[1]

        r_t_x = mask * (s*r - r) / (2*s*r - s - 1)
        r_t_y = mask * (t*r - r) / (2*t*r - t - 1)

        r_t = torch.stack((r_t_x * X[..., 0] / r + ~mask*X[..., 0], r_t_y * X[..., 1] / r + ~mask*X[..., 1]), dim=-1)

        return r_t, self.jacobian(X, theta)

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        p = theta[2]
        q = theta[3]
        r = torch.sqrt(p*X[..., 0]**2 + q*X[..., 1]**2)
        r2 = r**2

        x = X[..., 0]
        y = X[..., 1]
        s = theta[0]
        t = theta[1]

        mask = r < self.h

        j11 = torch.nan_to_num((-2*p*s*x**2*(s - 1)*r2**(7/2) + (1 - s)*r2**4*(-2*s*r + s + 1))/(r2**4*(-2*s*r + s + 1)**2), nan=1.0, posinf=1.0, neginf=1.0)
        j12 = torch.nan_to_num(-2*q*s*x*y*(s - 1)/(r*(-2*s*r + s + 1)**2), nan=0.0, posinf=0.0, neginf=0.0)
        j21 = torch.nan_to_num(-2*p*t*x*y*(t - 1)/(r*(-2*t*r + t + 1)**2), nan=0.0, posinf=0.0, neginf=0.0)
        j22 = torch.nan_to_num((-2*q*t*y**2*(t - 1)*r2**(7/2) + (1 - t)*r2**4*(-2*t*r + t + 1))/(r2**4*(-2*t*r + t + 1)**2), nan=1.0, posinf=1.0, neginf=1.0)

        j11[~mask] = 1
        j22[~mask] = 1
        j12[~mask] = 0
        j21[~mask] = 0

        return torch.stack((torch.stack((j11, j12), dim=-1), torch.stack((j21, j22), dim=-1)), dim=-2)

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        p = theta[2]
        q = theta[3]
        r = torch.sqrt(p*X[..., 0]**2 + q*X[..., 1]**2)
        r2 = r**2

        x = X[..., 0]
        y = X[..., 1]

        s = theta[0]
        t = theta[1]

        dfdt = torch.zeros((*X.shape, self.numParameters()), device=X.device)

        mask = r < self.h

        df1dp = -s*x**3*(s*r - r)/(r2*(2*s*r - s - 1)**2) - 1/2*x**3*(s*r - r)/(r2**(3/2)*(2*s*r - s - 1)) + x*((1/2)*s*x**2/r - 1/2*x**2/r)/(r*(2*s*r - s - 1))
        df2dp = -t*x**2*y*(t*r - r)/(r2*(2*t*r - t - 1)**2) - 1/2*x**2*y*(t*r - r)/(r2**(3/2)*(2*t*r - t - 1)) + y*((1/2)*t*x**2/r - 1/2*x**2/r)/(r*(2*t*r - t - 1))
        
        df1dq = -s*x*y**2*(s*r - r)/(r2*(2*s*r - s - 1)**2) - 1/2*x*y**2*(s*r - r)/(r2**(3/2)*(2*s*r - s - 1)) + x*((1/2)*s*y**2/r - 1/2*y**2/r)/(r*(2*s*r - s - 1))
        df2dq = -t*y**3*(t*r - r)/(r2*(2*t*r - t - 1)**2) - 1/2*y**3*(t*r - r)/(r2**(3/2)*(2*t*r - t - 1)) + y*((1/2)*t*y**2/r - 1/2*y**2/r)/(r*(2*t*r - t - 1))

        dfdt[..., 0, 0] = x*(1 - 2*r)*(s*r - r)/(r*(2*s*r - s - 1)**2) + x/(2*s*r - s - 1) # df1ds
        dfdt[..., 0, 1] = 0 # df1dt

        dfdt[..., 0, 2] = df1dp
        dfdt[..., 0, 3] = df1dq

        dfdt[..., 1, 0] = 0 # df2ds
        dfdt[..., 1, 1] = x*(1 - 2*r)*(t*r - r)/(r*(2*t*r - t - 1)**2) + x/(2*t*r - t - 1) # df2dt

        dfdt[..., 1, 2] = df2dp
        dfdt[..., 1, 3] = df2dq

        dfdt[~mask] = 0

        return dfdt

    def numParameters(self) -> int:
        return 4

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.Tensor([0.0, 0.0, 1.0, 1.0])
    
class IndependentAnisotropicHalfNormalTunableSigmoidDeformation(Deformation):
    """
    From https://math.stackexchange.com/questions/459872/adjustable-sigmoid-curve-s-curve-from-0-0-to-1-1
    Also see https://dhemery.github.io/DHE-Modules/technical/sigmoid/
    Here we actually only need the upper half of the sigmoid
    """

    def __init__(self, h) -> None:
        super().__init__()
        self.h = h

    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        
        r_x = torch.abs(X[..., 0])
        r_y = torch.abs(X[..., 1])

        s = theta[0]
        t = theta[1]

        mask_x = r_x < self.h
        mask_y = r_y < self.h

        r_t_x = mask_x * (s - 1) / (2*s*r_x - s - 1)
        r_t_y = mask_y * (t - 1) / (2*t*r_y - t - 1)

        r_t = torch.stack((r_t_x * X[..., 0] + ~mask_x*X[..., 0], r_t_y * X[..., 1] + ~mask_y*X[..., 1]), dim=-1)

        return r_t, self.jacobian(X, theta)

    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        
        r_x = torch.abs(X[..., 0])
        r_y = torch.abs(X[..., 1])

        x = X[..., 0]
        y = X[..., 1]
        s = theta[0]
        t = theta[1]

        mask_x = r_x < self.h
        mask_y = r_y < self.h

        j11 = torch.nan_to_num(
            -2*s*x*(s*r_x - r_x)*torch.sign(x)/((2*s*r_x - s - 1)**2*r_x) +
            x*(s*torch.sign(x) - torch.sign(x))/((2*s*r_x - s - 1)*r_x) +
            (s*r_x - r_x)/((2*s*r_x - s - 1)*r_x) -
            (s*r_x - r_x)*torch.sign(x)/(x*(2*s*r_x - s - 1)), nan=1, posinf=1, neginf=1)
        j22 = torch.nan_to_num(
            -2*t*y*(t*r_y - r_y)*torch.sign(y)/((2*t*r_y - t - 1)**2*r_y) +
            y*(t*torch.sign(y) - torch.sign(y))/((2*t*r_y - t - 1)*r_y) +
            (t*r_y - r_y)/((2*t*r_y - t - 1)*r_y) -
            (t*r_y - r_y)*torch.sign(y)/(y*(2*t*r_y - t - 1)), nan=1, posinf=1, neginf=1)
        j21 = torch.zeros_like(j11)
        j12 = torch.zeros_like(j11)

        j11[~mask_x] = 1
        j22[~mask_y] = 1

        return torch.stack((torch.stack((j11, j12), dim=-1), torch.stack((j21, j22), dim=-1)), dim=-2)

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        r_x = torch.abs(X[..., 0])
        r_y = torch.abs(X[..., 1])

        x = X[..., 0]
        y = X[..., 1]

        s = theta[0]
        t = theta[1]

        dfdt = torch.zeros((*X.shape, self.numParameters()), device=X.device)

        mask_x = r_x < self.h
        mask_y = r_y < self.h

        dfdt[..., 0, 0] = mask_x * torch.nan_to_num(x*(1 - 2*torch.abs(x))*(s*torch.abs(x) - torch.abs(x))/((2*s*torch.abs(x) - s - 1)**2*torch.abs(x)) + x/(2*s*torch.abs(x) - s - 1))
        dfdt[..., 0, 1] = 0
        dfdt[..., 1, 1] = mask_y * torch.nan_to_num(y*(1 - 2*torch.abs(y))*(t*torch.abs(y) - torch.abs(y))/((2*t*torch.abs(y) - t - 1)**2*torch.abs(y)) + y/(2*t*torch.abs(y) - t - 1))
        dfdt[..., 1, 0] = 0

        return dfdt

    def numParameters(self) -> int:
        return 2

    def getNeutralParameter(self) -> torch.Tensor:
        return torch.Tensor([0.0, 0.0])
    
    def inverse(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        return self(X, -theta)[0]
    
class Monotone_Linear_Spline_Deformation(Deformation):
    """
    class for deformations via two independent monotone linear splines
    """
    def __init__(self, sizeTheta) -> None:
        super().__init__()
        assert sizeTheta%4==0
        self.sizeTheta = sizeTheta  ## the size of theta determines the number of kinks in each of the two monotone splines. Within each spline it parameterizes the points of the kinks as well as their values. Thus, it has to be dividable by 4.
    
    
    def __call__(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """
        Deform the input tensor
        :param X: input tensor, shape (*, 2)
        :param theta: parameter of deformation
        :return: (deformed tensor, jacobian matrix, shape ((*, 2), (*, 2, 2))
        """
        assert self.sizeTheta==theta.shape[0]

        X = 0.5*X + 0.5
        firstDims = X.shape[:-1]
        X = X.view(-1,2)
        batchsize = X.shape[0]
        
        deformed_x,derivative_x = spline1d.evaluate(X[:,0], theta[0:int(self.sizeTheta/2)])
        deformed_y,derivative_y = spline1d.evaluate(X[:,1], theta[int(self.sizeTheta/2):])
        
        deformed_X = torch.cat((deformed_x.view(-1,1),deformed_y.view(-1,1)), dim=1)
        jacobian = torch.zeros(batchsize,2,2, device=X.device)
        jacobian[:,0,0] = derivative_x
        jacobian[:,1,1] = derivative_y

        deformed_X = deformed_X.view(*firstDims,2)
        jacobian = jacobian.view(*firstDims,2,2)

        deformed_X = 2*deformed_X-1

        return deformed_X, jacobian
        

    def dfdt(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """
        Compute the derivative of the deformation with respect to theta
        :param X: input tensor, shape (*, 2)
        :param theta: parameter of deformation
        :return: derivative of the deformation with respect to theta, shape (*, 2, dim(theta))
        """
        assert self.sizeTheta==theta.shape[0]
        X = 0.5*X + 0.5

        firstDims = X.shape[:-1]
        X = X.view(-1,2)
        dfdt_x = spline1d.dfdt(X[:,0], theta[0:int(self.sizeTheta/2)])
        dfdt_y = spline1d.dfdt(X[:,1], theta[int(self.sizeTheta/2):])
        
        der = torch.zeros(X.shape[0], 2,self.sizeTheta, device=X.device)
        der[:,0,0:int(self.sizeTheta/2)]=dfdt_x
        der[:,1,int(self.sizeTheta/2):]=dfdt_y

        der = der.view(*firstDims,2,self.sizeTheta)
        
        return der


    def jacobian(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        assert self.sizeTheta==theta.shape[0]
        X = 0.5*X + 0.5
        firstDims = X.shape[:-1]
        X = X.view(-1,2)
        batchsize = X.shape[0]
        
        deformed_X = torch.zeros_like(X)
        deformed_x,derivative_x = spline1d.evaluate(X[:,0], theta[0:int(self.sizeTheta/2)])
        deformed_y,derivative_y = spline1d.evaluate(X[:,1], theta[int(self.sizeTheta/2):])
        
        jacobian = torch.zeros(batchsize,2,2, device=X.device)
        jacobian[:,0,0] = derivative_x
        jacobian[:,1,1] = derivative_y

        jacobian = jacobian.view(*firstDims,2,2)
        return jacobian

    def numParameters(self) -> int:
        """
        :return: number of parameters of the deformation
        """
        return self.sizeTheta

    def getNeutralParameter(self) -> torch.Tensor:
        """
        :return: neutral parameter of the deformation
        """
        return torch.zeros(self.sizeTheta) ## requires_grad = True?? No

    def inverse(self, X: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        """
        Computes the inverse of the deformation. Uses Newton method if not overridden
        :param X: input tensor, shape (*, 2)
        :param theta: parameter of deformation
        :return: inverse of the deformation
        """
        X = 0.5*X + 0.5
        firstDims = X.shape[:-1]
        X = X.view(-1,2)
        deformed_x,derivative_inv_x = spline1d.evaluateInverse(X[:,0], theta[0:int(self.sizeTheta/2)])
        deformed_y,derivative_inv_y = spline1d.evaluateInverse(X[:,1], theta[int(self.sizeTheta/2):])

        deformed = torch.cat((deformed_x.view(-1,1),deformed_y.view(-1,1)), dim=1)
        deformed = 2*deformed-1
        deformed = deformed.view(*firstDims,2)
        return deformed