from .base import Algo
import torch
import tqdm
import numpy as np
from .sampling_utils import get_pc_sampler
from models.SEDD import utils as mutils
from time import time
from torchvision.utils import make_grid, save_image

class SMC(Algo):
    '''
        Implementation of Sequential Monte Carlo for discrete diffusion.
    '''

    def __init__(self, net, forward_op, num_steps=200, ode_steps=1, 
                 eps=1e-5, alpha=0, num_particles=20, mc=1, device='cuda'):
        super().__init__(net=net, forward_op=forward_op)
        self.model = self.net.model
        self.graph = self.net.graph
        self.noise = self.net.noise
        
        self.device = device
        self.num_steps = num_steps
        self.timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
        self.dt = (1 - eps) / num_steps
        self.num_particles = num_particles
        self.alpha = alpha
        self.ode_steps = ode_steps
        self.mc = mc
        self.uncond_sampler = get_pc_sampler(self.graph, self.noise, (1,1024), 'euler', self.ode_steps , device=device)

    @torch.no_grad()
    def inference(self, observation=None, num_samples=1, verbose=True):
        '''
            Implementation of SMC sampler.
        '''
        pbar = tqdm.trange(self.num_steps) if verbose else range(self.num_steps)
        xt = self.graph.sample_limit(num_samples * self.num_particles, self.model.length).to(self.device)

        sampling_score_fn = mutils.get_score_fn(self.model, train=False, sampling=True)
        log_likelihood = torch.zeros([num_samples, self.num_particles], device=xt.device)
        batch_indices = torch.arange(num_samples, device=xt.device).unsqueeze(1).expand(-1, self.num_particles)
        for i in pbar:
            # 1. Get unconditional samples at time t-1
            t = self.timesteps[i] * torch.ones(xt.shape[0], 1, device=xt.device)
            sigma, dsigma = self.noise(t)
            score = sampling_score_fn(xt, sigma)
            rev_rate = self.dt * dsigma[..., None] * self.graph.reverse_rate(xt, score)
            time1 = time()
            x_next = self.graph.sample_rate(xt, rev_rate).view(num_samples, self.num_particles,-1)
            time2 = time()
            # print(f"Sampling time: {time2 - time1:.4f}s")
            log_likelihood_new = self.compute_loglikelihood(x_next, self.forward_op, observation, self.timesteps[i+1])
            log_likelihood_new /= self.alpha
            time3 = time()
            # print(f"Log likelihood time: {time3 - time2:.4f}s")
            weights = log_likelihood_new - log_likelihood
            weights = torch.exp(weights - torch.max(weights, dim=1, keepdim=True).values)
            weights = weights / torch.sum(weights, dim=1, keepdim=True)
            print(log_likelihood_new[0])
            choices = torch.multinomial(weights, num_samples=self.num_particles, replacement=True) # [batch_size, num_particles]
            xt = x_next[batch_indices, choices] # [batch_size, num_particles, dim]
            log_likelihood = log_likelihood_new[batch_indices, choices]
            xt = xt.flatten(0,1)
        return xt.view(num_samples, self.num_particles, -1)[:,0]


    @torch.no_grad()
    def compute_loglikelihood(self, xt, op, y, t):
        '''
            MC estimation of log likelihood
        '''
        batch_size, num_particles = xt.shape[0], xt.shape[1]
        xt = xt.flatten(0,1)
        log_likelihoods = [0 for _ in range(batch_size)]
        for i in range(self.mc):
            time1 = time()
            x0hat = self.uncond_sampler(self.net, xt, t).view(batch_size, num_particles, -1)
            time2 = time()
            # print(f"Posterior mean time (per mc): {time2 - time1:.4f}s")
            for i in range(batch_size):
                log_likelihoods[i] += op.log_likelihood(x0hat[i], y[i])/self.mc 
        return torch.stack(log_likelihoods, dim=0)