import torch
from utils.pca import compute_u

def generate_samples(score_network: torch.nn.Module, nsamples: int, u) -> torch.Tensor:
    device = next(score_network.parameters()).device
    u = u.to(device)
    x_t = torch.randn((nsamples, 32 * 32), device=device)  # (nsamples, nch)
    time_pts = torch.linspace(1, 0, 1000, device=device)  # (ntime_pts,)
    beta = lambda t: 0.1 + (20 - 0.1) * t
    for i in range(len(time_pts) - 1):
        t = time_pts[i]
        dt = time_pts[i + 1] - t

        # calculate the drift and diffusion terms
        fxt = -0.5 * beta(t) * x_t
        gt = beta(t) ** 0.5
        score = score_network(x_t, t.expand(x_t.shape[0], 1), u).detach()
        drift = fxt - gt * gt * score
        diffusion = gt

        # euler-maruyama step
        x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5
    return x_t


def generate_samples_z(dim, score_network: torch.nn.Module, nsamples: int, u=None, beta_min=0.1, beta_max=20) -> torch.Tensor:
    device = next(score_network.parameters()).device
    if u is not None:
        u = u.to(device)
    x_t = torch.randn((nsamples, dim), device=device)  # (nsamples, nch)
    time_pts = torch.linspace(1, 0, 1000, device=device)  # (ntime_pts,)
    beta = lambda t: beta_min + (beta_max - beta_min) * t
    for i in range(len(time_pts) - 1):
        t = time_pts[i]
        dt = time_pts[i + 1] - t

        # calculate the drift and diffusion terms
        fxt = -0.5 * beta(t) * x_t
        gt = beta(t) ** 0.5

        score = score_network(x_t, t.expand(x_t.shape[0], 1), u).detach()
        drift = fxt - gt * gt * score
        diffusion = gt

        # euler-maruyama step
        x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5
    return x_t

def generate_training_samples(x: torch.Tensor, model: torch.nn.Module, nsamples: int) -> torch.Tensor:
    device = next(model.parameters()).device
    _, _, _, z = model.vae(x)
    z = z.reshape(model.num_examples, model.num_samples, -1)
    u = compute_u(z)
    dim = z.shape[-1]
    x_t = torch.randn((nsamples, dim), device=device)  # (nsamples, nch)
    time_pts = torch.linspace(1, 0, 1000, device=device)  # (ntime_pts,)
    beta = lambda t: 0.1 + (20 - 0.1) * t
    for i in range(len(time_pts) - 1):
        t = time_pts[i]
        dt = time_pts[i + 1] - t
        # calculate the drift and diffusion terms
        fxt = -0.5 * beta(t) * x_t
        gt = beta(t) ** 0.5
        score = model.scorenet(x_t, t.expand(x_t.shape[0], 1), u).detach()
        drift = fxt - gt * gt * score
        diffusion = gt

        # euler-maruyama step
        x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5

    x_t = x_t.detach().reshape(nsamples, dim)
    samples = model.vae.decoder(x_t).cuda()

    return samples

def generate_testing_samples(model: torch.nn.Module, u: torch.nn.Module) -> torch.Tensor:
    device = next(model.parameters()).device
    nsamples = u.shape[0]
    x_t = torch.randn((nsamples, model.z_dim), device=device)  # (nsamples, nch)
    time_pts = torch.linspace(1, 0, 1000, device=device)  # (ntime_pts,)
    beta = lambda t: 0.1 + (20 - 0.1) * t
    for i in range(len(time_pts) - 1):
        t = time_pts[i]
        dt = time_pts[i + 1] - t
        # calculate the drift and diffusion terms
        fxt = -0.5 * beta(t) * x_t
        gt = beta(t) ** 0.5
        score = model.scorenet(x_t, t.expand(x_t.shape[0], 1), u).detach()
        drift = fxt - gt * gt * score
        diffusion = gt

        # euler-maruyama step
        x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5

    x_t = x_t.detach().reshape(nsamples, model.z_dim).reshape(x_t.shape[0],16,4,4)
    samples = model.vae.decoder(x_t).cuda()

    return samples