from .base import Algo
import torch
import tqdm
import numpy as np
from .sampling_utils import get_pc_sampler
from models.SEDD import utils as mutils

class SVDD(Algo):
    '''
        Implementation of Derivative-Free Guidance with soft value-based decoding.
        https://arxiv.org/abs/2408.08252
    '''

    def __init__(self, net, forward_op, num_steps=200, ode_steps=1, 
                 eps=1e-5, alpha=0, num_particles=20, mc=1, device='cuda'):
        super().__init__(net=net, forward_op=forward_op)
        self.model = self.net.model
        self.graph = self.net.graph
        self.noise = self.net.noise
        
        self.device = device
        self.num_steps = num_steps
        self.timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
        self.dt = (1 - eps) / num_steps
        self.num_particles = num_particles
        self.alpha = alpha
        self.ode_steps = ode_steps
        self.mc = mc
        self.uncond_sampler = get_pc_sampler(self.graph, self.noise, (1,1024), 'euler', self.ode_steps , device=device)

    @torch.no_grad()
    def inference(self, observation=None, num_samples=1, verbose=True):
        '''
            Implementation of SVDD-PM sampler.
        '''
        pbar = tqdm.trange(self.num_steps) if verbose else range(self.num_steps)
        xt = self.graph.sample_limit(num_samples, self.model.length).to(self.device)

        sampling_score_fn = mutils.get_score_fn(self.model, train=False, sampling=True)
        for i in pbar:
            # 1. Get unconditional samples at time t-1
            xt = torch.cat([xt]*self.num_particles, dim=0)
            t = self.timesteps[i] * torch.ones(xt.shape[0], 1, device=xt.device)
            sigma, dsigma = self.noise(t)
            score = sampling_score_fn(xt, sigma)
            rev_rate = self.dt * dsigma[..., None] * self.graph.reverse_rate(xt, score)
            x_next = self.graph.sample_rate(xt, rev_rate).view(self.num_particles, num_samples, -1).permute(1, 0, 2)
            
            # 2. calculate wt and sample zeta
            if self.mc > 1:
                zeta = self.sample_zeta_mc(x_next, self.forward_op, observation, self.timesteps[i+1])
            else:
                zeta = self.sample_zeta(x_next, self.forward_op, observation, self.timesteps[i+1])
            # 3. update xt
            xt = x_next[torch.arange(num_samples), zeta]
        return xt

    @torch.no_grad()
    def sample_zeta(self, xt, op, y, t):
        '''
            Posterior mean estimation
        '''
        batch_size, num_particles = xt.shape[0], xt.shape[1]
        xt = xt.flatten(0,1)
        x0hat = self.uncond_sampler(self.net, xt, t).view(batch_size, num_particles, -1)
        zeta = []
        log_likelihoods = []
        for i in range(batch_size):
            log_likelihood = op.log_likelihood(x0hat[i], y[i])
            if self.alpha == 0:
                zeta_i = torch.argmax(log_likelihood)
            else:
                # sample from exp(v/alpha)
                log_likelihood = log_likelihood - log_likelihood.max() # for numerical stability
                v = torch.exp(log_likelihood / self.alpha)
                zeta_i = torch.multinomial(v, 1).squeeze()
            zeta.append(zeta_i)
            log_likelihoods.append(log_likelihood)
        zeta = torch.stack(zeta)
        return zeta
    
    @torch.no_grad()
    def sample_zeta_mc(self, xt, op, y, t):
        '''
            Posterior mean estimation
        '''
        batch_size, num_particles = xt.shape[0], xt.shape[1]
        xt = xt.flatten(0,1)
        zeta = []
        log_likelihoods = [0 for _ in range(batch_size)]
        for i in range(self.mc):
            x0hat = self.uncond_sampler(self.net, xt, t).view(batch_size, num_particles, -1)

            for j in range(batch_size):
                log_likelihoods[j] += op.log_likelihood(x0hat[j], y[j])
        for i in range(batch_size):
            log_likelihood = log_likelihoods[i] / self.mc 
            if self.alpha == 0:
                zeta_i = torch.argmax(log_likelihood)
            else:
                # sample from exp(v/alpha)
                log_likelihood = log_likelihood - log_likelihood.max() # for numerical stability
                v = torch.exp(log_likelihood / self.alpha)
                zeta_i = torch.multinomial(v, 1).squeeze()
            zeta.append(zeta_i)
        zeta = torch.stack(zeta)
        return zeta

class SVDD_latent(SVDD):
    @torch.no_grad()
    def sample_zeta(self, xt, op, y, t):
        batch_size, num_particles = xt.shape[0], xt.shape[1]
        xt = xt.flatten(0,1)
        x0hat = self.net.decode(self.uncond_sampler(self.net, xt, t))
        x0hat = x0hat.view(batch_size, num_particles, *x0hat.shape[1:])
        zeta = []
        log_likelihoods = []
        for i in range(batch_size):
            log_likelihood = op.log_likelihood(x0hat[i], y[i])
            if self.alpha == 0:
                zeta_i = torch.argmax(log_likelihood)
            else:
                # sample from exp(v/alpha)
                log_likelihood = log_likelihood - log_likelihood.max() # for numerical stability
                v = torch.exp(log_likelihood / self.alpha)
                zeta_i = torch.multinomial(v, 1).squeeze()
            zeta.append(zeta_i)
            log_likelihoods.append(log_likelihood)
        zeta = torch.stack(zeta)
        return zeta
    
    @torch.no_grad()
    def inference(self, observation=None, num_samples=1, verbose=True):
        z = super().inference(observation, num_samples, verbose)
        return self.net.decode(z)