from typing import Optional, Tuple, Any
from numbers import Real
from copy import deepcopy

import numpy as np
import torch
from torch import Tensor, nn
from torch.distributions import RelaxedBernoulli, Bernoulli
from torch.autograd import Function

from braindecode.augmentation.base import Transform, Output
from braindecode.augmentation.functionals import time_reverse, sign_flip,\
    fft_surrogate, channel_dropout, add_gaussian_noise,\
    permute_channels, _sample_mask_start, _relaxed_mask_time,\
    _pick_channels_randomly, _make_permutation_matrix, random_bandstop,\
    freq_shift, random_rotation, get_standard_10_20_positions
from braindecode.augmentation.transforms import TimeReverse
from braindecode.augmentation.transforms import SignFlip
from braindecode.augmentation.transforms import FTSurrogate
from braindecode.augmentation.transforms import MissingChannels
from braindecode.augmentation.transforms import ShuffleChannels
from braindecode.augmentation.transforms import GaussianNoise
from braindecode.augmentation.transforms import ChannelSymmetry
from braindecode.augmentation.transforms import TimeMask
from braindecode.augmentation.transforms import FrequencyShift
from braindecode.augmentation.transforms import RandomXRotation
from braindecode.augmentation.transforms import RandomYRotation
from braindecode.augmentation.transforms import RandomZRotation


DIFF_TO_STANDARD_MAP = {
    "DiffTimeReverse": TimeReverse,
    "DiffSignFlip": SignFlip,
    "DiffFTSurrogate": FTSurrogate,
    "DiffMissingChannels": MissingChannels,
    "DiffShuffleChannels": ShuffleChannels,
    "DiffGaussianNoise": GaussianNoise,
    "DiffChannelSymmetry": ChannelSymmetry,
    "DiffTimeMask": TimeMask,
    "DiffFrequencyShift": FrequencyShift,
    "DiffRandomXRotation": RandomXRotation,
    "DiffRandomYRotation": RandomYRotation,
    "DiffRandomZRotation": RandomZRotation,
}


class DiffTransform(Transform):
    """ Basic class used for implementing differentiable data augmentations
    (where probability and magnitude can be learned with gradient descent)

    As proposed in [1]_

    Code copied and modified from https://github.com/moskomule/dda/

    Parameters
    ----------
    operation : callable
        A function taking arrays X, y (sample features and
        target resp.) and other required arguments, and returning the
        transformed X and y.
    initial_probability : float | None, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. When set to None, the
        initial probability will be drawn from a uniform distribution. Set to
        None by default.
    initial_magnitude : float | None, optional
        Initial value for the magnitude. Defines the strength of the
        transformation applied between 0 and 1 and depends on the nature of the
        transformation and on its range. Some transformations don't have any
        magnitude. It can be equivalent to another argument of object
        with more meaning. In case both are passed, magnitude will override the
        latter. Defaults to None (uniformly sampled between 0 and 1).
    mag_range : tuple of two floats | None, optional
        Valid range of the argument mapped by `magnitude` (e.g. standard
        deviation, number of sample, etc.):
        ```
        argument = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0].
        ```
        If `magnitude` is None it is ignored. Defaults to None.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int, optional
        Seed to be used to instatiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument. Also used for initializing the magnitudes and probabilities
        (independently of the forward) Defaults to None.
    *args:
        Arguments to be passed to operation.
    **kwargs:
        Keyword arguments to be passed to operation.

    References
    ----------
    .. [1] Hataya R., Zdenek J., Yoshizoe K., Nakayama H. (2020) Faster
    AutoAugment: Learning Augmentation Strategies Using Backpropagation.
    In: Vedaldi A., Bischof H., Brox T., Frahm JM. (eds) Computer Vision –
    ECCV 2020. ECCV 2020. Lecture Notes in Computer Science, vol 12370.
    Springer, Cham. https://doi.org/10.1007/978-3-030-58595-2_1
    """
    has_mag = True

    def __init__(self, operation, initial_probability=None,
                 initial_magnitude=None, mag_range=None, temperature=0.05,
                 random_state=None, *args, **kwargs):
        super().__init__(operation=operation, probability=initial_probability,
                         magnitude=initial_magnitude, mag_range=mag_range,
                         random_state=random_state, *args, **kwargs)

        # Important to separate the init RNG from the forward one, so that
        # one can init a new policy and load a checkpoint without modifying
        # the forward rng state
        self.init_rng = deepcopy(self.rng)
        if initial_probability is None:
            initial_probability = self.init_rng.uniform()
        if initial_magnitude is None:
            initial_magnitude = self.init_rng.uniform()
        self.initial_probability = initial_probability
        self.initial_magnitude = initial_magnitude
        self.temperature = temperature
        self._perturb_params = False
        self._pertub_range = 0.1
        self._probability = nn.Parameter(
            torch.empty(1).fill_(initial_probability)
        )
        if self.has_mag:
            self._magnitude = nn.Parameter(
                torch.empty(1).fill_(initial_magnitude)
            )
        else:
            self._magnitude = None

    def forward(self, X: Tensor, y: Tensor) -> Output:
        if self.training:
            mask = self._get_mask(X.shape[0]).to(X.device)
            magnitude = self.magnitude
            if magnitude is not None:
                magnitude = magnitude.to(X.device)
            tr_X, tr_y = self.operation(
                X.clone(), y.clone(), *self.args,
                random_state=self.rng, magnitude=magnitude, **self.kwargs
            )
            return mask * tr_X + (1 - mask) * X, tr_y
        else:
            return super().forward(X, y)

    def _get_mask(self, batch_size=None) -> torch.Tensor:
        # size = (batch_size, 1, 1)
        size = (batch_size, 1)
        if self.training:
            # The only problem I see with this line is that I cannot use the
            # RNG that already exists (I would need to fix the whole torch seed
            # if I watn something  kind of reproduceable)...
            # An alternative here could be to re-implement the relaxed
            # bernoulli
            # XXX: Unsure about the control we have upon the randomness here
            mask = RelaxedBernoulli(
                temperature=self.temperature, probs=self.probability
            ).rsample(size)
            return mask
        else:
            return Bernoulli(probs=self.probability).sample(size)

    def perturbation_on(self, eps=None):
        """Enables probability and magnitude perturbation by sampling a delta
        at each new call of both properties

        Parameters
        ----------
        eps : float | None, optional
            Wehn different then None, sets the range to sample perturbations
            uniformly between [-`eps`, `eps`], by default None.
        """
        self._perturb_params = True
        if eps is not None:
            self._pertub_range = eps

    def perturbation_off(self):
        """Desables probability and magnitude perturbation
        """
        self._perturb_params = False

    @property
    def probability(self) -> torch.Tensor:
        if self._perturb_params:
            perturbed_p = self._probability + self.rng.uniform(
                low=-self._pertub_range,
                high=-self._pertub_range,
            )
            return perturbed_p.clamp(0, 1)
        return self._probability.clamp(0, 1)
        # An alternative could be the following, but it has its drawbacks too
        # return torch.sigmoid((self._probability - 0.5)*5)

    @property
    def magnitude(self) -> Optional[torch.Tensor]:
        if self._magnitude is None:
            return None
        if self._perturb_params:
            perturbed_m = self._magnitude + self.rng.uniform(
                low=-self._pertub_range,
                high=-self._pertub_range,
            )
            return perturbed_m.clamp(0, 1)
        return self._magnitude.clamp(0, 1)
        # An alternative could be the following, but it has its drawbacks too
        # return torch.sigmoid((self._magnitude - 0.5)*5)


class _STE(Function):
    """ Straight-through gradient estimator

    Copied from https://github.com/moskomule/dda/
    """

    @staticmethod
    def forward(ctx,
                input_forward: torch.Tensor,
                input_backward: torch.Tensor) -> torch.Tensor:
        ctx.shape = input_backward.shape
        return input_forward

    @staticmethod
    def backward(ctx,
                 grad_in: torch.Tensor) -> Tuple[None, torch.Tensor]:
        return None, grad_in.sum_to_size(ctx.shape)


def ste(input_forward: torch.Tensor,
        input_backward: torch.Tensor) -> torch.Tensor:
    """ Function useful for applying the straight-through gradient estimator to
    a non-differentiable function

    Copied from https://github.com/moskomule/dda/
    """

    return _STE.apply(input_forward, input_backward).clone()


class DiffTimeReverse(DiffTransform):
    """ Flip the time axis of each feature sample with a given probability

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : object, optional
        Always ignored, exists for compatibility.
    mag_range : object, optional
        Always ignored, exists for compatibility.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument. Defaults to None.
    """

    has_mag = False

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=None,
        temperature=0.05,
        random_state=None
    ):
        super().__init__(
            operation=time_reverse,
            initial_probability=initial_probability,
            temperature=temperature,
            random_state=random_state
        )


class DiffSignFlip(DiffTransform):
    """ Flip the sign axis of each feature sample with a given probability

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : object, optional
        Always ignored, exists for compatibility.
    mag_range : object, optional
        Always ignored, exists for compatibility.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument. Defaults to None.
    """

    has_mag = False

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=None,
        temperature=0.05,
        random_state=None
    ):
        super().__init__(
            operation=sign_flip,
            initial_probability=initial_probability,
            temperature=temperature,
            random_state=random_state
        )


class DiffFTSurrogate(DiffTransform):
    """ FT surrogate augmentation of a single EEG channel, as proposed in [1]_

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float | None, optional
        Initial value for the magnitude. Defines the strength of the
        transformation applied between 0 and 1 and depends on the nature of the
        transformation and on its range. Some transformations don't have any
        magnitude (=None). It can be equivalent to another argument of object
        with more meaning. In case both are passed, magnitude will override the
        latter. Defaults to 0.5.
    mag_range : object, optional
        Always ignored, exists for compatibility.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument. Defaults to None.

    References
    ----------
    .. [1] Schwabedal, J. T., Snyder, J. C., Cakmak, A., Nemati, S., &
       Clifford, G. D. (2018). Addressing Class Imbalance in Classification
       Problems of Noisy Signals by using Fourier Transform Surrogates. arXiv
       preprint arXiv:1806.08675.
    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=None,
        temperature=0.05,
        random_state=None
    ):
        super().__init__(
            operation=fft_surrogate,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            temperature=temperature,
            random_state=random_state
        )


class DiffMissingChannels(DiffTransform):
    """ Randomly set channels to flat signal

    Part of the CMSAugment policy proposed in [1]_

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float | None, optional
        Initial value for the magnitude. Defines the strength of the
        transformation applied between 0 and 1 and depends on the nature of the
        transformation and on its range. Some transformations don't have any
        magnitude (=None). It can be equivalent to another argument of object
        with more meaning. In case both are passed, magnitude will override the
        latter. Defaults to 0.5.
    mag_range : object, optional
        Always ignored, exists for compatibility.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument and to sample channels to erase. Defaults to None.

    References
    ----------
    .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
       Learning from Heterogeneous EEG Signals with Differentiable Channel
       Reordering. arXiv preprint arXiv:2010.13694.
    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=None,
        temperature=0.05,
        random_state=None
    ):
        super().__init__(
            operation=self._channel_dropout,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            temperature=temperature,
            random_state=random_state
        )

    def _channel_dropout(self, X, y, magnitude, random_state, *args, **kwargs):
        return channel_dropout(X, y, magnitude, random_state, *args, *kwargs)


class DiffShuffleChannels(DiffTransform):
    """ Randomly shuffle channels in EEG data matrix

    Part of the CMSAugment policy proposed in [1]_

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float | None, optional
        Initial value for the magnitude. Defines the strength of the
        transformation applied between 0 and 1 and depends on the nature of the
        transformation and on its range. Some transformations don't have any
        magnitude (=None). It can be equivalent to another argument of object
        with more meaning. In case both are passed, magnitude will override the
        latter. Defaults to 0.5.
    mag_range : object, optional
        Always ignored, exists for compatibility.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument. Defaults to None.

    References
    ----------
    .. [1] Saeed, A., Grangier, D., Pietquin, O., & Zeghidour, N. (2020).
       Learning from Heterogeneous EEG Signals with Differentiable Channel
       Reordering. arXiv preprint arXiv:2010.13694.
    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=None,
        temperature=0.05,
        random_state=None
    ):
        super().__init__(
            operation=self._channel_shuffle,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            temperature=temperature,
            random_state=random_state
        )

    def _channel_shuffle(self, X, y, magnitude, random_state, *args, **kwargs):
        mask = _pick_channels_randomly(X, 1-magnitude, random_state)
        batch_permutations = ste(
            _make_permutation_matrix(X, mask, random_state),
            magnitude
        )
        return torch.matmul(batch_permutations, X), y


class DiffGaussianNoise(DiffTransform):
    """Randomly add white noise to all channels

    Suggested e.g. in [1]_, [2]_ and [3]_

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float | None, optional
        Initial value for magnitude. Float between 0 and 1 encoding the
        standard deviation to use for the additive noise:
        ```
        std = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Std range when set using the magnitude (see `magnitude`).
        If omitted, the range (0, 0.2) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Wang, F., Zhong, S. H., Peng, J., Jiang, J., & Liu, Y. (2018). Data
       augmentation for eeg-based emotion recognition with deep convolutional
       neural networks. In International Conference on Multimedia Modeling
       (pp. 82-93).
    .. [2] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
       Subject-aware contrastive learning for biosignals. arXiv preprint
       arXiv:2007.04871.
    .. [3] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
       Representation Learning for Electroencephalogram Classification. In
       Machine Learning for Health (pp. 238-253). PMLR.

    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 0.2),
        random_state=None
    ):
        super().__init__(
            operation=self._add_gaussian_noise,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            random_state=random_state,
        )

    def _add_gaussian_noise(self, X, y, magnitude, *args, **kwargs):
        min_val, max_val = self.mag_range
        std = magnitude * max_val + (1 - magnitude) * min_val
        return add_gaussian_noise(X, y, std=std, *args, **kwargs)


class DiffChannelSymmetry(DiffTransform):
    """Permute EEG channels inverting left and right-side sensors

    Suggested e.g. in [1]_

    Parameters
    ----------
    ordered_ch_names : list
        Ordered list of strings containing the names (in 10-20
        nomenclature) of the EEG channels that will be transformed. The
        first name should correspond the data in the first row of X, the
        second name in the second row and so on.
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : object, optional
        Always ignored, exists for compatibility.
    mag_range : object, optional
        Always ignored, exists for compatibility.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Used to decide whether or not to transform given the probability
        argument. Defaults to None.

    References
    ----------
    .. [1] Deiss, O., Biswal, S., Jin, J., Sun, H., Westover, M. B., & Sun, J.
       (2018). HAMLET: interpretable human and machine co-learning technique.
       arXiv preprint arXiv:1803.09702.

    """

    has_mag = False

    def __init__(
        self,
        ordered_ch_names,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=None,
        temperature=0.05,
        random_state=None
    ):
        assert (
            isinstance(ordered_ch_names, list) and
            all(isinstance(ch, str) for ch in ordered_ch_names)
        ), "ordered_ch_names should be a list of str."
        self.ordered_ch_names = ordered_ch_names

        permutation = list()
        for idx, ch_name in enumerate(ordered_ch_names):
            new_position = idx
            # Find digits in channel name (assuming 10-20 system)
            d = ''.join(list(filter(str.isdigit, ch_name)))
            if len(d) > 0:
                d = int(d)
                if d % 2 == 0:  # pair/right electrodes
                    sym = d - 1
                else:  # odd/left electrodes
                    sym = d + 1
                new_channel = ch_name.replace(str(d), str(sym))
                if new_channel in ordered_ch_names:
                    new_position = ordered_ch_names.index(new_channel)
            permutation.append(new_position)

        super().__init__(
            operation=permute_channels,
            initial_probability=initial_probability,
            permutation=permutation,
            random_state=random_state,
        )


class DiffTimeMask(DiffTransform):
    """Replace part of all channels by zeros

    Suggested e.g. in [1]_ and [2]_
    Similar to the time variant of SpecAugment for speech signals [3]_

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial value for the magnitude. Float between 0 and 1 encoding the
        number of consecutive samples within `mag_range` to set to 0:
        ```
        mask_len_samples = int(round(magnitude * mag_range[1] +
            (1 - magnitude) * mag_range[0]))
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `mask_len_samples` settable using the
        magnitude (see `magnitude`). If omitted, the range (0, 100) samples
        will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
       Subject-aware contrastive learning for biosignals. arXiv preprint
       arXiv:2007.04871.
    .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
       Representation Learning for Electroencephalogram Classification. In
       Machine Learning for Health (pp. 238-253). PMLR.
    .. [3] Park, D.S., Chan, W., Zhang, Y., Chiu, C., Zoph, B., Cubuk, E.D.,
       Le, Q.V. (2019) SpecAugment: A Simple Data Augmentation Method for
       Automatic Speech Recognition. Proc. Interspeech 2019, 2613-2617

    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 100),
        temperature=0.05,
        random_state=None
    ):
        super().__init__(
            operation=self._mask_time,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            temperature=temperature,
            random_state=random_state,
        )

    def _mask_time(self, X, y, magnitude, random_state, *args, **kwargs):
        min_val, max_val = self.mag_range
        mask_len_samples = (
            magnitude * max_val + (1 - magnitude) * min_val
        )
        mask_start_per_sample = _sample_mask_start(
           X, mask_len_samples, random_state
        )
        transformed_X = _relaxed_mask_time(X, mask_start_per_sample,
                                           mask_len_samples)
        return transformed_X, y


class DiffBandstopFilter(DiffTransform):
    """Applies a stopband filter with desired bandwidth at a randomly selected
    frequency position between 0 and `max_freq`.

    Suggested e.g. in [1]_ and [2]_
    Similar to the frequency variant of SpecAugment for speech signals [3]_

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial value for the magnitude. Float between 0 and 1 encoding the
        bandwidth of the filter, i.e. distance between the low and high cut
        frequencies:
        ```
        bandwidth = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `bandwidth` settable using the magnitude
        (see `magnitude`). If omitted, the range (0, 2 Hz) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    sfreq : float, optional
        Sampling frequency of the signals to be filtered. Defaults to 100 Hz.
    max_freq : float | None, optional
        Maximal admissible frequency. The low cut frequency will be sampled so
        that the corresponding high cut frequency + transition are below
        `max_freq`. If omitted or `None`, will default to the Nyquist frequency
        (`sfreq / 2`).
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Cheng, J. Y., Goh, H., Dogrusoz, K., Tuzel, O., & Azemi, E. (2020).
       Subject-aware contrastive learning for biosignals. arXiv preprint
       arXiv:2007.04871.
    .. [2] Mohsenvand, M. N., Izadi, M. R., & Maes, P. (2020). Contrastive
       Representation Learning for Electroencephalogram Classification. In
       Machine Learning for Health (pp. 238-253). PMLR.
    .. [3] Park, D.S., Chan, W., Zhang, Y., Chiu, C., Zoph, B., Cubuk, E.D.,
       Le, Q.V. (2019) SpecAugment: A Simple Data Augmentation Method for
       Automatic Speech Recognition. Proc. Interspeech 2019, 2613-2617
    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 2),
        temperature=0.05,
        sfreq=100,
        max_freq=50,
        random_state=None
    ):
        assert isinstance(sfreq, Real) and sfreq > 0,\
            "sfreq should be a positive float."
        assert isinstance(max_freq, Real) and max_freq > 0,\
            "max_freq should be a positive float."
        nyq = sfreq / 2
        if max_freq is None or max_freq > nyq:
            max_freq = nyq

        # override bandwidth value when a magnitude is passed
        self.sfreq = sfreq
        self.max_freq = max_freq

        super().__init__(
            operation=self._random_bandstop,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            temperature=temperature,
            max_freq=self.max_freq,
            sfreq=self.sfreq,
            random_state=random_state,
        )

    def _random_bandstop(self, X, y, magnitude, *args, **kwargs):
        min_val, max_val = self.mag_range
        bandwidth = magnitude * max_val + (1 - magnitude) * min_val
        return ste(
            random_bandstop(X, y, *args, bandwidth=bandwidth, **kwargs),
            magnitude
        )


class DiffFrequencyShift(DiffTransform):
    """Add a random shift in the frequency domain to all channels.

    Parameters
    ----------
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial  value for the magnitude. Float between 0 and 1 encoding the
        `max_shift` parameter:
        ```
        max_shift = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Random frequency shifts will be samples uniformly in the interval
        `[0, max_shift]`. Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `max_shift` settable using the magnitude
        (see `magnitude`). If omitted the range (0 Hz, 5 Hz) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    sfreq : float, optional
        Sampling frequency of the signals to be transformed. Default to 100 Hz.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.
    """

    def __init__(
        self,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 5),
        temperature=0.05,
        sfreq=100,
        random_state=None
    ):
        assert isinstance(sfreq, Real) and sfreq > 0,\
            "sfreq should be a positive float."
        self.sfreq = sfreq

        super().__init__(
            operation=self._freq_shift,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            temperature=temperature,
            sfreq=self.sfreq,
            random_state=random_state,
        )

    def _freq_shift(self, X, y, magnitude, *args, **kwargs):
        min_val, max_val = self.mag_range
        max_shift = magnitude * max_val + (1 - magnitude) * min_val
        return freq_shift(X, y, *args, max_shift=max_shift, **kwargs)


class DiffRandomSensorsRotation(DiffTransform):
    """Interpolates EEG signals over sensors rotated around the desired axis
    with an angle sampled uniformly between 0 and `max_degree`.

    Suggested in [1]_

    Parameters
    ----------
    sensors_positions_matrix : numpy.ndarray
        Matrix giving the positions of each sensor in a 3D cartesian coordiante
        systemsof. Should have shape (3, n_channels), where n_channels is the
        number of channels. Standard 10-20 positions can be obtained from
        `mne` through:
        ```
        ten_twenty_montage = mne.channels.make_standard_montage(
            'standard_1020'
        ).get_positions()['ch_pos']
        ```
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial value for the magnitude. Float between 0 and 1 encoding the
        `max_degree` parameter:
        ```
        max_degree = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `max_degree` settable using the magnitude
        (see `magnitude`). If omitted, the range (0, 30 degrees) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    axis : 'x' | 'y' | 'z', optional
        Axis around which to rotate. Defaults to 'z'.
    spherical_splines : bool, optional
        Whether to use spherical splines for the interpolation or not. When
        `False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
        used (as in the original paper). Defaults to True.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
       electroencephalographic data. In 2017 39th Annual International
       Conference of the IEEE Engineering in Medicine and Biology Society
       (EMBC) (pp. 471-474).
    """

    def __init__(
        self,
        sensors_positions_matrix,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 30),
        temperature=0.05,
        axis='z',
        spherical_splines=True,
        random_state=None
    ):
        if isinstance(sensors_positions_matrix, (np.ndarray, list)):
            sensors_positions_matrix = torch.as_tensor(
                sensors_positions_matrix
            )
        assert isinstance(sensors_positions_matrix, torch.Tensor),\
            "sensors_positions should be an Tensor"
        assert isinstance(axis, str) and axis in ['x', 'y', 'z'],\
            "axis can be either x, y or z."
        assert sensors_positions_matrix.shape[0] == 3,\
            "sensors_positions_matrix shape should be 3 x n_channels."
        assert isinstance(spherical_splines, bool),\
            "spherical_splines should be a boolean"
        self.sensors_positions_matrix = sensors_positions_matrix
        self.axis = axis
        self.spherical_splines = spherical_splines

        super().__init__(
            operation=self._random_rotation,
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            axis=self.axis,
            sensors_positions_matrix=self.sensors_positions_matrix,
            spherical_splines=self.spherical_splines,
            random_state=random_state
        )

    def _random_rotation(self, X, y, magnitude, *args, **kwargs):
        min_val, max_val = self.mag_range
        max_degrees = magnitude * max_val + (1 - magnitude) * min_val
        rotated_X, y = random_rotation(
            X, y, *args, max_degrees=max_degrees, **kwargs)
        return rotated_X, y


class DiffRandomZRotation(DiffRandomSensorsRotation):
    """Interpolates EEG signals over sensors rotated around the Z axis
    with an angle sampled uniformly between 0 and `max_degree`.

    Suggested in [1]_

    Parameters
    ----------
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Has to be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial value for the magnitude. Float between 0 and 1 encoding the
        `max_degree` parameter:
        ```
        max_degree = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `max_degree` settable using the magnitude
        (see `magnitude`). If omitted, the range (0, 30 degrees) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    spherical_splines : bool, optional
        Whether to use spherical splines for the interpolation or not. When
        `False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
        used (as in the original paper). Defaults to True.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
       electroencephalographic data. In 2017 39th Annual International
       Conference of the IEEE Engineering in Medicine and Biology Society
       (EMBC) (pp. 471-474).
    """

    def __init__(
        self,
        ordered_ch_names,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 30),
        temperature=0.05,
        spherical_splines=True,
        random_state=None
    ):
        self.ordered_ch_names = ordered_ch_names
        sensors_positions_matrix = torch.as_tensor(
            get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
        )
        super().__init__(
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            axis='z',
            sensors_positions_matrix=sensors_positions_matrix,
            spherical_splines=spherical_splines,
            random_state=random_state
        )


class DiffRandomYRotation(DiffRandomSensorsRotation):
    """Interpolates EEG signals over sensors rotated around the Z axis
    with an angle sampled uniformly between 0 and `max_degree`.

    Suggested in [1]_

    Parameters
    ----------
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Has to be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial value for the magnitude. Float between 0 and 1 encoding the
        `max_degree` parameter:
        ```
        max_degree = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `max_degree` settable using the magnitude
        (see `magnitude`). If omitted, the range (0, 30 degrees) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    spherical_splines : bool, optional
        Whether to use spherical splines for the interpolation or not. When
        `False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
        used (as in the original paper). Defaults to True.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
       electroencephalographic data. In 2017 39th Annual International
       Conference of the IEEE Engineering in Medicine and Biology Society
       (EMBC) (pp. 471-474).
    """

    def __init__(
        self,
        ordered_ch_names,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 30),
        temperature=0.05,
        spherical_splines=True,
        random_state=None
    ):
        self.ordered_ch_names = ordered_ch_names
        sensors_positions_matrix = torch.as_tensor(
            get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
        )
        super().__init__(
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            axis='y',
            sensors_positions_matrix=sensors_positions_matrix,
            spherical_splines=spherical_splines,
            random_state=random_state
        )


class DiffRandomXRotation(DiffRandomSensorsRotation):
    """Interpolates EEG signals over sensors rotated around the Z axis
    with an angle sampled uniformly between 0 and `max_degree`.

    Suggested in [1]_

    Parameters
    ----------
    ordered_ch_names : list
        List of strings representing the channels of the montage considered.
        Has to be in standard 10-20 style. The order has to be consistent with
        the order of channels in the input matrices that will be fed to the
        transform. This channel will be used to compute approximate sensors
        positions from a standard 10-20 montage.
    initial_probability : float, optional
        Initial value for probability. Float between 0 and 1 defining the
        uniform probability of applying the operation. Set to 0.5 by default.
    initial_magnitude : float, optional
        Initial value for the magnitude. Float between 0 and 1 encoding the
        `max_degree` parameter:
        ```
        max_degree = magnitude * mag_range[1] + (1 - magnitude) * mag_range[0]
        ```
        Defaults to 0.5.
    mag_range : tuple of two floats | None, optional
        Range of possible values for `max_degree` settable using the magnitude
        (see `magnitude`). If omitted, the range (0, 30 degrees) will be used.
    temperature : float, optional
        Temperature parameter of the RelaxedBernouilli distribution used to
        decide whether to apply the operation to the input or not. Defaults to
        0.05.
    spherical_splines : bool, optional
        Whether to use spherical splines for the interpolation or not. When
        `False`, standard scipy.interpolate.Rbf (with quadratic kernel) will be
        used (as in the original paper). Defaults to True.
    random_state: int | numpy.random.Generator, optional
        Seed to be used to instantiate numpy random number generator instance.
        Defaults to None.

    References
    ----------
    .. [1] Krell, M. M., & Kim, S. K. (2017). Rotational data augmentation for
       electroencephalographic data. In 2017 39th Annual International
       Conference of the IEEE Engineering in Medicine and Biology Society
       (EMBC) (pp. 471-474).
    """

    def __init__(
        self,
        ordered_ch_names,
        initial_probability=None,
        initial_magnitude=None,
        mag_range=(0, 30),
        temperature=0.05,
        spherical_splines=True,
        random_state=None
    ):
        self.ordered_ch_names = ordered_ch_names
        sensors_positions_matrix = torch.as_tensor(
            get_standard_10_20_positions(ordered_ch_names=ordered_ch_names)
        )
        super().__init__(
            initial_probability=initial_probability,
            initial_magnitude=initial_magnitude,
            mag_range=mag_range,
            axis='x',
            sensors_positions_matrix=sensors_positions_matrix,
            spherical_splines=spherical_splines,
            random_state=random_state
        )


def convert_diff_transform(diff_transform, random_state):
    """ Converts a differentiable transforms object into a regular Transforms
    """
    standard_class = DIFF_TO_STANDARD_MAP[type(diff_transform).__name__]
    kwargs = {}
    if hasattr(diff_transform, 'ordered_ch_names'):
        kwargs['ordered_ch_names'] = diff_transform.ordered_ch_names
    if hasattr(diff_transform, 'sfreq'):
        kwargs['sfreq'] = diff_transform.sfreq
    m = diff_transform.magnitude
    if m is not None:
        m = m.clone().detach().item()
    return standard_class(
        probability=diff_transform.probability.clone().detach().item(),
        magnitude=m,
        random_state=random_state,
        **kwargs
    )


POSSIBLE_DIFF_TRANSFORMS = {
    DiffTimeReverse: [],
    DiffSignFlip: [],
    DiffFTSurrogate: [],
    DiffMissingChannels: [],
    DiffShuffleChannels: [],
    DiffGaussianNoise: [],
    DiffChannelSymmetry: ["ordered_ch_names"],
    DiffTimeMask: [],
    DiffFrequencyShift: ["sfreq"],
    DiffRandomXRotation: ["ordered_ch_names"],
    DiffRandomYRotation: ["ordered_ch_names"],
    DiffRandomZRotation: ["ordered_ch_names"],
}
