from functools import partial

import numpy as np
import torch

from gen_neg_toy.utils import expand_tensor_dims_as


@torch.no_grad()
def compute_elbo(
    batch, net, sde, guidance_fn=None, num_steps=18, model_kwargs={}
):
    net_forward = guidance_fn(net.forward) if guidance_fn is not None else net.forward
    net_forward = partial(net_forward, **model_kwargs)
    # Adjust noise levels based on what's supported by the network.
    assert sde.sigma_min >= net.sigma_min, f'Network supports sigma_min={net.sigma_min}, but SDE requires sigma_min={sde.sigma_min}'
    assert sde.sigma_max <= net.sigma_max, f'Network supports sigma_max={net.sigma_max}, but SDE requires sigma_max={sde.sigma_max}'

    x0 = batch
    dim = x0.shape[-1]

    # Define time steps in terms of noise level.
    sigma_steps = sde.get_sigma_steps(num_steps, device=x0.device)

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sde.sigma_inv(net.round_sigma(sigma_steps))
    #t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
    t_steps = t_steps.flip([0]) # 0, ..., N-1

    # Main loop.
    t_next = t_steps[0]
    res = 0
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        h = t_next - t_cur
        x_cur = (x0 + torch.randn_like(x0) * sde.sigma(t_cur)) * sde.scale(t_cur)
        denoised = net_forward(x_cur / sde.scale(t_cur), sde.sigma(t_cur) * x_cur.new_ones(len(x_cur))).to(torch.float64)
        # ELBO computation
        elbo_coef = sde.sigma_deriv(t_cur) / (sde.sigma(t_cur) ** 3)
        err = ((denoised - x0).square())
        if "obs_mask" in model_kwargs:
            err *= 1 - model_kwargs["obs_mask"]
        term_a = err.sum(dim=-1) * elbo_coef
        term_b = sde.sigma_deriv(t_cur) / sde.sigma(t_cur) + (sde.scale_deriv(t_cur) / sde.scale(t_cur)) * dim
        res += (term_b - term_a) * h.abs()

    # Prior term
    item = (x0 / sde.sigma_max) ** 2 / 2
    item = item.sum(dim=-1)
    res += item

    return res.mean().item()