import torch


class Stride(object):

    def __init__(self, dict):
        """Stride class for continous convolution

        :param param: type of continuous convolution
        :type param: string
        """

        self._dict_stride = dict
        self._stride_continuous = None
        self._stride_discrete = self._create_stride_discrete(dict)

    def _create_stride_discrete(self, my_dict):
        """Creating the list for applying the filter

        :param my_dict: Dictionary with the following arguments:
            domain size, starting position of the filter, jump size
            for the filter and direction of the filter
        :type my_dict: dict
        :raises IndexError: Values in the dict must have all same length
        :raises ValueError: Domain values must be greater than 0
        :raises ValueError: Direction must be either equal to 1, -1 or 0
        :raises IndexError: Direction and jumps must have zero in the same
            index
        :return: list of positions for the filter
        :rtype: list
        :Example:


                >>> stride = {"domain": [4, 4],
                            "start": [-4, 2],
                            "jump": [2, 2],
                            "direction": [1, 1],
                            }
                >>> create_stride(stride)
                [[-4.0, 2.0], [-4.0, 4.0], [-2.0, 2.0], [-2.0, 4.0]]
        """

        # we must check boundaries of the input as well

        domain, start, jumps, direction = my_dict.values()

        # checking

        if not all([len(s) == len(domain) for s in my_dict.values()]):
            raise IndexError("values in the dict must have all same length")

        if not all(v >= 0 for v in domain):
            raise ValueError("domain values must be greater than 0")

        if not all(v == 1 or v == -1 or v == 0 for v in direction):
            raise ValueError("direction must be either equal to 1, -1 or 0")

        seq_jumps = [i for i, e in enumerate(jumps) if e == 0]
        seq_direction = [i for i, e in enumerate(direction) if e == 0]

        if seq_direction != seq_jumps:
            raise IndexError(
                "direction and jumps must have zero in the same index")

        if seq_jumps:
            for i in seq_jumps:
                jumps[i] = domain[i]
                direction[i] = 1

        # creating the stride grid
        values_mesh = [torch.arange(0, i, step).float()
                       for i, step in zip(domain, jumps)]

        values_mesh = [single * dim for single,
                       dim in zip(values_mesh, direction)]

        mesh = torch.meshgrid(values_mesh)
        coordinates_mesh = [x.reshape(-1, 1) for x in mesh]

        stride = torch.cat(coordinates_mesh, dim=1) + torch.tensor(start)

        return stride
