import numpy as np
import torch
import random
from typing import Optional, Tuple
import numpy
from typing import *
import copy
import math
import logging
from models.GNOT.utils import MultipleTensors
logger = logging.getLogger(__name__)

def fourier_shift(u: torch.Tensor, eps: float=0., dim: int=-1, order: int=0) -> torch.Tensor:
    """
    Shift in Fourier space.
    Args:
        u (torch.Tensor): input tensor, usually of shape [batch, t, x]
        eps (float): shift parameter
        dim (int): dimension which is used for shifting
        order (int): derivative order
    Returns:
        torch.Tensor: Fourier shifted input
    """
    assert dim < 0
    n = u.shape[dim]
    u_hat = torch.fft.rfft(u, dim=dim, norm='ortho')
    # Fourier modes
    omega = torch.arange(n // 2 + 1)
    if n % 2 == 0:
        omega[-1] *= 0
    # Applying Fourier shift according to shift theorem
    fs = torch.exp(- 2 * np.pi * 1j * omega * eps)
    # For order>0 derivative is taken
    fs = (- 2 * np.pi * 1j * omega) ** order * fs
    for _ in range(-dim - 1):
        fs = fs[..., None]
    return torch.fft.irfft(fs * u_hat, n=n, dim=dim, norm='ortho')

def linear_shift(u: torch.Tensor, eps: float=0., dim:int=-1) -> torch.Tensor:
    """
    Linear shift.
    Args:
        u (torch.Tensor): input tensor, usually of shape [batch, t, x]
        eps (float): shift parameter
        dim (int): dimension which is used for shifting
    Returns:
        Linear shifted input
    """
    n = u.shape[dim]
    # Shift to the left and to the right and interpolate linearly
    q, r = torch.div(eps*n, 1, rounding_mode='floor'), (eps * n) % 1
    q_left, q_right = q/n, (q+1)/n
    u_left = fourier_shift(u, eps=q_left, dim=-1)
    u_right = fourier_shift(u, eps=q_right, dim=-1)
    return (1-r) * u_left + r * u_right

def to_coords(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    Transforms the coordinates to a tensor X of shape [time, space, 2].
    Args:
        x: spatial coordinates
        t: temporal coordinates
    Returns:
        torch.Tensor: X[..., 0] is the space coordinate (in 2D)
                      X[..., 1] is the time coordinate (in 2D)
    """
    x_, t_ = torch.meshgrid(x, t)
    x_, t_ = x_.T, t_.T
    return torch.stack((x_, t_), -1)


def translate(X, shift):
    return X + shift

def scale(X, scale):
    return X * scale

def rotate(X, theta):
    # rotate points in X wrt p counterclockwise by angle theta
    matrix = torch.tensor([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    return X @ matrix.T

def reflect(X, o):
    if o > 0:
        return -X
    else:
        return X

class BasicTransform:
    def __init__(self, always_apply=False, p=1.0):
        self.p = p
        self.always_apply = always_apply
    def _apply_transform(self, data):
        raise NotImplementedError
    def __call__(self, graph, u_p, inputs_f):

        p = random.random()
        if self.always_apply or (p > self.p):
            transformed = self._apply_transform(graph, u_p, inputs_f)
        else:
            transformed = (graph, u_p, inputs_f)
        return transformed
    
class ComposedTransform(BasicTransform):
    def __init__(self, transforms, *args):
        super(ComposedTransform, self).__init__(*args)
        self.transforms = transforms

    def _apply_transform(self, data):
        for t in self.transforms:
            data = t(data)
        return data

class Darcy2dTransform(BasicTransform):
    def __init__(self,
                 max_space_scale=[0.05, 1.0], 
                 max_value_scale=[0.05, 1.0],
                 always_apply=False,
                 p: float = 1.0):
        super(Darcy2dTransform, self).__init__(always_apply, p)
        self.max_space_scale = max_space_scale
        self.max_value_scale = max_value_scale
    
    def _apply_transform(self, graph, u_p, inputs_f):
        graph = copy.deepcopy(graph)
        u_p = copy.deepcopy(u_p)
        inputs_f = copy.deepcopy(inputs_f)

        #shift = torch.tensor([random.uniform(*self.max_shift[0]), random.uniform(*self.max_shift[1])])
        # rx = [graph.ndata["x"][:, 0].min(), graph.ndata["x"][:, 0].max()]
        # ry = [graph.ndata["x"][:, 1].min(), graph.ndata["x"][:, 1].max()]
        # shift = torch.tensor([-(rx[0] + rx[1]) * 0.5, -(ry[0] + ry[1]) * 0.5])
        # shift_back = torch.tensor([0.5, 0.5])
        #space_scale = torch.tensor(random.uniform(*self.max_space_scale))
        #value_scale = torch.tensor(random.uniform(*self.max_value_scale))
        #value_scale = torch.square(space_scale)
        theta = torch.tensor(random.uniform(0, 2*math.pi))
        #p = torch.tensor([0.0, 0.0])

        # shift and space scale
        graph.ndata["x"] = rotate(graph.ndata["x"], theta)
        #graph.ndata["y"] = scale(graph.ndata["y"], torch.square(space_scale))

        # qf
        inputs_f.x[0][:, [0, 1]] = rotate(inputs_f.x[0][:, [0, 1]], theta)
        # uD
        inputs_f.x[1][:, [0, 1]] = rotate(inputs_f.x[1][:, [0, 1]], theta)
        #inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], torch.square(space_scale))

        # value scale
        # inputs_f.x[0][:, 3] = scale(inputs_f.x[0][:, 3], value_scale)
        # inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], value_scale)
        # graph.ndata["y"] = scale(graph.ndata["y"], value_scale)

        return graph, u_p, inputs_f

class Laplace2dnTransform(BasicTransform):
    def __init__(self,
                 max_space_scale=[0.05, 1.0], 
                 max_value_scale=[0.05, 1.0],
                 always_apply=False,
                 p: float = 1.0):
        super(Laplace2dnTransform, self).__init__(always_apply, p)
        self.max_space_scale = max_space_scale
        self.max_value_scale = max_value_scale
    
    def _apply_transform(self, graph, u_p, inputs_f):
        graph = copy.deepcopy(graph)
        u_p = copy.deepcopy(u_p)
        inputs_f = copy.deepcopy(inputs_f)

        boundary_func = inputs_f.x[0]
        dirichlet_func = boundary_func[boundary_func[:, -1] == 0.0]
        neumann_func = boundary_func[boundary_func[:, -1] == 1.0]

        space_scale = torch.tensor(random.uniform(*self.max_space_scale))
        value_scale = torch.tensor(random.uniform(*self.max_value_scale))
        theta = torch.tensor(random.uniform(0, 2*math.pi))

        # # shift and space scale
        graph.ndata["x"] = scale(rotate(graph.ndata["x"], theta), space_scale)

        # uD and g
        dirichlet_func[:, [0, 1]] = scale(rotate(dirichlet_func[:, [0, 1]], theta), space_scale)
        neumann_func[:, [0, 1]] = scale(rotate(neumann_func[:, [0, 1]], theta), space_scale)
        # # g
        # inputs_f.x[1][:, [0, 1]] = translate(scale(rotate(translate(inputs_f.x[1][:, [0, 1]], shift), p, theta), space_scale), shift_back)
        
        if neumann_func.shape[0] != 0:
            dirichlet_func[:, 2] = scale(dirichlet_func[:, 2], space_scale)
            graph.ndata["y"] = scale(graph.ndata["y"], space_scale)

        # value scale
        #max_abs = max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min()), 1)
        graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        
        #value_scale = 1 / max_abs
        dirichlet_func[:, 2] = scale(dirichlet_func[:, 2], value_scale)
        neumann_func[:, 2] = scale(neumann_func[:, 2], value_scale)
        #inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], value_scale)
        #graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        #logger.info(max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min())))
        boundary_func = torch.cat([dirichlet_func, neumann_func])

        inputs_f = MultipleTensors([boundary_func])

        return graph, u_p, inputs_f

class Laplace2dTransform(BasicTransform):
    def __init__(self,
                 max_space_scale=[0.05, 1.0], 
                 max_value_scale=[0.05, 1.0],
                 always_apply=False,
                 p: float = 1.0):
        super(Laplace2dTransform, self).__init__(always_apply, p)
        self.max_space_scale = max_space_scale
        self.max_value_scale = max_value_scale
    
    def _apply_transform(self, graph, u_p, inputs_f):
        graph = copy.deepcopy(graph)
        u_p = copy.deepcopy(u_p)
        inputs_f = copy.deepcopy(inputs_f)

        boundary_func = inputs_f.x[0]
        dirichlet_func = boundary_func[boundary_func[:, -1] == 0.0]

        space_scale = torch.tensor(random.uniform(*self.max_space_scale))
        value_scale = torch.tensor(random.uniform(*self.max_value_scale))
        theta = torch.tensor(random.uniform(0, 2*math.pi))

        # # shift and space scale
        graph.ndata["x"] = scale(rotate(graph.ndata["x"], theta), space_scale)

        # uD and g
        dirichlet_func[:, [0, 1]] = scale(rotate(dirichlet_func[:, [0, 1]], theta), space_scale)

        # value scale
        #max_abs = max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min()), 1)
        graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        
        #value_scale = 1 / max_abs
        dirichlet_func[:, 2] = scale(dirichlet_func[:, 2], value_scale)
        #inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], value_scale)
        #graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        #logger.info(max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min())))
        boundary_func = dirichlet_func

        inputs_f = MultipleTensors([boundary_func])

        return graph, u_p, inputs_f


class Laplace2DSTransform(BasicTransform):
    def __init__(self,
                 max_shift=[[0.0, 0.0], [0.0, 0.0]],
                 max_space_scale=[0.05, 1.0], 
                 max_value_scale=[0.05, 1.0],
                 always_apply=False,
                 p: float = 1.0):
        super(Laplace2DSTransform, self).__init__(always_apply, p)
        self.max_shift = max_shift
        self.max_space_scale = max_space_scale
        self.max_value_scale = max_value_scale
    
    def _apply_transform(self, graph, u_p, inputs_f):
        graph = copy.deepcopy(graph)
        u_p = copy.deepcopy(u_p)
        inputs_f = copy.deepcopy(inputs_f)

        shift = torch.tensor([random.uniform(*self.max_shift[0]), random.uniform(*self.max_shift[1])])
        rx = [graph.ndata["x"][:, 0].min(), graph.ndata["x"][:, 0].max()]
        ry = [graph.ndata["x"][:, 1].min(), graph.ndata["x"][:, 1].max()]
        shift = torch.tensor([-(rx[0] + rx[1]) * 0.5, -(ry[0] + ry[1]) * 0.5])
        shift_back = torch.tensor([0.5, 0.5])
        space_scale = torch.tensor(random.uniform(*self.max_space_scale))
        value_scale = torch.tensor(random.uniform(*self.max_value_scale))
        #value_scale = torch.square(space_scale)
        theta = torch.tensor(random.uniform(0, 2*math.pi))
        p = torch.tensor([0.0, 0.0])

        # # shift and space scale
        graph.ndata["x"] = translate(scale(rotate(translate(graph.ndata["x"], shift), p, theta), space_scale), shift_back)

        # uD
        inputs_f.x[0][:, [0, 1]] = translate(scale(rotate(translate(inputs_f.x[0][:, [0, 1]], shift), p, theta), space_scale), shift_back)
        # g
        inputs_f.x[1][:, [0, 1]] = translate(scale(rotate(translate(inputs_f.x[1][:, [0, 1]], shift), p, theta), space_scale), shift_back)
        inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], 1.0 / space_scale)

        # value scale
        #max_abs = max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min()), 1)
        graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        
        #value_scale = 1 / max_abs
        inputs_f.x[0][:, 2] = scale(inputs_f.x[0][:, 2], value_scale)
        #inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], value_scale)
        #graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        #logger.info(max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min())))

        return graph, u_p, inputs_f

class Heat2dTransform(BasicTransform):
    def __init__(self,
                 max_space_scale=[0.05, 1.0], 
                 max_value_scale=[0.05, 1.0],
                 always_apply=False,
                 p: float = 1.0):
        super(Heat2dTransform, self).__init__(always_apply, p)
        self.max_space_scale = max_space_scale
        self.max_value_scale = max_value_scale
    
    def _apply_transform(self, graph, u_p, inputs_f):
        graph = copy.deepcopy(graph)
        u_p = copy.deepcopy(u_p)
        inputs_f = copy.deepcopy(inputs_f)

        initial_func = inputs_f.x[0]
        boundary_func = inputs_f.x[1]
        dirichlet_func = boundary_func[boundary_func[:, -1] == 0.0]

        space_scale = torch.tensor(random.uniform(*self.max_space_scale))
        value_scale = torch.tensor(random.uniform(*self.max_value_scale))
        theta = torch.tensor(random.uniform(0, 2*math.pi))

        # # shift and space scale
        graph.ndata["x"] = scale(rotate(graph.ndata["x"], theta), space_scale)

        # uD and g
        dirichlet_func[:, [0, 1]] = scale(rotate(dirichlet_func[:, [0, 1]], theta), space_scale)
        initial_func[:, [0, 1]] = scale(rotate(initial_func[:, [0, 1]], theta), space_scale)
        # # g
        # inputs_f.x[1][:, [0, 1]] = translate(scale(rotate(translate(inputs_f.x[1][:, [0, 1]], shift), p, theta), space_scale), shift_back)
        u_p = scale(u_p, space_scale * space_scale)

        # value scale
        #max_abs = max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min()), 1)
        graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        
        #value_scale = 1 / max_abs
        dirichlet_func[:, 2:] = scale(dirichlet_func[:, 2:], value_scale)
        initial_func[:, 2:] = scale(initial_func[:, 2:], value_scale)
        #neumann_func[:, 2] = scale(neumann_func[:, 2], value_scale)
        #inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], value_scale)
        #graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        #logger.info(max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min())))
        boundary_func = dirichlet_func

        inputs_f = MultipleTensors([initial_func, boundary_func])

        return graph, u_p, inputs_f

class Helmholtz2dTransform(BasicTransform):
    def __init__(self,
                 max_space_scale=[0.05, 1.0], 
                 max_value_scale=[0.05, 1.0],
                 always_apply=False,
                 p: float = 1.0):
        super(Helmholtz2dTransform, self).__init__(always_apply, p)
        self.max_space_scale = max_space_scale
        self.max_value_scale = max_value_scale
    
    def _apply_transform(self, graph, u_p, inputs_f):
        graph = copy.deepcopy(graph)
        u_p = copy.deepcopy(u_p)
        inputs_f = copy.deepcopy(inputs_f)

        boundary_func = inputs_f.x[0]
        dirichlet_func = boundary_func[boundary_func[:, -1] == 0.0]

        space_scale = torch.tensor(random.uniform(*self.max_space_scale))
        value_scale = torch.tensor(random.uniform(*self.max_value_scale))
        theta = torch.tensor(random.uniform(0, 2*math.pi))

        # # shift and space scale
        graph.ndata["x"] = scale(rotate(graph.ndata["x"], theta), space_scale)

        # uD
        dirichlet_func[:, [0, 1]] = scale(rotate(dirichlet_func[:, [0, 1]], theta), space_scale)
        u_p = scale(u_p, space_scale)

        # value scale
        #max_abs = max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min()), 1)
        graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        
        #value_scale = 1 / max_abs
        dirichlet_func[:, 2] = scale(dirichlet_func[:, 2], value_scale)
        #inputs_f.x[1][:, 2] = scale(inputs_f.x[1][:, 2], value_scale)
        #graph.ndata["y"] = scale(graph.ndata["y"], value_scale)
        #logger.info(max(abs(graph.ndata["y"].max()), abs(graph.ndata["y"].min())))
        boundary_func = dirichlet_func

        inputs_f = MultipleTensors([boundary_func])

        return graph, u_p, inputs_f


class SpaceTranslate:
    def __init__(self, 
                 max_shift:Union[float, Sequence[float]],
                 always_apply = False,
                  p: float = 1.0,
                  space_coord = [0, 1]):
        """
        Instantiate sub-pixel space translation.
        Translations are drawn from the distribution.
        Uniform(-max_shift/2, max_shift/2) where max_shift is in units of input side length.
        Args:
            shift (iterable): maximum shift length (rotations)
        """
        self.max_shift = max_shift
        self.space_coord = space_coord
        self.space_dim = len(space_coord)
        self.always_apply = always_apply
        self.p = p

    def _apply_transform(self, u, eps):
        if eps is None:
            eps = self.max_shift * (torch.rand(()) - 0.5)
        else:
            eps = eps * torch.ones(())

        u[:, self.space_coord] = u[:, self.space_coord] + eps

        return u

    def apply(self, 
              u: numpy.array, 
              eps: Optional[float]=None,
              ):
        """
        Sub-pixel space translation shift.
        Args:
            sample (torch.Tensor): input tensor of the form [u, X]
            eps (float): shift parameter
            shift (str): fourier or linear shift
        Returns:
            torch.Tensor: sub-pixel shifted tensor of the form [u, X]
        """

        p = random.ramdom()
        if p >= self.p:
            u = self._apply_transform(u, eps)

        return u


class Scale:
    def __init__(self, max_scale: float=1.):
        """
        Instantiate scale generator.
        Scale transformations are drawn from the distribution
        Uniform(-max_scale, max_scale)
        Args:
            max_scale (float): maximum scale shift
        """
        self.max_scale = max_scale

    def apply(self, sample: torch.Tensor, eps: Optional[float]=None) -> torch.Tensor:
        """
        Scale shift.
        Args:
            sample (torch.Tensor): input tensor of the form [u, X]
            eps (float): shift parameter
            shift (str): fourier or linear shift (not used, only for consistency w.r.t. other generators)
        Returns:
            torch.Tensor: scale shifted tensor of the form [u, X]
        """
        u, X = sample
        
        if eps is None:
            eps = self.max_scale * (torch.rand(()) - 0.5)
        else:
            eps = eps * torch.ones(())
        
        # X[..., 0] *= torch.exp(-eps)
        # X[..., 1] *= torch.exp(-3 * eps)       # in-place operation will cause trouble when transform_batch is applied
        # return (torch.exp(2 * eps) * u, X)
        
        a = X[:, :, 0] * torch.exp(-eps)
        b = X[:, :, 1] * torch.exp(-3 * eps)
        return (torch.exp(2 * eps) * u, torch.cat((a, b), -1))


class Galileo:
    def __init__(self, max_velocity: float=1) -> torch.Tensor:
        """
        Instantiate Galileo generator.
        Galilean transformations are drawn from the distribution
            Uniform(-max_velocity, max_velocity) where max_velocity is in units of m/s.
        Args:
            max_velocity: float for maximum velocity in m/s.
        """
        self.max_velocity = max_velocity

    def __call__(self, sample: torch.Tensor, eps: Optional[float]=None, shift: str='fourier') -> torch.Tensor:
        """
        Galilean shift.
        Args:
            sample (torch.Tensor): input tensor of the form [u, X]
            eps (float): shift parameter
            shift (str): fourier or linear shift (not used, only for consistency w.r.t. other generators)
        Returns:
            torch.Tensor: Galilean shifted tensor of the form [u, X]
        """
        u, X = sample

        T = u.shape[-2]
        N = u.shape[-1]
        if len(X.shape) == 3:
            dx = X[0, 1, 0] - X[0, 0, 0]
            dt = X[1, 0, 1] - X[0, 0, 1]
        else:
            raise ValueError("X.shape should be (space, time, 2) ")
        t = dt * torch.arange(T)
        L = dx * N

        if eps is None:
            eps = 2 * self.max_velocity * (torch.rand(()) - 0.5)
        else:
            eps = eps * torch.ones(())
        # shift in pixel
        d = -(eps * t) / L

        if shift == 'fourier':
            output = (fourier_shift(u, eps=d[:, None], dim=-1) - eps, X)
        elif shift == 'linear':
            output = (linear_shift(u, eps=d[:, None], dim=-1) - eps, X)

        return output


def heat_to_burgers(psi: torch.Tensor, nu: float, L: float) -> torch.Tensor:
    """
    Cole-Hopf transformation which transforms a trajectory of the Heat equation into the Burgers' equation.
    Args:
        psi (torch.Tensor): input trajectory of the Heat equation
        nu (float): diffusion coefficient
        L (float): length of spatial domain
    Returns:
        torch.Tensor: Cole-Hopf transformed trajectory of the Burgers' equation
    """
    psi_max = torch.amax(psi, dim=-1, keepdim=True)
    psi_min = torch.amin(psi, dim=-1, keepdim=True)
    psi = psi / (psi_max - psi_min)
    psix = fourier_shift(psi, order=1)
    return -(psix / psi) / (2 * nu)


class Subalgebra:
    def __init__(self, nu: float, alpha: float=0.5):
        """
        Instantiate Burgers A_alpha generator.
        CAVEAT: samples have to be solutions to the Heat equation!!!!
        Args:
            nu: dynamic viscocity coefficient
            alpha: convex interpolation coefficient [0, 1]
        """
        self.nu = nu
        self.alpha = alpha

    def __call__(self, heat_sample1: torch.Tensor, heat_sample2: torch.Tensor=None, alpha: float=None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Infinite subalgebra.
        Combines two trajectories of the Heat equation and returns a Cole-Hopf transformed
        trajectory of the Burgers' equation.
        Args:
            heat_sample1 (torch.Tensor): input tensor of the form [u, X] of the first trajectory of the Heat equation
            heat_sample2 (torch.Tensor): input tensor of the form [u, X] of the second trajectory of the Heat equation
            alpha (float): mixing parameter
            shift (str): fourier or linear shift (not used, only for consistency w.r.t. other generators)
        Returns:
            torch.Tensor: Cole-Hopf transformed trajectory of the Burgers' equation of the form [u, X]
        """
        psi1, X = heat_sample1
        dx = X[1, 0, 1] - X[0, 0, 1]
        L = dx * X.shape[0]

        if heat_sample2 is not None:
            if alpha is None:
                alpha = self.max_alpha
            else:
                alpha = alpha
            psi2, _ = heat_sample2
            u = heat_to_burgers((1 - alpha) * psi1 + alpha * psi2, self.nu, L)
        else:
            u = heat_to_burgers(psi1, self.nu, L)

        return u, X


class KdV_augmentation:
    def __init__(self, max_x_shift: float, max_velocity: float, max_scale):
        """
        Instantiate KdV data augmentation.
        Args:
            max_x_shift (float): parameter of sub-pixel space translation
            max_velocity (float): parameter of Galilean transformation
            max_scale (float): parameter of scaling transformation
        """
        self.generators = [
                           SpaceTranslate(max_x_shift),
                           Galileo(max_velocity),
                           Scale(max_scale)
                          ]

    def __call__(self, u: torch.Tensor, eps: Optional[torch.Tensor]=None, shift: str='fourier') -> torch.Tensor:
    # def __call__(self, u: torch.Tensor, shift: str='fourier') -> torch.Tensor:
        """
        KdV data augmentation, evoking one generator after each other.
        Args:
            u (torch.Tensor): input tensor of the form [u, X]
            eps: Optional[torch.Tensor]: eps list for each augmentation
            shift (str): fourier or linear shift (not used, only for consistency w.r.t. other generators)
        Returns:
            torch.Tensor: new space shifted, Galilean transformed and scaled trajectory
        """
        # for g in self.generators:
        #     u = g(u, shift=shift)
        # return u        
            
        for i in range(len(self.generators)):
            g = self.generators[i]  
            if eps is not None:                
                e = eps[i]
                u = g(u, eps=e, shift=shift)
            else:
                u = g(u, shift=shift)
        return u

class KS_augmentation:
    def __init__(self, max_x_shift: float, max_velocity: float):
        """
        Instantiate KS data augmentation.
        Args:
            max_x_shift (float): parameter of sub-pixel space translation
            max_velocity (float): parameter of Galilean transformation
        """
        self.generators = [SpaceTranslate(max_x_shift),
                           Galileo(max_velocity)]

    def __call__(self, u: torch.Tensor, shift: str='fourier') -> torch.Tensor:
        """
        KS data augmentation, evoking one generator after each other.
        Args:
            u (torch.Tensor): input tensor of the form [u, X]
            shift (str): fourier or linear shift (not used, only for consistency w.r.t. other generators)
        Returns:
            torch.Tensor: new space shifted and Galilean transformed trajectory
        """
        for g in self.generators:
            u = g(u, shift=shift)
        return u

class Heat_augmentation:
    def __init__(self, max_x_shift: float):
        """
        Instantiate Heat data augmentation (other generators might still be added).
        Args:
            max_x_shift (float): parameter of sub-pixel space translation
        """
        self.generators = [SpaceTranslate(max_x_shift)]
        # TODO add other generators

    def __call__(self, u: torch.Tensor) -> torch.Tensor:
        """
        Heat data augmentation, evoking one generator after each other.
        Args:
            u (torch.Tensor): input tensor of the form [u, X]
        Returns:
            torch.Tensor: new space shifted trajectory
        """
        for g in self.generators:
            u = g(u)
        return u
