from typing import Any, List
from copy import deepcopy

import numpy as np
from sklearn.utils.validation import check_random_state
import torch
from torch import nn
from torch.distributions import Categorical
from torch.nn.functional import gumbel_softmax

from braindecode.augmentation.base import Output, Compose

from eeg_augment.diff_aug.relax import RelaxGumbelSoftmax, ControlVariate
from eeg_augment.diff_aug.dada import (
    convert_dada_subpolicy,
    DADAPolicy,
    SampledTransform,
)
from eeg_augment.auto_augmentation import sample_subpolicy_and_apply,\
    AugmentationPolicy, ClasswiseSubpolicy
from eeg_augment.diff_aug.diff_transforms import (
    DiffTimeReverse, DiffSignFlip, DiffFTSurrogate, DiffMissingChannels,
    DiffShuffleChannels, DiffGaussianNoise, DiffChannelSymmetry,
    DiffTimeMask, DiffFrequencyShift, DiffRandomXRotation,
    DiffRandomYRotation, DiffRandomZRotation, POSSIBLE_DIFF_TRANSFORMS,
    convert_diff_transform
)


def make_diff_transforms_subset(sfreq, ordered_ch_names, random_state=None):
    return nn.ModuleList([
        DiffTimeReverse(random_state=random_state),
        DiffSignFlip(random_state=random_state),
        DiffFTSurrogate(random_state=random_state),
        DiffChannelSymmetry(
            ordered_ch_names=ordered_ch_names,
            random_state=random_state
        ),
        DiffFrequencyShift(sfreq=sfreq, random_state=random_state),
    ])


def make_all_diff_transforms(
    sfreq,
    ordered_ch_names,
    random_state=None,
    init_random_state=None
):
    """Instantiates one object of each DiffTransform

    Parameters
    ----------
    sfreq : float
        Sampling frequency of EEG.
    ordered_ch_names : list
        Ordered list of channels corresponding to each row in EEG epochs.
    random_state : int | numpy.random.RandomState | None, optional
        Passed to each DiffTransform object to be used in their stochastic
        forward method. Defaults to None.
    init_random_state : int | numpy.random.RandomState | None, optional
        Used to sample the initial prbabilities and magnitudes of each
        DiffTransform object. Defaults to None.

    Returns
    -------
    torch.nn.ModuleList
        DiffTransform objects, which can be passed directly to a
        SubpolicyStage class.
    """
    forward_rng = check_random_state(random_state)
    init_rng = check_random_state(init_random_state)

    possible_additional_params = {
        "sfreq": sfreq,
        "ordered_ch_names": ordered_ch_names
    }

    return nn.ModuleList([
        tf_class(
            random_state=forward_rng,
            initial_probability=init_rng.uniform(),
            initial_magnitude=init_rng.uniform(),
            **{key: value for key, value in possible_additional_params.items()
               if key in param_names}
        )
        for tf_class, param_names in POSSIBLE_DIFF_TRANSFORMS.items()
    ])


class SubpolicyStage(nn.Module):
    """Differentiable module made of a softmax activation on a set of
    DiffTransforms. Main building block of DiffSubpolicy.

    As proposed in [1]_

    Code copied and modified from https://github.com/moskomule/dda/

    Parameters
    ----------
    transforms : torch.nn.ModuleList
        List of transforms to learn weights.
    temperature : float,
        Used for setting how steep to do the softmax operator.
    rand_init : bool, optional
        Whether or not to randomly initialize transforms weights. If not,
        weights are initialized equaly to 1. Defaults to True.
    grad_est : str, optional
        Defines what gradient estimator to use for the operations weights.
            * None: a softmax of the weights will be used, producing a convex
              combination of all operations in the forward and backward pass.
            * "gumbel": the straight-through Gumbel-Softmax trick is used [2]_
              where a single operation is sampled per batch in the forward and
              gumbel-softmax distribution is optimized in the backward pass.
            * "relax": the behavior is similar to "gumbel" but where the
              gradient is estimated with no bias, using the RELAX estimator
              [3]_.
        Defaults to None.
    random_state : numpy.random.RandomState | None, optional
        Used in the forward call of each stage to sample perturbation when this
        functionality is enabled. Defaults to None.
    init_random_state : numpy.random.RandomState | None, optional
        Used to initialize transform weights in each stage. Defaults to None.

    References
    ----------
    .. [1] Hataya R., Zdenek J., Yoshizoe K., Nakayama H. (2020) Faster
    AutoAugment: Learning Augmentation Strategies Using Backpropagation.
    In: Vedaldi A., Bischof H., Brox T., Frahm JM. (eds) Computer Vision –
    ECCV 2020. ECCV 2020. Lecture Notes in Computer Science, vol 12370.
    Springer, Cham. https://doi.org/10.1007/978-3-030-58595-2_1
    .. [2] Jang, Gu, et Poole, (2017) Categorical Reparameterization with
    Gumbel-Softmax. In Proceedings of ICLR.
    .. [3] Grathwohl et al. (2018). Backpropagation through the Void:
    Optimizing control variates for black-box gradient estimation. In
    Proceedings of ICLR.
    """
    def __init__(self,
                 transforms: nn.ModuleList,
                 temperature: float,
                 rand_init: bool = True,
                 grad_est: Any = None,
                 random_state: Any = None,
                 init_random_state: Any = None,
                 device: str or torch.device = "cpu",
                 ):
        super(SubpolicyStage, self).__init__()
        self.transforms = transforms
        n = len(self.transforms)
        self.rng = check_random_state(random_state)
        self.init_rng = check_random_state(init_random_state)
        self.device = device
        if not rand_init:
            self._weights = nn.Parameter(torch.ones(n))
        else:
            bound = np.sqrt(6) / np.sqrt(n + 1)
            self._weights = nn.Parameter(
                torch.as_tensor(self.init_rng.rand(n), dtype=torch.float)
                * 2 * bound - bound
            )
        self.temperature = temperature
        self.grad_est = grad_est
        if self.grad_est == "relax":
            # We need to instantiate a NN-based control variate when using
            # RELAX estimator
            self.variate_net = ControlVariate(
                self._weights.shape[0]
            ).to(self.device)
            self.temperature = torch.as_tensor(
                self.temperature,
                device=self.device,
            )
            self._weights.data = self._weights.to(self.device)
        self._perturb_params = False
        self._pertub_range = 0.1

    def forward(self, X, y) -> Output:
        if self.grad_est in ['relax', 'gumbel']:
            # When using Relax or Gumbel softmax, use the corresponding special
            # sampling method to get a onehot encoded vector
            weights = self.sample_weights()

            # In order to have gradients backprop through weights, a special
            # sampling is needed
            return SampledTransform.apply(weights, self.transforms, X, y)
        else:
            # In the vanilla case, the behavior changes between training and
            # inference
            if self.training:
                # Compute outputs of all possible transforms
                all_ops_out = torch.stack(
                    [op(X, y)[0] for op in self.transforms]
                )
                # Just compute the softmax of all weights in vanilla case
                reshaped_weights = self.weights.view(-1, 1, 1, 1).to(X.device)
                return (all_ops_out * reshaped_weights).sum(0), y
            else:
                # At inference, the operation is simply sampled using the
                # weights
                sampled_op_idx = Categorical(self.weights).sample()
                return self.transforms[sampled_op_idx](X, y)

    def sample_weights(self):
        """Sample a single operation using the Gumbel-softmax trick or RELAX
        estimator.

        Returns
        -------
        torch.Tensor
            Onehot encoded vector corresponding to the sampled transform.
        """
        if self.grad_est == "gumbel":
            return gumbel_softmax(
                self._weights,
                tau=self.temperature,
                hard=True,  # straight-through behavior
            )
        elif self.grad_est == "relax":
            self.variate_net.zero_grad()
            # TODO: Temperature could be learnt too
            return RelaxGumbelSoftmax.apply(
                self._weights,
                self.temperature,
                self.variate_net,
            )

    def perturbation_on(self, eps=None):
        """Enables probability, magnitude and weights perturbation by sampling
        a delta at each new call of these properties

        Parameters
        ----------
        eps : float | None, optional
            Wehn different then None, sets the range to sample perturbations
            uniformly between [-`eps`, `eps`], by default None.
        """
        self._perturb_params = True
        if eps is not None:
            self._pertub_range = eps
        for tf in self.transforms:
            tf.perturbation_on(eps=eps)

    def perturbation_off(self):
        """Desables probability, magnitude and weights perturbation
        """
        self._perturb_params = False
        for tf in self.transforms:
            tf.perturbation_off()

    @property
    def weights(self):
        if self._perturb_params:
            perturbed_w = self._weights + torch.as_tensor(
                self.rng.uniform(
                    low=-self._pertub_range,
                    high=-self._pertub_range,
                    size=self._weights.shape[0]
                ),
                device=self._weights.device,
                dtype=torch.float,
            )
            return perturbed_w.div(self.temperature).softmax(dim=0)
        return self._weights.div(self.temperature).softmax(dim=0)

    @property
    def all_probabilities(self):
        return np.array([op.probability.item() for op in self.transforms])

    @property
    def all_magnitudes(self):
        return np.array([
            op.magnitude.item() for op in self.transforms
            if op.magnitude is not None
        ])

    @property
    def all_prob_grads(self):
        return np.array([
            op._probability.grad.item() for op in self.transforms
            if op._probability.grad is not None
        ])

    @property
    def all_mag_grads(self):
        return np.array([
            op._magnitude.grad.item() for op in self.transforms
            if op.magnitude is not None and op._magnitude.grad is not None
        ])

    @property
    def all_weights_grads(self):
        return self._weights.grad


class DiffSubpolicy(nn.Module):
    """Differentiable augmentation subpolicy

    As proposed in [1]_

    Code copied and modified from https://github.com/moskomule/dda/

    Parameters
    ----------
    subpolicy_stages : list
        List of differentiable SubpolicyStage objects to group into a subpolicy

    References
    ----------
    .. [1] Hataya R., Zdenek J., Yoshizoe K., Nakayama H. (2020) Faster
    AutoAugment: Learning Augmentation Strategies Using Backpropagation.
    In: Vedaldi A., Bischof H., Brox T., Frahm JM. (eds) Computer Vision –
    ECCV 2020. ECCV 2020. Lecture Notes in Computer Science, vol 12370.
    Springer, Cham. https://doi.org/10.1007/978-3-030-58595-2_1
    """
    def __init__(
        self,
        subpolicy_stages: List[SubpolicyStage]
    ):
        super(DiffSubpolicy, self).__init__()
        self.stages = nn.ModuleList(subpolicy_stages)

    def forward(self, X, y):
        for stage in self.stages:
            X, y = stage(X, y)
        return X, y

    def perturbation_on(self, eps=None):
        """Enables probability, magnitude and weights perturbation by sampling
        a delta at each new call of these properties

        Parameters
        ----------
        eps : float | None, optional
            Wehn different then None, sets the range to sample perturbations
            uniformly between [-`eps`, `eps`], by default None.
        """
        for stage in self.stages:
            stage.perturbation_on(eps=eps)

    def perturbation_off(self):
        """Desables probability, magnitude and weights perturbation
        """
        for stage in self.stages:
            stage.perturbation_off()

    @property
    def all_weights(self):
        return torch.hstack([stage.weights for stage in self.stages])

    @property
    def all_probabilities(self):
        return np.hstack([stage.all_probabilities for stage in self.stages])

    @property
    def all_magnitudes(self):
        return np.hstack([stage.all_magnitudes for stage in self.stages])

    @property
    def all_prob_grads(self):
        return np.hstack([stage.all_prob_grads for stage in self.stages])

    @property
    def all_mag_grads(self):
        return np.hstack([stage.all_mag_grads for stage in self.stages])

    @property
    def all_weight_grads(self):
        subpolicy_w_grads = [
            stage._weights.grad for stage in self.stages
            if stage._weights.grad is not None
        ]
        if len(subpolicy_w_grads) > 0:
            return torch.hstack(subpolicy_w_grads)


class DiffAugmentationPolicy(nn.Module):
    """Differentiable augmentation policy

    As proposed in [1]_

    Parameters
    ----------
    n_subpolicies : int
        Number of subpolicies to sample uniformly from.
    subpolicy_len : int, optional
        Number of consecutive DiffTransforms making up a subpolicy. Defaults to
        2.
    temperature : float, optional
        Used for setting how steep to do the softmax operator in the subpolicy
        stages. Defaults to 0.05.
    ch_names : list, optional
        Ordered list of channels in the dataset. Defaults to None.
    sfreq : float, optional
        Sampling rate in the dataset. Defaults to None.
    batchwise : bool, optional
        Whether to sample subpolicies per batch (True) or per sample (False).
        Defaults to True.
    grad_est : str, optional
        Defines what gradient estimator to use for the operations weights.
            * None: a softmax of the weights will be used, producing a convex
              combination of all operations in the forward and backward pass.
            * "gumbel": the straight-through Gumbel-Softmax trick is used [2]_
              where a single operation is sampled per batch in the forward and
              gumbel-softmax distribution is optimized in the backward pass.
            * "relax": the behavior is similar to "gumbel" but where the
              gradient is estimated with no bias, using the RELAX estimator
              [3]_.
        Defaults to None.
    random_state : int | numpy.random.RandomState, optional
        Used for setting the transforms and sampling the subpolicies. Defaults
        to None.

    References
    ----------
    .. [1] Hataya R., Zdenek J., Yoshizoe K., Nakayama H. (2020) Faster
    AutoAugment: Learning Augmentation Strategies Using Backpropagation.
    In proceedings of ECCV 2020.
    .. [2] Jang, Gu, et Poole, (2017) Categorical Reparameterization with
    Gumbel-Softmax. In Proceedings of ICLR.
    .. [3] Grathwohl et al. (2018). Backpropagation through the Void:
    Optimizing control variates for black-box gradient estimation. In
    Proceedings of ICLR.
    """
    def __init__(
        self,
        n_subpolicies: int,
        subpolicy_len: int = 2,
        temperature: float = 0.05,
        ch_names: list = None,
        sfreq: float = None,
        batchwise: bool = True,
        grad_est: Any = None,
        random_state: Any = None,
    ):
        super(DiffAugmentationPolicy, self).__init__()

        self.rng = check_random_state(random_state)
        self.batchwise = batchwise
        self.grad_est = grad_est

        # So that init does not change self.rng state
        # Important to ensure independance between forward passes and init
        init_rng = deepcopy(self.rng)

        self.diff_subpolicies = nn.ModuleList([
            DiffSubpolicy(
                subpolicy_stages=[
                    SubpolicyStage(
                        transforms=make_all_diff_transforms(
                            sfreq=sfreq,
                            ordered_ch_names=ch_names,
                            random_state=self.rng,
                            init_random_state=init_rng,
                        ),
                        temperature=temperature,
                        grad_est=self.grad_est,
                        random_state=self.rng,
                        init_random_state=init_rng,
                    )
                    for _ in range(subpolicy_len)
                ],
            )
            for _ in range(n_subpolicies)
        ])

    def forward(self, X, y):
        return sample_subpolicy_and_apply(
            X, y,
            subpolicies=self.diff_subpolicies,
            random_state=self.rng,
            batchwise=self.batchwise,
        )

    def perturbation_on(self, eps=None):
        """Enables probability, magnitude and weights perturbation by sampling
        a delta at each new call of these properties

        Parameters
        ----------
        eps : float | None, optional
            Wehn different then None, sets the range to sample perturbations
            uniformly between [-`eps`, `eps`], by default None.
        """
        for subpol in self.diff_subpolicies:
            subpol.perturbation_on(eps=eps)

    def perturbation_off(self):
        """Desables probability, magnitude and weights perturbation
        """
        for subpol in self.diff_subpolicies:
            subpol.perturbation_off()

    @property
    def all_probabilities(self):
        return np.hstack([
                    subpolicy.all_probabilities
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_magnitudes(self):
        return np.hstack([
                    subpolicy.all_magnitudes
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_weights(self):
        return torch.hstack([
                    subpolicy.all_weights
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_prob_grads(self):
        return np.hstack([
                    subpolicy.all_prob_grads
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_mag_grads(self):
        return np.hstack([
                    subpolicy.all_mag_grads
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_weight_grads(self):
        weights_grads = [
            subpolicy.all_weight_grads
            for subpolicy in self.diff_subpolicies
            if subpolicy.all_weight_grads is not None
        ]
        if len(weights_grads) == 0:
            return None
        return torch.hstack(weights_grads)


def convert_diff_subpolicy(diff_subpolicy, random_state):
    """ Converts a differentiable subpolicy object into a regular subpolicy
    (i.e. Compose transforms object)
    """
    seq_standard_tr = list()
    for o, stage in enumerate(diff_subpolicy.stages):
        selected_transform_idx = stage.weights.argmax()
        selected_transform = stage.transforms[selected_transform_idx]
        m = selected_transform._magnitude
        if m is not None:
            m = m.item()
        print(
            f"{o}: {type(selected_transform).__name__}",
            f"p={selected_transform._probability.item()}",
            f"m={m}",
        )
        seq_standard_tr.append(
            convert_diff_transform(selected_transform, random_state)
        )
    return Compose(seq_standard_tr)


def convert_diff_classwise_subpolicy(diff_class_subpolicy, random_state):
    """ Converts a differentiable classwise subpolicy object into a regular
    claswise subpolicy
    """
    equiv_subpolicies_per_class = {
        int(c): convert_diff_subpolicy(diff_subp, random_state)
        for c, diff_subp in diff_class_subpolicy.subpolicies_per_class.items()
    }
    return ClasswiseSubpolicy(equiv_subpolicies_per_class)


def diff_policy_to_standard_policy(diff_policy, random_state):
    """ Converts a differentiable policy object (both classwise and not) into a
    regular policy object
    """
    if isinstance(diff_policy, DADAPolicy):
        convertion_func = convert_dada_subpolicy
        convert_type = "DADA policy"

        sorted_subpolicies_idx = diff_policy.weights.sort()[1]
        iterator = [
            diff_policy.diff_subpolicies[i]
            for i in sorted_subpolicies_idx[-diff_policy.n_subpolicies:]
        ]
    elif isinstance(diff_policy, DiffAugmentationPolicy):
        convertion_func = convert_diff_subpolicy
        convert_type = "diff policy"
        iterator = diff_policy.diff_subpolicies
    elif isinstance(diff_policy, DiffClasswisePolicy):
        convertion_func = convert_diff_classwise_subpolicy
        convert_type = "diff classwise policy"
        iterator = diff_policy.diff_subpolicies
    else:
        raise ValueError(
            "diff_policy has to be either a DiffAugmentationPolicy "
            "or a DiffClasswisePolicy"
        )
    print(f"Converting {convert_type} to standard one:")
    set_standard_subp = list()
    for s, diff_subpolicy in enumerate(iterator):
        print(f"Subpolicy {s}")
        equivalent_subpolicy = convertion_func(diff_subpolicy, random_state)
        set_standard_subp.append(equivalent_subpolicy)
    return AugmentationPolicy(set_standard_subp, random_state=random_state)


class DiffClasswiseSubpolicy(nn.Module):
    """Differentiable classwise augmentation subpolicy

    Parameters
    ----------
    temperature : float,
        Used for setting how steep to do the softmax operator.
    ordered_ch_names : list, optional
        Ordered list of channels in the dataset. Defaults to None
    sfreq : float, optional
        Sampling rate in the dataset. Defaults to None.
    subpolicy_len : int,
        Number of consecutive DiffTransforms to learn.
    classes : list, optional
        List containing all possible classes from the dataset.
        Default to [1, 2, 3, 4, 5], which maps to W, N1, N2, N3, REM
        sleep stages.
    grad_est : str, optional
        Defines what gradient estimator to use for the operations weights.
            * None: a softmax of the weights will be used, producing a convex
              combination of all operations in the forward and backward pass.
            * "gumbel": the straight-through Gumbel-Softmax trick is used [2]_
              where a single operation is sampled per batch in the forward and
              gumbel-softmax distribution is optimized in the backward pass.
            * "relax": the behavior is similar to "gumbel" but where the
              gradient is estimated with no bias, using the RELAX estimator
              [3]_.
        Defaults to None.
    random_state : int, optional
        Passed to each DiffTransform object to be used in their stochastic
        forward method. Defaults to None.
    init_random_state : int | numpy.random.RandomState | None, optional
        Used to sample the initial weights, probabilities and magnitudes of
        each DiffTransform object. Defaults to None.

    References
    ----------
    .. [1] Jang, Gu, et Poole, (2017) Categorical Reparameterization with
    Gumbel-Softmax. In Proceedings of ICLR.
    .. [2] Grathwohl et al. (2018). Backpropagation through the Void:
    Optimizing control variates for black-box gradient estimation. In
    Proceedings of ICLR.
    """
    def __init__(
        self,
        temperature: float = 0.05,
        ordered_ch_names: list = None,
        sfreq: float = None,
        subpolicy_len: int = 2,
        classes: List[Any] = None,
        grad_est: Any = None,
        random_state: Any = None,
        init_rng: Any = None
    ):
        super(DiffClasswiseSubpolicy, self).__init__()

        self.rng = check_random_state(random_state)
        init_rng = check_random_state(init_rng)

        if classes is None:
            classes = list(range(5))
        self.classes = classes
        self.grad_est = grad_est
        self.subpolicies_per_class = nn.ModuleDict({
            str(c): DiffSubpolicy(
                subpolicy_stages=[
                    SubpolicyStage(
                        transforms=make_all_diff_transforms(
                            sfreq=sfreq,
                            ordered_ch_names=ordered_ch_names,
                            random_state=self.rng,
                            init_random_state=init_rng,
                        ),
                        temperature=temperature,
                        grad_est=self.grad_est,
                        random_state=self.rng,
                        init_random_state=init_rng,
                    )
                    for _ in range(subpolicy_len)
                ],
            ) for c in classes
        })

    def forward(self, X, y):
        tr_X = X.clone()
        for c in self.classes:
            mask = y == c
            if any(mask):
                tr_X[mask, ...], _ = self.subpolicies_per_class[str(c)](
                    X[mask, ...], y[mask]
                )
        return tr_X, y

    def perturbation_on(self, eps=None):
        """Enables probability, magnitude and weights perturbation by sampling
        a delta at each new call of these properties

        Parameters
        ----------
        eps : float | None, optional
            Wehn different then None, sets the range to sample perturbations
            uniformly between [-`eps`, `eps`], by default None.
        """
        for subpol in self.subpolicies_per_class.values():
            subpol.perturbation_on(eps=eps)

    def perturbation_off(self):
        """Desables probability, magnitude and weights perturbation
        """
        for subpol in self.subpolicies_per_class.values():
            subpol.perturbation_off()

    @property
    def all_probabilities(self):
        return np.hstack([
                    subpolicy.all_probabilities
                    for subpolicy in self.subpolicies_per_class.values()
                ])

    @property
    def all_magnitudes(self):
        return np.hstack([
                    subpolicy.all_magnitudes
                    for subpolicy in self.subpolicies_per_class.values()
                ])

    @property
    def all_weights(self):
        return torch.hstack([
                    subpolicy.all_weights
                    for subpolicy in self.subpolicies_per_class.values()
                ])

    @property
    def all_prob_grads(self):
        return np.hstack([
                    subpolicy.all_prob_grads
                    for subpolicy in self.subpolicies_per_class.values()
                ])

    @property
    def all_mag_grads(self):
        return np.hstack([
                    subpolicy.all_mag_grads
                    for subpolicy in self.subpolicies_per_class.values()
                ])

    @property
    def all_weight_grads(self):
        weights_grads = [
            subpolicy.all_weight_grads
            for subpolicy in self.subpolicies_per_class.values()
            if subpolicy.all_weight_grads is not None
        ]
        if len(weights_grads) == 0:
            return None
        return torch.hstack(weights_grads)


class DiffClasswisePolicy(nn.Module):
    """Differentiable classwise augmentation policy

    Parameters
    ----------
    n_subpolicies : int
        Number of classwise subpolicies to sample uniformly from.
    subpolicy_len : int, optional
        Number of consecutive DiffTransforms making up a subpolicy. Defaults to
        2.
    classes : list, optional
        List containing all possible classes from the dataset.
        Default to [1, 2, 3, 4, 5], which maps to W, N1, N2, N3, REM
        sleep stages.
    temperature : float, optional
        Used for setting how steep to do the softmax operator in the subpolicy
        stages. Defaults to 0.05.
    ch_names : list, optional
        Ordered list of channels in the dataset. Defaults to None.
    sfreq : float, optional
        Sampling rate in the dataset. Defaults to None.
    batchwise : bool, optional
        Whether to sample subpolicies per batch (True) or per sample (False).
        Defaults to True.
    grad_est : bool, optional
        Defines what gradient estimator to use for the operations weights.
        When None, a softmax of the weights will be used, producing a convex
        combination of all operations in the forward and backward pass. When
        "gumbel", the straight-through Gumbel-Softmax trick is used [1]_ where
        a single operation is sampled per batch in the forward and
        gumbel-softmax distribution is optimized in the backward pass. If
        "relax", the behavior is similar to "gumbel" but where the gradient is
        estimated with no bias, using the RELAX estimator [2]_. Defaults to
        None.
    random_state : int | numpy.random.RandomState, optional
        Used for setting the transforms and sampling the subpolicies. Defaults
        to None.

    References
    ----------
    .. [1] Jang, Gu, et Poole, (2017) Categorical Reparameterization with
    Gumbel-Softmax. In Proceedings of ICLR.
    .. [2] Grathwohl et al. (2018). Backpropagation through the Void:
    Optimizing control variates for black-box gradient estimation. In
    Proceedings of ICLR.
    """
    def __init__(
        self,
        n_subpolicies: int,
        subpolicy_len: int = 2,
        classes: List[Any] = None,
        temperature: float = 0.05,
        ch_names: list = None,
        sfreq: float = None,
        batchwise: bool = True,
        grad_est: Any = None,
        random_state: Any = None,
    ):
        super(DiffClasswisePolicy, self).__init__()

        self.rng = check_random_state(random_state)
        self.batchwise = batchwise
        self.grad_est = grad_est

        init_rng = deepcopy(self.rng)

        self.diff_subpolicies = nn.ModuleList([
            DiffClasswiseSubpolicy(
                temperature=temperature,
                ordered_ch_names=ch_names,
                sfreq=sfreq,
                subpolicy_len=subpolicy_len,
                classes=classes,
                grad_est=self.grad_est,
                random_state=self.rng,
                init_rng=init_rng,
            )
            for _ in range(n_subpolicies)
        ])
        self.classes = self.diff_subpolicies[0].classes

    def forward(self, X, y):
        return sample_subpolicy_and_apply(
            X, y,
            subpolicies=self.diff_subpolicies,
            random_state=self.rng,
            batchwise=self.batchwise,
        )

    def perturbation_on(self, eps=None):
        """Enables probability, magnitude and weights perturbation by sampling
        a delta at each new call of these properties

        Parameters
        ----------
        eps : float | None, optional
            Wehn different then None, sets the range to sample perturbations
            uniformly between [-`eps`, `eps`], by default None.
        """
        for subpol in self.diff_subpolicies:
            subpol.perturbation_on(eps=eps)

    def perturbation_off(self):
        """Desables probability, magnitude and weights perturbation
        """
        for subpol in self.diff_subpolicies:
            subpol.perturbation_off()

    @property
    def all_probabilities(self):
        return np.hstack([
                    subpolicy.all_probabilities
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_magnitudes(self):
        return np.hstack([
                    subpolicy.all_magnitudes
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_weights(self):
        return torch.hstack([
                    subpolicy.all_weights
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_prob_grads(self):
        return np.hstack([
                    subpolicy.all_prob_grads
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_mag_grads(self):
        return np.hstack([
                    subpolicy.all_mag_grads
                    for subpolicy in self.diff_subpolicies
                ])

    @property
    def all_weight_grads(self):
        weights_grads = [
            subpolicy.all_weight_grads
            for subpolicy in self.diff_subpolicies
            if subpolicy.all_weight_grads is not None
        ]
        if len(weights_grads) == 0:
            return None
        return torch.hstack(weights_grads)
