"""Compare gradients computed with adjoint vs analytical solution."""
import sys

sys.path = sys.path[1:]  # A hack so that we always import the installed library.

import itertools
import unittest

import pytest
import torch

import torchsde
from .basic_sde import BasicSDE1, BasicSDE2, BasicSDE3, BasicSDE4
from .problems import Ex1, Ex2, Ex3
from .utils import assert_allclose

torch.manual_seed(1147481649)
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.get_default_dtype()

ito_methods = {'euler': 'ito',
               'milstein': 'ito',
               'srk': 'ito'}
stratonovich_methods = {'midpoint': 'stratonovich',
                        'log_ode': 'stratonovich'}


@pytest.mark.parametrize("problem", [Ex1, Ex2, Ex3])
@pytest.mark.parametrize("method, sde_type", itertools.chain(ito_methods.items(), stratonovich_methods.items()))
@pytest.mark.parametrize("noise_type", ['diagonal', 'scalar', 'additive', 'general'])
@pytest.mark.parametrize('adaptive', (False, True))
def test_adjoint(problem, method, sde_type, noise_type, adaptive):
    if method == 'euler' and adaptive:
        return
    if problem is not Ex3 and noise_type == 'additive':
        return
    # TODO: remove this once we have adjoint implemented for other noise/sde combinations
    if sde_type == 'stratonovich' and noise_type != 'general':
        return
    if sde_type == 'ito' and noise_type == 'general':
        return

    d = 1 if noise_type == 'scalar' else 10
    batch_size = 128
    t0, t1 = ts = torch.tensor([0.0, 0.5], device=device)
    dt = 1e-3
    y0 = torch.zeros(batch_size, d).to(device).fill_(0.1)

    problem = problem(d, sde_type=sde_type, noise_type=noise_type).to(device)

    levy_area_approximation = {
        'euler': 'none',
        'milstein': 'none',
        'srk': 'space-time',
        'midpoint': 'none',
        'log_ode': 'foster'
    }[method]
    bm = torchsde.BrownianInterval(
        t0=t0, t1=t1, shape=(batch_size, d), dtype=dtype, device=device,
        levy_area_approximation=levy_area_approximation
    )
    with torch.no_grad():
        grad_outputs = torch.ones(batch_size, d).to(device)
        alt_grad = problem.analytical_grad(y0, t1, grad_outputs, bm)

    problem.zero_grad()
    _, yt = torchsde.sdeint_adjoint(problem, y0, ts, bm=bm, method=method, dt=dt, adaptive=adaptive)
    loss = yt.sum(dim=1).mean(dim=0)
    loss.backward()
    adj_grad = torch.cat(tuple(p.grad for p in problem.parameters()))
    assert_allclose(alt_grad, adj_grad)


@pytest.mark.parametrize("problem", [BasicSDE1, BasicSDE2, BasicSDE3, BasicSDE4])
@pytest.mark.parametrize("method", ito_methods.keys())
@pytest.mark.parametrize('adaptive', (False, True))
def test_basic(problem, method, adaptive):
    if method == 'euler' and adaptive:
        return

    d = 10
    batch_size = 128
    ts = torch.tensor([0.0, 0.5], device=device)
    dt = 1e-3
    y0 = torch.zeros(batch_size, d).to(device).fill_(0.1)

    problem = problem(d).to(device)

    num_before = _count_differentiable_params(problem)

    problem.zero_grad()
    _, yt = torchsde.sdeint_adjoint(problem, y0, ts, method=method, dt=dt, adaptive=adaptive)
    loss = yt.sum(dim=1).mean(dim=0)
    loss.backward()

    num_after = _count_differentiable_params(problem)
    assert num_before == num_after


def _count_differentiable_params(module):
    return len([p for p in module.parameters() if p.requires_grad])


if __name__ == '__main__':
    unittest.main()
