import torch
import torch.nn.functional as F

class ContinuousRateMatrixDesigner:
    def __init__(self, limit_dist):
        self.limit_dist = limit_dist
        self.num_classes_X = len(self.limit_dist.X)
        self.num_classes_E = len(self.limit_dist.E)
        
        self.p0_X = self.limit_dist.X         
        self.p0_E = self.limit_dist.E         

    def update_limit_dist(self, limit_dist):
        self.limit_dist = limit_dist
        self.p0_X = self.limit_dist.X
        self.p0_E = self.limit_dist.E


    def compute_graph_rate_matrix(self, t, node_mask, G_t, G_1_pred):

        X_t, E_t = G_t
        X_1_pred, E_1_pred = G_1_pred                          

        X_t_idx = X_t.argmax(dim=-1)               
        E_t_idx = E_t.argmax(dim=-1)                  
        
        t_X = t.view(-1, 1, 1)                        
        t_E = t.view(-1, 1, 1, 1)                        

        bs, n = X_t.shape[:2]
        S_X = self.num_classes_X
        
        p0_X = self.p0_X.to(X_t.device)           
        p0_X_zt = p0_X[X_t_idx]                    
        p0_X_k = p0_X.view(1, 1, S_X)                  
        
        pt_X_k = X_1_pred                               
        pt_X_zt = pt_X_k.gather(-1, X_t_idx.unsqueeze(-1)).squeeze(-1)          
        
        denom_X = S_X * (1 - t_X + 1e-6) * p0_X_zt.unsqueeze(-1)             
        
        term1_X = pt_X_k * (1 + p0_X_zt.unsqueeze(-1) - p0_X_k)               
        
        relu_diff_X = F.relu(p0_X_zt.unsqueeze(-1) - p0_X_k)               
        factor_X = 1 - pt_X_zt.unsqueeze(-1) - pt_X_k                    
        term2_X = factor_X * relu_diff_X
        
        R_t_X = (term1_X + term2_X) / denom_X
        
        mask_diag_X = torch.eye(S_X, device=X_t.device).view(1, 1, S_X, S_X)
        mask_self_X = F.one_hot(X_t_idx, num_classes=S_X).bool()
        R_t_X.masked_fill_(mask_self_X, 0.0)

        S_E = self.num_classes_E
        
        p0_E = self.p0_E.to(E_t.device)
        p0_E_zt = p0_E[E_t_idx]                       
        p0_E_k = p0_E.view(1, 1, 1, S_E)                  
        
        pt_E_k = E_1_pred                                  
        pt_E_zt = pt_E_k.gather(-1, E_t_idx.unsqueeze(-1)).squeeze(-1)             
        
        denom_E = S_E * (1 - t_E + 1e-6) * p0_E_zt.unsqueeze(-1)
        
        term1_E = pt_E_k * (1 + p0_E_zt.unsqueeze(-1) - p0_E_k)
        
        relu_diff_E = F.relu(p0_E_zt.unsqueeze(-1) - p0_E_k)
        factor_E = 1 - pt_E_zt.unsqueeze(-1) - pt_E_k
        term2_E = factor_E * relu_diff_E
        
        R_t_E = (term1_E + term2_E) / denom_E
        
        mask_self_E = F.one_hot(E_t_idx, num_classes=S_E).bool()
        R_t_E.masked_fill_(mask_self_E, 0.0)

        if node_mask is not None:
             x_mask = node_mask.unsqueeze(-1)
             R_t_X = R_t_X * x_mask
             
             e_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)             
             R_t_E = R_t_E * e_mask.unsqueeze(-1)

        return R_t_X, R_t_E
