import torch
import numpy as np


class BatchTransform(object):

    def __init__(self, only_trial_data=True):
        """
        Batch transforms are operations that are performed on trial tensors after being accumulated into batches via the
        :meth:`__call__` method. Ideally this is implemented with pytorch operations for ease of execution graph
        integration.
        """
        self.only_trial_data = only_trial_data

    def __str__(self):
        return self.__class__.__name__

    def __call__(self, *x, training=False):
        """
        Modifies a batch of tensors.

        Parameters
        ----------
        x : torch.Tensor, tuple
            A batch of trial instance tensor. If initialized with `only_trial_data=False`, then this includes batches
            of all other loaded tensors as well.
        training: bool
                  Indicates whether this is a training batch or otherwise, allowing for alternate behaviour during
                  evaluation.

        Returns
        -------
        x : torch.Tensor, tuple
            The modified trial tensor batch, or tensors if not `only_trial_data`
        """
        raise NotImplementedError()


class RandomTemporalCrop(BatchTransform):

    def __init__(self, max_crop_frac=0.25, temporal_axis=1):
        """
        Uniformly crops the time-dimensions of a batch.

        Parameters
        ----------
        max_crop_frac: float
                       The is the maximum fraction to crop off of the trial.
        """
        super(RandomTemporalCrop, self).__init__(only_trial_data=True)
        assert 0 < max_crop_frac < 1
        self.max_crop_frac = max_crop_frac
        self.temporal_axis = temporal_axis

    def __call__(self, x, training=False):
        if not training:
            return x

        trial_len = x.shape[self.temporal_axis]
        crop_len = np.random.randint(int((1 - self.max_crop_frac) * trial_len), trial_len)
        offset = np.random.randint(0, trial_len - crop_len)

        return x[:, offset:offset + crop_len, ...]


class RandomTemporalEndCrop(BatchTransform):

    def __init__(self, end_crop_frac=0.25, crop_weights=None, temporal_axis=1):
        """
        Crops the time dimension of an entire batch.

        Parameters
        ----------
        end_crop_frac: float
                       If this is specified (and `crop_weights` is not), a crop end is selected uniformly from the
                        last `max_crop_frac` indices.
        crop_weights: list, array-like
                      If specified, this should be a list of un-normalized weights used to weight the selection of the
                      last `len(crop_weights)` indicies to crop to.
        """
        super(RandomTemporalEndCrop, self).__init__(only_trial_data=True)
        self.end_crop_frac = end_crop_frac
        self.crop_weights = np.array(crop_weights)
        self.temporal_axis = temporal_axis

    def __call__(self, x, training=False):
        if not training:
            return x
        if self.crop_weights is None:
            assert 0 <= self.end_crop_frac <= 1
            self.crop_weights = np.ones(int(x.shape[self.temporal_axis] * self.end_crop_frac))

        no_crop_len = x.shape[self.temporal_axis] - len(self.crop_weights)
        assert no_crop_len >= 0
        inds = np.arange(no_crop_len, x.shape[self.temporal_axis])
        crop_location = np.random.choice(inds, p=self.crop_weights / self.crop_weights.sum())
        return x[:, :crop_location, ...]
