"""Problems of different noise types with analytical solutions.

Ex1, Ex2, Ex3 from

Rackauckas, Christopher, and Qing Nie. "Adaptive methods for stochastic
differential equations via natural embeddings and rejection sampling with memory."
Discrete and continuous dynamical systems. Series B 22.7 (2017): 2731.

Ex4 is constructed to test schemes for SDEs with general noise.
"""

import torch
from torch import nn

from torchsde import BaseSDE
from torchsde.settings import NOISE_TYPES, SDE_TYPES


def _scalar(g):
    def scalar_g(t, y):
        return g(t, y).unsqueeze(-1)

    return scalar_g


def _general(g):
    def general_g(t, y):
        return torch.diag_embed(g(t, y))

    return general_g


class Ex1(BaseSDE):
    def __init__(self, d=10, sde_type=SDE_TYPES.ito, noise_type=NOISE_TYPES.diagonal):
        super(Ex1, self).__init__(sde_type=sde_type, noise_type=noise_type)
        self.f = self.f_ito if sde_type == SDE_TYPES.ito else self.f_stratonovich
        self.g = {NOISE_TYPES.diagonal: self.g,
                  NOISE_TYPES.scalar: _scalar(self.g),
                  NOISE_TYPES.general: _general(self.g)}[noise_type]
        self._nfe = 0

        # Use non-exploding initialization.
        sigma = torch.sigmoid(torch.randn(d))
        mu = -sigma ** 2 - torch.sigmoid(torch.randn(d))
        self.mu = nn.Parameter(mu, requires_grad=True)
        self.sigma = nn.Parameter(sigma, requires_grad=True)

    def f_ito(self, t, y):
        self._nfe += 1
        return self.mu * y

    def f_stratonovich(self, t, y):
        self._nfe += 1
        return self.mu * y - .5 * (self.sigma ** 2) * y

    def g(self, t, y):
        self._nfe += 1
        return self.sigma * y

    def analytical_grad(self, y0, t, grad_output, bm):
        with torch.no_grad():
            ans = y0 * torch.exp((self.mu - self.sigma ** 2. / 2.) * t + self.sigma * bm(t))
            dmu = (grad_output * ans * t).mean(0)
            dsigma = (grad_output * ans * (-self.sigma * t + bm(t))).mean(0)
        return torch.cat((dmu, dsigma), dim=0)

    def analytical_sample(self, y0, ts, bm):
        assert ts[0] == 0
        with torch.no_grad():
            ans = [y0 * torch.exp((self.mu - self.sigma ** 2. / 2.) * t + self.sigma * bm(t)) for t in ts]
        return torch.stack(ans, dim=0)

    @property
    def nfe(self):
        return self._nfe


class Ex2(BaseSDE):
    def __init__(self, d=10, sde_type=SDE_TYPES.ito, noise_type=NOISE_TYPES.diagonal):
        super(Ex2, self).__init__(sde_type=sde_type, noise_type=noise_type)
        self.f = self.f_ito if sde_type == SDE_TYPES.ito else self.f_stratonovich
        self.g = {NOISE_TYPES.diagonal: self.g,
                  NOISE_TYPES.scalar: _scalar(self.g),
                  NOISE_TYPES.general: _general(self.g)}[noise_type]
        self._nfe = 0
        self.p = nn.Parameter(torch.sigmoid(torch.randn(d)), requires_grad=True)

    def f_ito(self, t, y):
        self._nfe += 1
        return -self.p ** 2. * torch.sin(y) * torch.cos(y) ** 3.

    def f_stratonovich(self, t, y):
        self._nfe += 1
        return torch.zeros_like(y)

    def g(self, t, y):
        self._nfe += 1
        return self.p * torch.cos(y) ** 2

    def analytical_grad(self, y0, t, grad_output, bm):
        with torch.no_grad():
            wt = bm(t)
            dp = (grad_output * wt / (1. + (self.p * wt + torch.tan(y0)) ** 2.)).mean(0)
        return dp

    def analytical_sample(self, y0, ts, bm):
        assert ts[0] == 0
        with torch.no_grad():
            ans = [torch.atan(self.p * bm(t) + torch.tan(y0)) for t in ts]
        return torch.stack(ans, dim=0)

    @property
    def nfe(self):
        return self._nfe


# TODO: Make this a test problem for additive noise settings with decoupled m and d.
class Ex3(BaseSDE):
    def __init__(self, d=10, sde_type=SDE_TYPES.ito, noise_type=NOISE_TYPES.diagonal):
        super(Ex3, self).__init__(sde_type=sde_type, noise_type=noise_type)
        self.g = {NOISE_TYPES.diagonal: self.g,
                  NOISE_TYPES.scalar: _scalar(self.g),
                  NOISE_TYPES.additive: _general(self.g),
                  NOISE_TYPES.general: _general(self.g)}[noise_type]
        self._nfe = 0
        self.a = nn.Parameter(torch.sigmoid(torch.randn(d)), requires_grad=True)
        self.b = nn.Parameter(torch.sigmoid(torch.randn(d)), requires_grad=True)

    def f(self, t, y):
        self._nfe += 1
        return self.b / torch.sqrt(1. + t) - y / (2. + 2. * t)

    def g(self, t, y):
        self._nfe += 1
        return self.a * self.b / torch.sqrt(1. + t) + torch.zeros_like(y)  # Add dummy zero to make dimensions match.

    def analytical_grad(self, y0, t, grad_output, bm):
        with torch.no_grad():
            wt = bm(t)
            da = grad_output * self.b * wt / torch.sqrt(1. + t)
            db = grad_output * (t + self.a * wt) / torch.sqrt(1. + t)
            da = da.mean(0)
            db = db.mean(0)
        return torch.cat((da, db), dim=0)

    def analytical_sample(self, y0, ts, bm):
        assert ts[0] == 0
        with torch.no_grad():
            ans = [y0 / torch.sqrt(1 + t) + self.b * (t + self.a * bm(t)) / torch.sqrt(1 + t) for t in ts]
        return torch.stack(ans, dim=0)

    @property
    def nfe(self):
        return self._nfe


def _column_wise_func(y, t, i):
    # This function is designed so that there are mixed partials.
    return (torch.cos(y * i + t * 0.1) * 0.2 +
            torch.sum(y, dim=-1, keepdim=True).cos() * 0.1)


class Ex4(BaseSDE):
    def __init__(self, d, m, sde_type=SDE_TYPES.ito):
        super(Ex4, self).__init__(sde_type=sde_type, noise_type=NOISE_TYPES.general)
        self.d = d
        self.m = m

    def f(self, t, y):
        return torch.sin(y) + t

    def g(self, t, y):
        return torch.stack([_column_wise_func(y, t, i) for i in range(self.m)], dim=-1)


class Ex5(BaseSDE):
    def __init__(self, d, m, sde_type=SDE_TYPES.ito):
        super(Ex5, self).__init__(sde_type=sde_type, noise_type=NOISE_TYPES.general)
        self.d = d
        self.m = m

        self.f_net = nn.Sequential(
            nn.Linear(d + 1, 3),
            nn.Softplus(),
            nn.Linear(3, d)
        )
        self.g_net = nn.Sequential(
            nn.Linear(d + 1, 3),
            nn.Softplus(),
            nn.Linear(3, d * m),
            nn.Sigmoid()
        )

    def f(self, t, y):
        ty = torch.cat((t.expand_as(y[:, :1]), y), dim=1)
        return self.f_net(ty)

    def g(self, t, y):
        ty = torch.cat((t.expand_as(y[:, :1]), y), dim=1)
        return self.g_net(ty).reshape(-1, self.d, self.m)
