
from itertools import permutations
import torch
from torch import nn
from scipy.optimize import linear_sum_assignment
import numpy as np
import torch.nn.functional as F

class PITLossWrapper(nn.Module):
    r"""Permutation invariant loss wrapper.

    Args:
        loss_func: function with signature (est_targets, targets, **kwargs).
        pit_from (str): Determines how PIT is applied.

            * ``'pw_mtx'`` (pairwise matrix): `loss_func` computes pairwise
              losses and returns a torch.Tensor of shape
              :math:`(batch, n\_src, n\_src)`. Each element
              :math:`(batch, i, j)` corresponds to the loss between
              :math:`targets[:, i]` and :math:`est\_targets[:, j]`
            * ``'pw_pt'`` (pairwise point): `loss_func` computes the loss for
              a batch of single source and single estimates (tensors won't
              have the source axis). Output shape : :math:`(batch)`.
              See :meth:`~PITLossWrapper.get_pw_losses`.
            * ``'perm_avg'`` (permutation average): `loss_func` computes the
              average loss for a given permutations of the sources and
              estimates. Output shape : :math:`(batch)`.
              See :meth:`~PITLossWrapper.best_perm_from_perm_avg_loss`.

            In terms of efficiency, ``'perm_avg'`` is the least efficicient.

        perm_reduce (Callable): torch function to reduce permutation losses.
            Defaults to None (equivalent to mean). Signature of the func
            (pwl_set, **kwargs) : :math:`(B, n\_src!, n\_src) --> (B, n\_src!)`.
            `perm_reduce` can receive **kwargs during forward using the
            `reduce_kwargs` argument (dict). If those argument are static,
            consider defining a small function or using `functools.partial`.
            Only used in `'pw_mtx'` and `'pw_pt'` `pit_from` modes.

    For each of these modes, the best permutation and reordering will be
    automatically computed. When either ``'pw_mtx'`` or ``'pw_pt'`` is used,
    and the number of sources is larger than three, the hungarian algorithm is
    used to find the best permutation.

    Examples
        >>> import torch
        >>> from asteroid.losses import pairwise_neg_sisdr
        >>> sources = torch.randn(10, 3, 16000)
        >>> est_sources = torch.randn(10, 3, 16000)
        >>> # Compute PIT loss based on pairwise losses
        >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
        >>> loss_val = loss_func(est_sources, sources)
        >>>
        >>> # Using reduce
        >>> def reduce(perm_loss, src):
        >>>     weighted = perm_loss * src.norm(dim=-1, keepdim=True)
        >>>     return torch.mean(weighted, dim=-1)
        >>>
        >>> loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx',
        >>>                            perm_reduce=reduce)
        >>> reduce_kwargs = {'src': sources}
        >>> loss_val = loss_func(est_sources, sources,
        >>>                      reduce_kwargs=reduce_kwargs)
    """

    def __init__(self, loss_func, pit_from="pw_mtx", perm_reduce=None):
        super().__init__()
        self.loss_func = loss_func
        self.pit_from = pit_from
        self.perm_reduce = perm_reduce
        if self.pit_from not in ["pw_mtx", "pw_pt", "perm_avg"]:
            raise ValueError(
                "Unsupported loss function type for now. Expected"
                "one of [`pw_mtx`, `pw_pt`, `perm_avg`]"
            )

    def forward(self, est_targets, targets, return_est=False, reduce_kwargs=None, **kwargs):
        r"""Find the best permutation and return the loss.

        Args:
            est_targets: torch.Tensor. Expected shape $(batch, nsrc, ...)$.
                The batch of target estimates.
            targets: torch.Tensor. Expected shape $(batch, nsrc, ...)$.
                The batch of training targets
            return_est: Boolean. Whether to return the reordered targets
                estimates (To compute metrics or to save example).
            reduce_kwargs (dict or None): kwargs that will be passed to the
                pairwise losses reduce function (`perm_reduce`).
            **kwargs: additional keyword argument that will be passed to the
                loss function.

        Returns:
            - Best permutation loss for each batch sample, average over
              the batch.
            - The reordered targets estimates if ``return_est`` is True.
              :class:`torch.Tensor` of shape $(batch, nsrc, ...)$.
        """
        n_src = targets.shape[1]
        # assert n_src < 10, f"Expected source axis along dim 1, found {n_src}"
        if self.pit_from == "pw_mtx":
            # Loss function already returns pairwise losses
            pw_losses = self.loss_func(est_targets, targets, **kwargs)
        elif self.pit_from == "pw_pt":
            # Compute pairwise losses with a for loop.
            pw_losses = self.get_pw_losses(self.loss_func, est_targets, targets, **kwargs)
        elif self.pit_from == "perm_avg":
            # Cannot get pairwise losses from this type of loss.
            # Find best permutation directly.
            min_loss, batch_indices = self.best_perm_from_perm_avg_loss(
                self.loss_func, est_targets, targets, **kwargs
            )
            # Take the mean over the batch
            mean_loss = torch.mean(min_loss)
            if not return_est:
                return mean_loss
            reordered = self.reorder_source(est_targets, batch_indices)
            return mean_loss, reordered
        else:
            return

        assert pw_losses.ndim == 3, (
            "Something went wrong with the loss " "function, please read the docs."
        )
        assert pw_losses.shape[0] == targets.shape[0], "PIT loss needs same batch dim as input"

        reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict()
        min_loss, batch_indices = self.find_best_perm(
            pw_losses, perm_reduce=self.perm_reduce, **reduce_kwargs
        )
        mean_loss = torch.mean(min_loss)
        if not return_est:
            return mean_loss
        reordered = self.reorder_source(est_targets, batch_indices)
        return mean_loss, reordered

    @staticmethod
    def get_pw_losses(loss_func, est_targets, targets, **kwargs):
        r"""Get pair-wise losses between the training targets and its estimate
        for a given loss function.

        Args:
            loss_func: function with signature (est_targets, targets, **kwargs)
                The loss function to get pair-wise losses from.
            est_targets: torch.Tensor. Expected shape $(batch, nsrc, ...)$.
                The batch of target estimates.
            targets: torch.Tensor. Expected shape $(batch, nsrc, ...)$.
                The batch of training targets.
            **kwargs: additional keyword argument that will be passed to the
                loss function.

        Returns:
            torch.Tensor or size $(batch, nsrc, nsrc)$, losses computed for
            all_mel_e2e permutations of the targets and est_targets.

        This function can be called on a loss function which returns a tensor
        of size :math:`(batch)`. There are more efficient ways to compute pair-wise
        losses using broadcasting.
        """
        batch_size, n_src, *_ = targets.shape
        pair_wise_losses = targets.new_empty(batch_size, n_src, n_src)
        for est_idx, est_src in enumerate(est_targets.transpose(0, 1)):
            for target_idx, target_src in enumerate(targets.transpose(0, 1)):
                pair_wise_losses[:, est_idx, target_idx] = loss_func(est_src, target_src, **kwargs)
        return pair_wise_losses

    @staticmethod
    def best_perm_from_perm_avg_loss(loss_func, est_targets, targets, **kwargs):
        r"""Find best permutation from loss function with source axis.

        Args:
            loss_func: function with signature $(est_targets, targets, **kwargs)$
                The loss function batch losses from.
            est_targets: torch.Tensor. Expected shape $(batch, nsrc, *)$.
                The batch of target estimates.
            targets: torch.Tensor. Expected shape $(batch, nsrc, *)$.
                The batch of training targets.
            **kwargs: additional keyword argument that will be passed to the
                loss function.

        Returns:
            - :class:`torch.Tensor`:
                The loss corresponding to the best permutation of size $(batch,)$.

            - :class:`torch.Tensor`:
                The indices of the best permutations.
        """
        n_src = targets.shape[1]
        perms = torch.tensor(list(permutations(range(n_src))), dtype=torch.long)
        loss_set = torch.stack(
            [loss_func(est_targets[:, perm], targets, **kwargs) for perm in perms], dim=1
        )
        # Indexes and values of min losses for each batch element
        min_loss, min_loss_idx = torch.min(loss_set, dim=1)
        # Permutation indices for each batch.
        batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
        return min_loss, batch_indices

    @staticmethod
    def find_best_perm(pair_wise_losses, perm_reduce=None, **kwargs):
        r"""Find the best permutation, given the pair-wise losses.

        Dispatch between factorial method if number of sources is small (<3)
        and hungarian method for more sources. If ``perm_reduce`` is not None,
        the factorial method is always used.

        Args:
            pair_wise_losses (:class:`torch.Tensor`):
                Tensor of shape :math:`(batch, n\_src, n\_src)`. Pairwise losses.
            perm_reduce (Callable): torch function to reduce permutation losses.
                Defaults to None (equivalent to mean). Signature of the func
                (pwl_set, **kwargs) : :math:`(B, n\_src!, n\_src) -> (B, n\_src!)`
            **kwargs: additional keyword argument that will be passed to the
                permutation reduce function.

        Returns:
            - :class:`torch.Tensor`:
              The loss corresponding to the best permutation of size $(batch,)$.

            - :class:`torch.Tensor`:
              The indices of the best permutations.
        """
        n_src = pair_wise_losses.shape[-1]
        if perm_reduce is not None or n_src <= 3:
            min_loss, batch_indices = PITLossWrapper.find_best_perm_factorial(
                pair_wise_losses, perm_reduce=perm_reduce, **kwargs
            )
        else:
            min_loss, batch_indices = PITLossWrapper.find_best_perm_hungarian(pair_wise_losses)
        return min_loss, batch_indices

    @staticmethod
    def reorder_source(source, batch_indices):
        r"""Reorder sources according to the best permutation.

        Args:
            source (torch.Tensor): Tensor of shape :math:`(batch, n_src, time)`
            batch_indices (torch.Tensor): Tensor of shape :math:`(batch, n_src)`.
                Contains optimal permutation indices for each batch.

        Returns:
            :class:`torch.Tensor`: Reordered sources.
        """
        reordered_sources = torch.stack(
            [torch.index_select(s, 0, b) for s, b in zip(source, batch_indices)]
        )
        return reordered_sources

    @staticmethod
    def find_best_perm_factorial(pair_wise_losses, perm_reduce=None, **kwargs):
        r"""Find the best permutation given the pair-wise losses by looping
        through all_mel_e2e the permutations.

        Args:
            pair_wise_losses (:class:`torch.Tensor`):
                Tensor of shape :math:`(batch, n_src, n_src)`. Pairwise losses.
            perm_reduce (Callable): torch function to reduce permutation losses.
                Defaults to None (equivalent to mean). Signature of the func
                (pwl_set, **kwargs) : :math:`(B, n\_src!, n\_src) -> (B, n\_src!)`
            **kwargs: additional keyword argument that will be passed to the
                permutation reduce function.

        Returns:
            - :class:`torch.Tensor`:
              The loss corresponding to the best permutation of size $(batch,)$.

            - :class:`torch.Tensor`:
              The indices of the best permutations.

        MIT Copyright (c) 2018 Kaituo XU.
        See `Original code
        <https://github.com/kaituoxu/Conv-TasNet/blob/master>`__ and `License
        <https://github.com/kaituoxu/Conv-TasNet/blob/master/LICENSE>`__.
        """
        n_src = pair_wise_losses.shape[-1]
        # After transposition, dim 1 corresp. to sources and dim 2 to estimates
        pwl = pair_wise_losses.transpose(-1, -2)
        perms = pwl.new_tensor(list(permutations(range(n_src))), dtype=torch.long)
        # Column permutation indices
        idx = torch.unsqueeze(perms, 2)
        # Loss mean of each permutation
        if perm_reduce is None:
            # one-hot, [n_src!, n_src, n_src]
            perms_one_hot = pwl.new_zeros((*perms.size(), n_src)).scatter_(2, idx, 1)
            loss_set = torch.einsum("bij,pij->bp", [pwl, perms_one_hot])
            loss_set /= n_src
        else:
            # batch = pwl.shape[0]; n_perm = idx.shape[0]
            # [batch, n_src!, n_src] : Pairwise losses for each permutation.
            pwl_set = pwl[:, torch.arange(n_src), idx.squeeze(-1)]
            # Apply reduce [batch, n_src!, n_src] --> [batch, n_src!]
            loss_set = perm_reduce(pwl_set, **kwargs)
        # Indexes and values of min losses for each batch element
        min_loss, min_loss_idx = torch.min(loss_set, dim=1)

        # Permutation indices for each batch.
        batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
        return min_loss, batch_indices

    @staticmethod
    def find_best_perm_hungarian(pair_wise_losses: torch.Tensor):
        """
        Find the best permutation given the pair-wise losses, using the Hungarian algorithm.

        Returns:
            - :class:`torch.Tensor`:
              The loss corresponding to the best permutation of size (batch,).

            - :class:`torch.Tensor`:
              The indices of the best permutations.
        """
        # After transposition, dim 1 corresp. to sources and dim 2 to estimates
        pwl = pair_wise_losses.transpose(-1, -2)
        # Just bring the numbers to cpu(), not the graph
        pwl_copy = pwl.detach().cpu()
        # Loop over batch + row indices are always ordered for square matrices.
        batch_indices = torch.tensor([linear_sum_assignment(pwl)[1] for pwl in pwl_copy]).to(
            pwl.device
        )
        min_loss = torch.gather(pwl, 2, batch_indices[..., None]).mean([-1, -2])
        return min_loss, batch_indices


class PITReorder(PITLossWrapper):
    """Permutation invariant reorderer. Only returns the reordered estimates.
    See `:py:class:asteroid.losses.PITLossWrapper`."""

    def forward(self, est_targets, targets, reduce_kwargs=None, **kwargs):
        _, reordered = super().forward(
            est_targets=est_targets,
            targets=targets,
            return_est=True,
            reduce_kwargs=reduce_kwargs,
            **kwargs,
        )
        return reordered


class LambdaOverlapAdd(torch.nn.Module):
    """Overlap-add with lambda transform on segments (not scriptable).

    Segment input signal, apply lambda function (a neural network for example)
    and combine with OLA.

    `LambdaOverlapAdd` can be used with :mod:`asteroid.separate` and the
    `asteroid-infer` CLI.

    Args:
        nnet (callable): Function to apply to each segment.
        n_src (Optional[int]): Number of sources in the output of nnet.
            If None, the number of sources is determined by the network's output,
            but some correctness checks cannot be performed.
        window_size (int): Size of segmenting window.
        hop_size (int): Segmentation hop size.
        window (str): Name of the window (see scipy.signal.get_window) used
            for the synthesis.
        reorder_chunks (bool): Whether to reorder each consecutive segment.
            This might be useful when `nnet` is permutation invariant, as
            source assignements might change output channel from one segment
            to the next (in classic speech separation for example).
            Reordering is performed based on the correlation between
            the overlapped part of consecutive segment.

     Examples
        >>> from asteroid import ConvTasNet
        >>> nnet = ConvTasNet(n_src=2)
        >>> continuous_nnet = LambdaOverlapAdd(
        >>>     nnet=nnet,
        >>>     n_src=2,
        >>>     window_size=64000,
        >>>     hop_size=None,        >>>     window="hanning",
        >>>     reorder_chunks=True,
        >>>     enable_grad=False,
        >>> )

        >>> # Process wav tensor:
        >>> wav = torch.randn(1, 1, 500000)
        >>> out_wavs = continuous_nnet.forward(wav)
        >>> # asteroid.separate.Separatable support:
        >>> from asteroid.separate import file_separate
        >>> file_separate(continuous_nnet, "example.wav")
    """

    def __init__(
        self,
        nnet,
        n_src,
        window_size,
        in_margin,
        window="hanning",
        reorder_chunks=True,
        enable_grad=False,
        device = torch.device("cpu")
    ):
        super().__init__()
        assert window_size % 2 == 0, "Window size must be even"

        self.nnet = nnet
        self.window_size = window_size
        self.hop_size = window_size
        self.n_src = n_src
        self.in_channels = getattr(nnet, "in_channels", None)
        self.in_margin = in_margin
        if window:
            from scipy.signal import get_window  # for torch.hub

            window = get_window(window, self.window_size).astype("float32")
            window = torch.from_numpy(window)
            self.use_window = True
        else:
            self.use_window = False

        self.register_buffer("window", window.type_as(nnet.f_helper.stft.conv_real.weight))
        self.reorder_chunks = reorder_chunks
        self.enable_grad = enable_grad

    def ola_forward(self, x, key='wav'):
        """Heart of the class: segment signal, apply func, combine with OLA."""
        """
        x: [batchsize, channels, samples]
        """
        assert x.ndim == 3

        batch, channels, n_frames = x.size()
        # Overlap and add:
        # [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
        # ================================================================================================
        def calc_L(outputsize, padding, dilation, kernel_size, stride):
            return int((outputsize+2*padding-dilation*(kernel_size-1)-1)/stride+1)

        # Pad signal
        last_frame_samples = n_frames - int(n_frames/self.window_size) * self.window_size
        if(last_frame_samples != 0):
            x = F.pad(x,(0,self.window_size-last_frame_samples))

        unfolded = F.unfold(
            x.unsqueeze(-1),
            kernel_size=(self.window_size+self.in_margin, 1),
            padding=(self.in_margin, 0),
            stride=(self.hop_size, 1),
        )

        out = []
        n_chunks = unfolded.shape[-1]
        ######################################################################
        # unfolded = unfolded.view(batch, self.window_size, channels, n_chunks)  # Wrong!!!
        unfolded = unfolded.view(batch, channels, self.window_size+self.in_margin, n_chunks)  # Split channel out !
        margin = torch.zeros(size=(batch,channels,self.in_margin,n_chunks)).type_as(unfolded)
        margin[...,:-1] = unfolded[...,self.in_margin:self.in_margin*2,1:]
        unfolded = torch.cat([unfolded,margin],dim=2)

        # unfolded = unfolded.permute(0,2,1,3) # convert to the shape of the model input
        ######################################################################
        for frame_idx in range(n_chunks):  # for loop to spare memory
            # print(unfolded[..., frame_idx].size())
            if(frame_idx == 0):
                frame = self.nnet(unfolded[..., frame_idx][...,self.in_margin:])
                frame = frame[key]  # convert to what the following code needs
                frame = frame[:, :, :-self.in_margin]
            elif(frame_idx == n_chunks-1 and last_frame_samples != 0):
                frame = self.nnet(unfolded[..., frame_idx][...,:self.in_margin+last_frame_samples])
                frame = frame[key]  # convert to what the following code needs
                frame = frame[:, :, self.in_margin:]
                frame = F.pad(frame,(0,self.window_size-last_frame_samples))
            elif(frame_idx == n_chunks-1 and last_frame_samples == 0):
                frame = self.nnet(unfolded[..., frame_idx][...,:-self.in_margin])
                frame = frame[key]  # convert to what the following code needs
                frame = frame[:, :, self.in_margin:]
            else:
                frame = self.nnet(unfolded[..., frame_idx])
                # x_out = self.nnet(x[:,:,int(frame_idx*self.window_size)-self.in_margin:int((frame_idx+1)*self.window_size)+self.in_margin])
                # print("out",torch.sum(x_out['wav']-frame['wav']))
                ######################################################################
                # frame = frame['wav'].permute(0,2,1) # convert to what the following code needs
                frame = frame[key] # convert to what the following code needs
                frame = frame[:,:,self.in_margin:-self.in_margin]
                # print(torch.sum(unfolded[..., frame_idx]-x[:,:,int(frame_idx*self.window_size)-self.in_margin:int((frame_idx+1)*self.window_size)+self.in_margin]))
            ######################################################################
            # user must handle multichannel by reshaping to batch
            if frame_idx == 0:
                assert frame.ndim == 3, "nnet should return (batch, n_src, time)"
                if self.n_src is not None:
                    assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
                n_src = frame.shape[1]
            frame = frame.reshape(batch * n_src, -1)

            if frame_idx != 0 and self.reorder_chunks:
                # we determine best perm based on xcorr with previous sources
                frame = _reorder_sources(frame, out[-1], n_src, self.window_size, self.hop_size)

            if self.use_window:
                frame = frame * self.window
            else:
                frame = frame / (self.window_size / self.hop_size)
            out.append(frame)

        out = torch.stack(out).reshape(n_chunks, batch * n_src, self.window_size)
        out = out.permute(1, 2, 0)

        L = calc_L(outputsize=out.size()[-1]*out.size()[-2],padding=0,dilation=1,kernel_size=self.window_size,stride=self.hop_size)
        out = out[...,:L]
        out = torch.nn.functional.fold(
            out,
            (out.size()[-1]*out.size()[-2], 1),
            kernel_size=(self.window_size, 1),
            padding=(0, 0),
            stride=(self.hop_size, 1),
        )
        out = out.squeeze(-1).reshape(batch, n_src, -1)
        out = out[...,:n_frames]
        return out

    def forward(self, x, key="wav"):
        """Forward module: segment signal, apply func, combine with OLA.

        Args:
            x (:class:`torch.Tensor`): waveform signal of shape (batch, channels, time).

        Returns:
            :class:`torch.Tensor`: The output of the lambda OLA.
        """
        # Here we can do the reshaping
        with torch.autograd.set_grad_enabled(self.enable_grad):
            olad = self.ola_forward(x,key=key)
            return olad

    # Implement `asteroid.separate.Separatable` (separation support)

    @property
    def sample_rate(self):
        return self.nnet.sample_rate

    def _separate(self, wav, *args, **kwargs):
        return self.forward(wav, *args, **kwargs)


def _reorder_sources(
    current: torch.FloatTensor,
    previous: torch.FloatTensor,
    n_src: int,
    window_size: int,
    hop_size: int,
):
    """
     Reorder sources in current chunk to maximize correlation with previous chunk.
     Used for Continuous Source Separation. Standard dsp correlation is used
     for reordering.


    Args:
        current (:class:`torch.Tensor`): current chunk, tensor
                                        of shape (batch, n_src, window_size)
        previous (:class:`torch.Tensor`): previous chunk, tensor
                                        of shape (batch, n_src, window_size)
        n_src (:class:`int`): number of sources.
        window_size (:class:`int`): window_size, equal to last dimension of
                                    both current and previous.
        hop_size (:class:`int`): hop_size between current and previous tensors.

    """
    batch, frames = current.size()
    current = current.reshape(-1, n_src, frames)
    previous = previous.reshape(-1, n_src, frames)

    overlap_f = window_size - hop_size

    def reorder_func(x, y):
        x = x[..., :overlap_f]
        y = y[..., -overlap_f:]
        # Mean normalization
        x = x - x.mean(-1, keepdim=True)
        y = y - y.mean(-1, keepdim=True)
        # Negative mean Correlation
        return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)

    # We maximize correlation-like between previous and current.
    pit = PITReorder(reorder_func)
    current = pit(current, previous)
    return current.reshape(batch, frames)


class DualPathProcessing(nn.Module):
    """
    Perform Dual-Path processing via overlap-add as in DPRNN [1].

    Args:
        chunk_size (int): Size of segmenting window.
        hop_size (int): segmentation hop size.

    References
        [1] Yi Luo, Zhuo Chen and Takuya Yoshioka. "Dual-path RNN: efficient
        long sequence modeling for time-domain single-channel speech separation"
        https://arxiv.g/abs/1910.06379
    """

    def __init__(self, chunk_size, hop_size):
        super(DualPathProcessing, self).__init__()
        self.chunk_size = chunk_size
        self.hop_size = hop_size
        self.n_orig_frames = None

    def unfold(self, x):
        r"""
        Unfold the feature tensor from $(batch, channels, time)$ to
        $(batch, channels, chunksize, nchunks)$.

        Args:
            x (:class:`torch.Tensor`): feature tensor of shape $(batch, channels, time)$.

        Returns:
            :class:`torch.Tensor`: spliced feature tensor of shape
            $(batch, channels, chunksize, nchunks)$.

        """
        # x is (batch, chan, frames)
        batch, chan, frames = x.size()
        assert x.ndim == 3
        self.n_orig_frames = x.shape[-1]
        unfolded = torch.nn.functional.unfold(
            x.unsqueeze(-1),
            kernel_size=(self.chunk_size, 1),
            padding=(self.chunk_size, 0),
            stride=(self.hop_size, 1),
        )

        return unfolded.reshape(
            batch, chan, self.chunk_size, -1
        )  # (batch, chan, chunk_size, n_chunks)

    def fold(self, x, output_size=None):
        r"""
        Folds back the spliced feature tensor.
        Input shape $(batch, channels, chunksize, nchunks)$ to original shape
        $(batch, channels, time)$ using overlap-add.

        Args:
            x (:class:`torch.Tensor`): spliced feature tensor of shape
                $(batch, channels, chunksize, nchunks)$.
            output_size (int, optional): sequence length of original feature tensor.
                If None, the original length cached by the previous call of
                :meth:`unfold` will be used.

        Returns:
            :class:`torch.Tensor`:  feature tensor of shape $(batch, channels, time)$.

        .. note:: `fold` caches the original length of the input.

        """
        output_size = output_size if output_size is not None else self.n_orig_frames
        # x is (batch, chan, chunk_size, n_chunks)
        batch, chan, chunk_size, n_chunks = x.size()
        to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
        x = torch.nn.functional.fold(
            to_unfold,
            (output_size, 1),
            kernel_size=(self.chunk_size, 1),
            padding=(self.chunk_size, 0),
            stride=(self.hop_size, 1),
        )

        # force float div for torch jit
        x /= float(self.chunk_size) / self.hop_size

        return x.reshape(batch, chan, self.n_orig_frames)

    @staticmethod
    def intra_process(x, module):
        r"""Performs intra-chunk processing.

        Args:
            x (:class:`torch.Tensor`): spliced feature tensor of shape
                (batch, channels, chunk_size, n_chunks).
            module (:class:`torch.nn.Module`): module one wish to apply to each chunk
                of the spliced feature tensor.

        Returns:
            :class:`torch.Tensor`: processed spliced feature tensor of shape
            $(batch, channels, chunksize, nchunks)$.

        .. note:: the module should have the channel first convention and accept
            a 3D tensor of shape $(batch, channels, time)$.
        """

        # x is (batch, channels, chunk_size, n_chunks)
        batch, channels, chunk_size, n_chunks = x.size()
        # we reshape to batch*chunk_size, channels, n_chunks
        x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
        x = module(x)
        x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
        return x

    @staticmethod
    def inter_process(x, module):
        r"""Performs inter-chunk processing.

        Args:
            x (:class:`torch.Tensor`): spliced feature tensor of shape
                $(batch, channels, chunksize, nchunks)$.
            module (:class:`torch.nn.Module`): module one wish to apply between
                each chunk of the spliced feature tensor.


        Returns:
            x (:class:`torch.Tensor`): processed spliced feature tensor of shape
            $(batch, channels, chunksize, nchunks)$.

        .. note:: the module should have the channel first convention and accept
            a 3D tensor of shape $(batch, channels, time)$.
        """

        batch, channels, chunk_size, n_chunks = x.size()
        x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
        x = module(x)
        x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
        return x
