# Reproduced from https://github.com/montefiore-ai/hypothesis

import math
from typing import Callable

import torch.nn as nn
from torch import Tensor

from .lotka_volterra import LotkaVolterra


class LotkaVolterraNoRescale(LotkaVolterra):
    def __init__(self, data_path, benchmark_name: str = "lotka_volterra_no_rescale"):
        super().__init__(data_path, benchmark_name)

    def get_embedding_dim(self) -> int:
        return 16 * 32

    def get_embedding_build(self) -> Callable:
        class Prepare(nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x: Tensor) -> Tensor:
                """
                Reshapes the input according to the shape saved in the view data
                structure.
                """
                x = x.view([-1, 1001, 2]).permute((0, 2, 1))
                return x

        def get_embedding(
            embedding_dim: int, observable_shape: tuple[int, ...]
        ) -> nn.Module:
            nb_channels = 16
            nb_conv_layers = 10
            shrink_every = 2
            final_shape = 1001

            for i in range(nb_conv_layers):
                if i % shrink_every == 0:
                    final_shape = math.floor((final_shape - 1) / 2 + 1)
                else:
                    final_shape = final_shape

            cnn = [
                Prepare(),
                nn.Conv1d(in_channels=2, out_channels=nb_channels,
                          kernel_size=1),
            ]

            for i in range(nb_conv_layers):
                if i % shrink_every == 0:
                    stride = 2
                else:
                    stride = 1

                cnn.append(
                    nn.Conv1d(
                        in_channels=nb_channels,
                        out_channels=nb_channels,
                        kernel_size=3,
                        padding=1,
                    )
                )
                cnn.append(nn.SELU())
                cnn.append(nn.MaxPool1d(3, stride=stride, padding=1))

            cnn.append(nn.Flatten())

            return nn.Sequential(*cnn)

        return get_embedding