from pina.model.layers import ContinuousConvBlock
import torch


def prod(iterable):
    p = 1
    for n in iterable:
        p *= n
    return p


def make_grid(x):
    def _transform_image(image):

        # extracting image info
        channels, dimension = image.size()[0], image.size()[1:]

        # initializing transfomed image
        coordinates = torch.zeros(
            [channels, prod(dimension), len(dimension) + 1]).to(image.device)

        # creating the n dimensional mesh grid
        values_mesh = [torch.arange(0, dim).float().to(
            image.device) for dim in dimension]
        mesh = torch.meshgrid(values_mesh)
        coordinates_mesh = [x.reshape(-1, 1) for x in mesh]
        coordinates_mesh.append(0)

        for count, channel in enumerate(image):
            coordinates_mesh[-1] = channel.reshape(-1, 1)
            coordinates[count] = torch.cat(coordinates_mesh, dim=1)

        return coordinates

    output = [_transform_image(current_image) for current_image in x]
    return torch.stack(output).to(x.device)


class MLP(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self. model = torch.nn.Sequential(torch.nn.Linear(2, 8),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(8, 8),
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(8, 1))

    def forward(self, x):
        return self.model(x)


# INPUTS
channel_input = 2
channel_output = 6
batch = 2
N = 10
dim = [3, 3]
stride = {"domain": [10, 10],
          "start": [0, 0],
          "jumps": [3, 3],
          "direction": [1, 1.]}
dim_filter = len(dim)
dim_input = (batch, channel_input, 10, dim_filter)
dim_output = (batch, channel_output, 4, dim_filter)
x = torch.rand(dim_input)
x = make_grid(x)


def test_constructor():
    model = MLP

    conv = ContinuousConvBlock(channel_input,
                          channel_output,
                          dim,
                          stride,
                          model=model)
    conv = ContinuousConvBlock(channel_input,
                          channel_output,
                          dim,
                          stride,
                          model=None)


def test_forward():
    model = MLP

    # simple forward
    conv = ContinuousConvBlock(channel_input,
                          channel_output,
                          dim,
                          stride,
                          model=model)
    conv(x)

    # simple forward with optimization
    conv = ContinuousConvBlock(channel_input,
                          channel_output,
                          dim,
                          stride,
                          model=model,
                          optimize=True)
    conv(x)


def test_transpose():
    model = MLP

    # simple transpose
    conv = ContinuousConvBlock(channel_input,
                          channel_output,
                          dim,
                          stride,
                          model=model)

    conv2 = ContinuousConvBlock(channel_output,
                           channel_input,
                           dim,
                           stride,
                           model=model)

    integrals = conv(x)
    conv2.transpose(integrals[..., -1], x)

    # stride_no_overlap = {"domain": [10, 10],
    #                      "start": [0, 0],
    #                      "jumps": dim,
    #                      "direction": [1, 1.]}

    ## simple transpose with optimization
    # conv = ContinuousConvBlock(channel_input,
    #                       channel_output,
    #                       dim,
    #                       stride_no_overlap,
    #                       model=model,
    #                       optimize=True,
    #                       no_overlap=True)

    # integrals = conv(x)
    # conv.transpose(integrals[..., -1], x)