from .base import Algo
import torch
import tqdm
import numpy as np
from .sampling_utils import get_pc_sampler

class DPS(Algo):
    def __init__(self, net, forward_op, num_steps=200,  eps=1e-5, alpha=1, device='cuda'):
        super().__init__(net=net, forward_op=forward_op)
        self.graph = self.net.graph
        self.noise = self.net.noise
        self.one_step_sampler = get_pc_sampler(self.graph, self.noise, (1,1024), 'euler', 1 , device=device)
        self.device = device
        self.num_steps = num_steps
        self.alpha = alpha
        self.timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
        self.dt = (1 - eps) / num_steps
        self.alpha = alpha

    def get_cond_score(self, x, op, y, t):
        # x : [N, L], neighbors: # [N, LD, L]
        L = x.shape[-1]
        D = self.graph.dim
        N = x.shape[0]
        neighbors = []
        for i in range(L):
            temp = x.clone().unsqueeze(1).repeat(1, D, 1)
            temp[:, :, i] = torch.arange(D).unsqueeze(0).expand(N, -1)
            neighbors.append(temp)
        neighbors = torch.cat(neighbors, dim=1) # [N, LD, L]
        neighbors = neighbors.view(N*L*D, L)
        batch_size = 512
        num_batches = np.ceil(N*L*D / batch_size).astype(int)
        x0hat = []
        for i in range(num_batches):
            end = min((i+1)*batch_size, N*L*D)
            x0hat.append(self.one_step_sampler(self.net, neighbors[i*batch_size:end], t))
        x0hat = torch.cat(x0hat, dim=0)
        log_ratio = op.log_likelihood(x0hat, y.repeat(L*D, 1)).view(N, D, L) - op.log_likelihood(x, y).view(N, 1, 1)
        return log_ratio.permute(0,2,1), x0hat


    def inference(self, observation=None, num_samples=1, verbose=True):

        pbar = tqdm.trange(self.num_steps) if verbose else range(self.num_steps)
        x_start = self.graph.sample_limit(num_samples, self.net.length).to(self.device)
        
        xt = x_start.to(self.device)
        score_fn = lambda x, sigma: self.net.score(x,sigma).exp()
        for i in pbar:

            # 1. reverse diffusion
            t = self.timesteps[i] * torch.ones(xt.shape[0], 1, device=self.device)
            sigma, dsigma = self.noise(t)
            log_likelihood, _ = self.get_cond_score(xt, self.forward_op, observation, self.timesteps[i])
            score = score_fn(xt, sigma)
            score = score * torch.exp(log_likelihood*self.alpha)
            rev_rate = self.dt * dsigma[..., None] * self.graph.reverse_rate(xt, score)
            xt = self.graph.sample_rate(xt, rev_rate)
        return xt

