import torch
import torch.nn as nn

import utils.graph_lib as Graphs

class GraphPreconditioner(nn.Module):
    def __init__(self, net,  graph : Graphs.Absorbing, cfg_p=.1) -> None:
        super().__init__()
        self.net = net
        self.graph = graph
        assert 0 <= cfg_p and cfg_p <= 1, 'CFG Probability has to be in (0,1)'
        self.cfg_prob = cfg_p 
        
    def forward_with_cfg(self, x, sigma_int, cond, cfg_scale, force_condition_class=None):
        cfg_cond = - torch.ones_like(cond) if force_condition_class is None else torch.ones_like(cond) * force_condition_class
        cond_score = self(x,sigma_int,cond)
        uncond_score = self(x,sigma_int, cfg_cond)
        score_w = torch.where(uncond_score != 0, cond_score**cfg_scale * uncond_score ** (1-cfg_scale), 0.)
        normalization = score_w.sum(dim=-1, keepdim=True)
        score_w = torch.where(normalization > 0, score_w/normalization, 0.)
        return score_w
        
    def forward(self,x, sigma_int, cond, cfg_scale=1., return_probs=False, force_condition_class=None):
        shift_cond = cond + 1 # We always add 1 to use 0 as the CFG empty token
        if cfg_scale == 1.:
            disc_score = self.net(x, shift_cond)
            if return_probs:
                return disc_score
            time_rate = torch.where(sigma_int < 0.5, torch.expm1(sigma_int), sigma_int.exp() - 1).view(-1,1,1)
            disc_score = disc_score / time_rate

            return disc_score 
        else:
            guid_score = self.forward_with_cfg(x,sigma_int, cond, cfg_scale, force_condition_class)
            if return_probs:
                return guid_score
            time_rate = torch.where(sigma_int < 0.5, torch.expm1(sigma_int), sigma_int.exp() - 1).view(-1,1,1)
            guid_score = guid_score / time_rate
            return guid_score
        
def get_exact_model(distributions, mask_token):
    @torch.no_grad()
    def conditionals_from_joint(p_joint: torch.Tensor,x: torch.LongTensor,mask_token: int, eps: float = 1e-12, observed_one_hot: bool = True,
    ) -> torch.Tensor:
        """
        p_cond[n, i, v] = P(X_i = v | X_j = x[n,j] for observed j).
        p_joint: [M]*D tensor (D axes), nonnegative (will be normalized inside).
        x:       [N, D] long, observed in [0..M-1], masked == mask_token.
        Returns: [N, D, M] float tensor with conditionals.
        """
        if x.dim() != 2:
            raise ValueError(f"`x` must be [N, D], got {tuple(x.shape)}")
        N, D = x.shape
        M =  p_joint.shape[-1]
        if any(s != M for s in p_joint.shape):
            raise ValueError(f"`p_joint` must have shape {[M]*D}, got {tuple(p_joint.shape)}")
        if (x[(x != mask_token) & ((x < 0) | (x >= M))].numel() > 0):
            raise ValueError("Observed entries in `x` must be in [0, M-1].")

        p = p_joint.to(dtype=torch.get_default_dtype())
        Z = p.sum()
        if Z <= 0:
            raise ValueError("p_joint must have positive total mass.")
        p = p / Z

        p_cond = torch.empty((N, D, M), dtype=p.dtype, device=p.device)
        base_idx = [slice(None)] * D

        for n in range(N):
            idx = list(base_idx)
            observed = (x[n] != mask_token)
            for j in range(D):
                if observed[j]:
                    idx[j] = int(x[n, j])

            denom = p[tuple(idx)].sum()
            denom_ok = bool(denom > 0)

            for i in range(D):
                if observed[i] and observed_one_hot:
                    onehot = torch.zeros(M, dtype=p.dtype, device=p.device)
                    onehot[int(x[n, i])] = 1.0
                    p_cond[n, i] = onehot
                    continue

                nums = torch.empty(M, dtype=p.dtype, device=p.device)
                for v in range(M):
                    idx_i = list(idx)
                    idx_i[i] = v
                    nums[v] = p[tuple(idx_i)].sum()

                probs = nums / (denom + eps) if denom_ok else torch.full((M,), 1.0 / M, dtype=p.dtype, device=p.device)
                p_cond[n, i] = probs

        return p_cond
    
    def exact_model(x, cond):
        return conditionals_from_joint(distributions[cond[0].item()], x, mask_token)
    
    return exact_model 


def get_preconditioned_model(net, graph, cfg_prob=.1):
    return GraphPreconditioner(net,graph, cfg_p=cfg_prob)