import torch
import numpy as np

from gen_neg_toy.evaluation.evidence import compute_elbo
from gen_neg_toy.sampling import edm_sampler
from gen_neg_toy.sde_lib import EDMSDE
from gen_neg_toy.utils import splitit


def _draw_samples_core(model, n_samples, device, **sampler_kwargs):
    latents = torch.randn([n_samples, model.dim], device=device)
    model_kwargs = {}
    samples, nfe = edm_sampler(
        model, latents=latents, model_kwargs=model_kwargs, **sampler_kwargs
    )
    samples = model.update_with_observations(samples, **model_kwargs)

    return samples, nfe


def draw_samples(model, n_samples, device, max_batch_size=20000, **sampler_kwargs):
    nfe = 0
    samples = []
    for batch_size in splitit(n_samples, max_batch_size):
        new_samples, new_nfe = _draw_samples_core(model, batch_size, device, **sampler_kwargs)
        samples.append(new_samples)
        nfe += new_nfe
    samples = torch.cat(samples)
    return samples, nfe


def elbo(model, data, device, **compute_elbo_kwargs):
    x0, validity = data
    model_kwargs = {}
    x0 = x0.float().to(device)

    return compute_elbo(
        x0, model, sde=EDMSDE(), model_kwargs=model_kwargs, **compute_elbo_kwargs
    )


def elbo_from_dataloader(model, dataloader, device, **compute_elbo_kwargs):
    elbo_batches = []
    cnt = 0
    for x0, validity in dataloader:
        # Prepare data
        x0 = x0.float()
        validity = validity.float() * 2 - 1  # Validity is now \in {-1, 1}

        res = elbo(model, (x0, validity), device, **compute_elbo_kwargs)
        elbo_batches.append(res * x0.shape[0])
        cnt += x0.shape[0]
    return (np.array(elbo_batches) / cnt).sum()