from typing import List

import torch
from torch import nn, Tensor
from torchsde import SDEIto


class MlpSde(SDEIto):
    def __init__(self, layers: List[int]):
        super().__init__(noise_type="diagonal")
        assert layers[-1] == layers[0], "input and output sizes should be the same"

        self.log_sigma = nn.Parameter(torch.zeros(()))
        # self.log_sigma = torch.log(torch.ones(()) * 0.15)
        self.mlp = nn.Sequential(
            *[
                layer
                for li, lo in zip(layers[:-2], layers[1:-1])
                for layer in (nn.Linear(li, lo), nn.ReLU())
            ],
            nn.Linear(layers[-2], layers[-1])
        )

    @property
    def sigma(self):
        return self.log_sigma.exp()

    def f(self, t, x: Tensor):
        return self.mlp(x)

    def g(self, t, x: Tensor):
        return torch.ones_like(x) * self.sigma
