import torch
import abc
from tqdm import tqdm
from functools import partial

from utils.graph_lib import Absorbing
from utils.guidance_schedules import GuidanceSchedule, ConstantSchedule

class Predictor(abc.ABC):
    """The abstract class for a predictor algorithm."""

    def __init__(self, graph, noise):
        super().__init__()
        self.graph = graph
        self.noise = noise

    @abc.abstractmethod
    def update_fn(self, score_fn, x, t, step_size):
        """One update of the predictor.

        Args:
            score_fn: score function
            x: A PyTorch tensor representing the current state
            t: A Pytorch tensor representing the current time step.

        Returns:
            x: A PyTorch tensor of the next state.
        """
        pass


class EulerPredictor(Predictor):
    def update_fn(self, score_fn, x, t, step_size):
        sigma, dsigma = self.noise(t)
        score = score_fn(x, sigma)

        rev_rate = step_size * dsigma[..., None] * self.graph.reverse_rate(x, score)
        x = self.graph.sample_rate(x, rev_rate)
        return x

@torch.no_grad()    
def get_pc_sampler(model, shape, cond, steps, device, 
        graph : Absorbing, 
        guidance_schedule : GuidanceSchedule = ConstantSchedule(1.), 
        use_tau_leaping=True, 
        normalize_guid=False, 
        return_traj=False,
        force_condition_class=None):
    x = graph.sample_limit(*shape).to(device)
    eps = graph.delta
    timesteps = torch.linspace(1, eps, steps + 1, device=device)
    dt = (1 - eps) / steps
    ones = torch.ones(x.shape[0], 1, device=device) 

    if return_traj:
        traj = [x]

    for i in tqdm(range(steps), leave=False):
        t = timesteps[i] * ones
        sigma_int = graph.sigma_int(t)
        log_prob = model(x,sigma_int, cond, guidance_schedule(timesteps[i]), return_score=False, force_condition_class=force_condition_class)
        if normalize_guid:
            log_prob[...,:-1] = log_prob[...,:-1] - torch.logsumexp(log_prob[...,:-1], dim=-1, keepdim=True)
        esigm1_log = torch.where(sigma_int < 0.5, torch.expm1(sigma_int), sigma_int.exp() - 1).log().to(x.dtype).view(-1,1,1)
        score = (log_prob - esigm1_log).exp()
        # score = prob
        x = graph.update_fn(score, x, t, dt, tau=use_tau_leaping)
        if return_traj:
            traj.append(x)
    
    t = timesteps[-1] * ones

    sigma_int = graph.sigma_int(t)
    log_prob = model(x, sigma_int, cond, guidance_schedule(timesteps[-1]), return_score=False)
    if normalize_guid:
        log_prob[...,:-1] = log_prob[...,:-1] - torch.logsumexp(log_prob[...,:-1], dim=-1, keepdim=True)
    esigm1_log = torch.where(sigma_int < 0.5, torch.expm1(sigma_int), sigma_int.exp() - 1).log().to(x.dtype).view(-1,1,1)
    score = (log_prob - esigm1_log).exp()
    x = graph.denoise(score, x, t)

    if return_traj:
        traj.append(x)
        return x, traj
    
    return x

def get_sampler(graph, device, guidance_schedule=ConstantSchedule(1.)):
    return partial(get_pc_sampler, graph=graph, device=device, guidance_schedule=guidance_schedule)