import abc
import torch
import torch.nn.functional as F
from catsample import sample_categorical

from model import utils as mutils

_PREDICTORS = {}


def register_predictor(cls=None, *, name=None):
    """A decorator for registering predictor classes."""

    def _register(cls):
        if name is None:
            local_name = cls.__name__
        else:
            local_name = name
        if local_name in _PREDICTORS:
            raise ValueError(
                f'Already registered model with name: {local_name}')
        _PREDICTORS[local_name] = cls
        return cls

    if cls is None:
        return _register
    else:
        return _register(cls)

    
def get_predictor(name):
    return _PREDICTORS[name]



class Predictor(abc.ABC):
    """The abstract class for a predictor algorithm."""

    def __init__(self, graph, noise):
        super().__init__()
        self.graph = graph
        self.noise = noise

    @abc.abstractmethod
    def update_fn(self, score_fn, x, t, step_size):
        """One update of the predictor.

        Args:
            score_fn: score function
            x: A PyTorch tensor representing the current state
            t: A Pytorch tensor representing the current time step.

        Returns:
            x: A PyTorch tensor of the next state.
        """
        pass


@register_predictor(name="euler")
class EulerPredictor(Predictor):
    def update_fn(self, score_fn, x, t, step_size):
        sigma, dsigma = self.noise(t)
        score = score_fn(x, sigma)

        rev_rate = step_size * dsigma[..., None] * self.graph.reverse_rate(x, score)
        x = self.graph.sample_rate(x, rev_rate)
        return x

@register_predictor(name="none")
class NonePredictor(Predictor):
    def update_fn(self, score_fn, x, t, step_size):
        return x


@register_predictor(name="analytic")
class AnalyticPredictor(Predictor):
    def update_fn(self, score_fn, x, t, step_size):
        curr_sigma = self.noise(t)[0]
        next_sigma = self.noise(t - step_size)[0]
        dsigma = curr_sigma - next_sigma

        score = score_fn(x, curr_sigma)

        stag_score = self.graph.staggered_score(score, dsigma)
        probs = stag_score * self.graph.transp_transition(x, dsigma)
        return sample_categorical(probs,type = torch.float32)

    
class Denoiser:
    def __init__(self, graph, noise):
        self.graph = graph
        self.noise = noise

    def update_fn(self, score_fn, x, t):
        sigma = self.noise(t)[0]

        score = score_fn(x, sigma)
        stag_score = self.graph.staggered_score(score, sigma)
        probs = stag_score * self.graph.transp_transition(x, sigma)
        # truncate probabilities
        if self.graph.absorb:
            probs = probs[..., :-1]
        
        #return probs.argmax(dim=-1)
        return sample_categorical(probs)
                       

def get_sampling_fn(config, graph, noise, batch_dims, eps, device):
    
    sampling_fn = get_pc_sampler(graph=graph,
                                 noise=noise,
                                 batch_dims=batch_dims,
                                 predictor=config.sampling.predictor,
                                 steps=config.sampling.steps,
                                 denoise=config.sampling.noise_removal,
                                 eps=eps,
                                 device=device)
    
    return sampling_fn
    
total_update_cnt = 0

def get_pc_sampler(graph, noise, batch_dims, predictor, steps, denoise=True, eps=1e-5, device=torch.device('cpu'), proj_fun=lambda x: x, order = torch.arange(0,1024)):
    
    @torch.no_grad()
    def ar_sampler(model):
        model.eval()
        x = graph.sample_limit(*batch_dims).to(device)
        for i in range(steps):
            log_p = model.get_log_condition(x)
            p_condition = log_p.exp()
            x[:,order[i]] = sample_categorical(p_condition[:,order[i],:-1].to(torch.bfloat16))
        return x
        
    @torch.no_grad()
    def ar_random_sampler(model):
        model.eval()
        x = graph.sample_limit(*batch_dims).to(device)
        order = torch.randperm(1024)
        for i in range(steps):
            log_p = model.get_log_condition(x)
            p_condition = log_p.exp()
            x[:,order[i]] = sample_categorical(p_condition[:,order[i],:-1].to(torch.float16))
        return x

    if predictor == 'ar':
        assert steps == 1024
        return ar_sampler
    
    if predictor =='ar_forward':
        assert steps == 1024
        order = torch.arange(0,1024)
        return ar_sampler
    
    if predictor =='ar_backward':
        assert steps == 1024
        order = torch.arange(1023, -1, -1)
        return ar_sampler

    if predictor == 'ar_random':
        assert steps == 1024
        return ar_random_sampler
    

    projector = proj_fun
    denoiser = Denoiser(graph, noise)

    global total_update_cnt


    @torch.no_grad()
    def cached_analytic_sampler(model):
        model.eval()
        x = graph.sample_limit(*batch_dims).to(device)
        x = projector(x)
        timesteps = torch.linspace(1, eps, steps + 1, device=device)
        dt = (1 - eps) / steps
        step_size = dt
        changed = torch.ones(batch_dims[0], dtype=torch.bool)
        
        update_cnt = 0
        try:
            p_condition = torch.zeros((batch_dims[0],batch_dims[1],graph.dim),dtype = torch.bfloat16).to(device)
        except:
            p_condition = torch.zeros((batch_dims[0],batch_dims[1],graph.dim),dtype = torch.float16).to(device)
        for i in range(steps):
            t = timesteps[i] 
            curr_sigma,next_sigma = noise(t)[0], noise(t - step_size)[0]
            lambda_curr, lambda_next = (-curr_sigma).exp(), (-next_sigma).exp()
            if changed.any():
                mask = (x == graph.dim - 1)
                p_condition[changed] = model.get_log_condition(x[changed]).exp()
                p_condition_mask = p_condition[mask]
            probs_mask = p_condition_mask * ((lambda_next - lambda_curr) / (1 - lambda_curr))
            probs_mask[..., -1] =  (1 - lambda_next)/ (1 - lambda_curr)
            update_x_mask = sample_categorical(probs_mask, type = torch.float64)
            x_old = x.clone()
            x[mask] = update_x_mask
            changed = (x != x_old).any(dim = -1)
            update_cnt += (changed).sum().item()

        # print(f"Updated {update_cnt/x.shape[0]} pixels")
        if denoise:
            # denoising step
            sampling_score_fn = mutils.get_score_fn(model, train=False, sampling=True)
            x = projector(x)
            t = timesteps[-1] * torch.ones(x.shape[0], 1, device=device)
            x = denoiser.update_fn(sampling_score_fn, x, t)
        global total_update_cnt
        total_update_cnt += update_cnt
        return x
    
    @torch.no_grad()
    def cached_euler_sampler(model):
            model.eval()
            x = graph.sample_limit(*batch_dims).to(device)
            x = projector(x)
            timesteps = torch.linspace(1, eps, steps + 1, device=device)
            dt = (1 - eps) / steps
            step_size = dt
            changed = torch.ones(batch_dims[0], dtype=torch.bool)
            try:
                p_condition = torch.zeros((batch_dims[0],batch_dims[1],graph.dim),dtype = torch.bfloat16).to(device)
            except:
                p_condition = torch.zeros((batch_dims[0],batch_dims[1],graph.dim),dtype = torch.float16).to(device)
            global total_update_cnt
            for i in range(steps):
                t = timesteps[i] 
                curr_sigma,next_sigma = noise(t)[0], noise(t - step_size)[0]
                lambda_curr, lambda_next = (-curr_sigma).exp(), (-next_sigma).exp()
                if changed.any():
                    mask = (x == graph.dim - 1)
                    p_condition[changed] = model.get_log_condition(x[changed]).exp()
                    p_condition_mask = p_condition[mask]
                probs_mask = p_condition_mask * (dt * lambda_curr / (1 - lambda_curr))
                probs_mask[..., -1] =  1 - (dt * lambda_curr / (1 - lambda_curr))
                update_x_mask = sample_categorical(probs_mask.to(torch.float32))
                x_old = x.clone()
                x[mask] = update_x_mask
                changed = (x != x_old).any(dim = -1)
                total_update_cnt+=changed.sum()
            if denoise:
                # denoising step
                sampling_score_fn = mutils.get_score_fn(model, train=False, sampling=True)
                x = projector(x)
                t = timesteps[-1] * torch.ones(x.shape[0], 1, device=device)
                x = denoiser.update_fn(sampling_score_fn, x, t)
                
            return x

    if predictor == 'cached_analytic':
        return cached_analytic_sampler
    if predictor == 'cached_euler':
        return cached_euler_sampler

    predictor = get_predictor(predictor)(graph, noise)

    @torch.no_grad()
    def pc_sampler(model):
        sampling_score_fn = mutils.get_score_fn(model, train=False, sampling=True)
        x = graph.sample_limit(*batch_dims).to(device)
        timesteps = torch.linspace(1, eps, steps + 1, device=device)
        dt = (1 - eps) / steps

        for i in range(steps):
            t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
            x = projector(x)
            x = predictor.update_fn(sampling_score_fn, x, t, dt)
            

        if denoise:
            # denoising step
            x = projector(x)
            t = timesteps[-1] * torch.ones(x.shape[0], 1, device=device)
            x = denoiser.update_fn(sampling_score_fn, x, t)
            
        return x
    
    return pc_sampler

