import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import math

from diffusion import diffusion_utils
import utils


class MaskedTransition:
    def __init__(self, x_classes: int, e_classes: int, y_classes: int, 
                 cfg):
        super().__init__()
        self.X_classes = x_classes
        self.E_classes = e_classes
        self.y_classes = y_classes
        self.cfg = cfg
        self.noise_schedule = self.cfg.model.diffusion_noise_schedule
        
        self.u_x = torch.zeros(1, x_classes)
        self.u_x[:, -1] = 1.0
        self.u_e = torch.zeros(1, e_classes)
        self.u_e[:, -1] = 1.0
        
    def get_marginals(self, y=None, device=None):
        return self.u_x, self.u_e

    def get_Qt_bar(self, alpha_bar_t, device, y=None):
        u_x, u_e = self.get_marginals(y, device)
        u_x = u_x.to(device).unsqueeze(1).expand(-1, self.X_classes, -1)
        u_e = u_e.to(device).unsqueeze(1).expand(-1, self.E_classes, -1)

        if isinstance(alpha_bar_t, tuple): # element-wise schedule
            alpha_bar_t_node, alpha_bar_t_edge = alpha_bar_t
            B, N = alpha_bar_t_node.shape
            alpha_bar_t_node = alpha_bar_t_node.view(B, N, 1, 1)
            alpha_bar_t_edge = alpha_bar_t_edge.view(B, N, N, 1, 1)
            alpha_bar_t_node = alpha_bar_t_node.to(device)
            alpha_bar_t_edge = alpha_bar_t_edge.to(device)
            
            q_x = alpha_bar_t_node * torch.eye(self.X_classes, device=device)[None, None, :, :] + (1 - alpha_bar_t_node) * u_x.unsqueeze(1)
            q_e = alpha_bar_t_edge * torch.eye(self.E_classes, device=device)[None, None, None, :, :] + (1 - alpha_bar_t_edge) * u_e[:, None, None, :, :]
            q_y = None
        else: # element-agnostic schedule
            alpha_bar_t = alpha_bar_t.unsqueeze(1)
            alpha_bar_t = alpha_bar_t.to(device)
            alpha_bar_t_node = alpha_bar_t
            alpha_bar_t_edge = alpha_bar_t

            q_x = alpha_bar_t_node * torch.eye(self.X_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t_node) * u_x
            q_e = alpha_bar_t_edge * torch.eye(self.E_classes, device=device).unsqueeze(0) + (1 - alpha_bar_t_edge) * u_e
            q_y = None
        
        return utils.PlaceHolder(X=q_x, E=q_e, y=q_y)
    
    