import numpy as np
import torch as th
from functools import partial

import torchsde

# pylint: disable=unused-argument


def loss_wrapper(cfg):
    dt, t_end = cfg.model.dt, cfg.model.t_end
    dim = np.prod(cfg.data.shape)
    t1 = th.tensor([0.0, t_end]).cuda()  # pylint: disable=not-callable
    normal_const = dim / 2 * np.log(2 * np.pi) + 0.5 * dim * np.log(t_end)
    if cfg.model.sde.sde_type == "ito":
        sdeint = partial(torchsde.sdeint, method="euler")
    elif cfg.model.sde.sde_type == "stratonovich":
        sdeint = partial(torchsde.sdeint_adjoint, method="reversible_heun",
            adjoint_method="adjoint_reversible_heun")
    else:
        raise NotImplementedError

    def _loss_fn(trainer, feed_dict, is_train):
        # pylint: disable=invalid-name
        model = trainer.model
        disc_fn = trainer.train_set.get_disc
        y0 = th.zeros((feed_dict.shape[0], model.ndim + model.nreg)).cuda()
        ys = sdeint(model, y0, t1, dt=dt, bm=trainer.bm)
        y1 = ys[-1]

        # TODO: we assume only one reg loss for now
        reg_loss = y1[:, -1].mean()
        state = th.nan_to_num(y1[:, :-1])
        disc_loss = disc_fn(state).mean()
        quad_loss = -0.5 * state.pow(2).sum(dim=-1).mean() / t_end - normal_const
        loss = reg_loss + disc_loss + quad_loss
        return (
            loss,
            {},
            {
                "reg_loss": reg_loss,
                "disc_loss": disc_loss,
                "quad_loss": quad_loss,
            },
        )

    return _loss_fn
