import random
from collections import defaultdict
import itertools
import dgl.sparse as dglsp
import math
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..logic_constraint import solve_edge_constraint, solve_dag_constraint
from einops import rearrange
import json

def check_predecessor_balance(x_n, z_t, rho):
    # z_t_new with shape(2, m)  z_t_new[1] is dst, z_t_new[0] is src
    predecessors = defaultdict(list)
    src_list, dst_list = z_t[1].cpu().tolist(), z_t[0].cpu().tolist()
    for src, dst in zip(src_list, dst_list):
        predecessors[dst].append(src)

    unsatisfy = None
    src_list = None
    for v, preds in predecessors.items():
        n_v_0 = sum(1 for u in preds if x_n[u-1][0] == 0)
        n_v_1 = sum(1 for u in preds if x_n[u-1][0] == 1)
        n_v = n_v_0 + n_v_1
        if n_v > 0:
            imbalance_ratio = (np.floor(abs(n_v_0 - n_v_1) / 2)) / (n_v / 2)
            if imbalance_ratio > rho:
                if unsatisfy is None:
                    unsatisfy = []
                if src_list is None:
                    src_list = set()
                unsatisfy.append(v)
                for u in preds:
                    src_list.add(u)
    if unsatisfy is not None and src_list is not None:
        return unsatisfy, list(src_list)
    else:
        return None, None
    
def check_predecessor_balance_src_dst(x_n, src_list, dst_list, rho):
    predecessors = defaultdict(list)
    for src, dst in zip(src_list, dst_list):
        predecessors[dst].append(src)

    unsatisfy = None
    new_src_list = None
    for v, preds in predecessors.items():
        n_v_0 = sum(1 for u in preds if x_n[u-1] == 0)
        n_v_1 = sum(1 for u in preds if x_n[u-1] == 1)
        n_v = n_v_0 + n_v_1
        if n_v > 0:
            imbalance_ratio = (np.floor(abs(n_v_0 - n_v_1) / 2)) / (n_v / 2)
            if imbalance_ratio > rho:
                if unsatisfy is None:
                    unsatisfy = []
                if new_src_list is None:
                    new_src_list = set()
                unsatisfy.append(v)
                for u in preds:
                    new_src_list.add(u)
    if unsatisfy is not None and new_src_list is not None:
        return unsatisfy, list(new_src_list)
    else:
        return None, None

class SinusoidalPE(nn.Module):
    def __init__(self, pe_size):
        super().__init__()

        self.pe_size = pe_size
        if pe_size > 0:
            self.div_term = torch.exp(torch.arange(0, pe_size, 2) *
                                      (-math.log(10000.0) / pe_size))
            self.div_term = nn.Parameter(self.div_term, requires_grad=False)

    def forward(self, position):
        if self.pe_size == 0:
            return torch.zeros(len(position), 0).to(position.device)

        return torch.cat([
            torch.sin(position * self.div_term),
            torch.cos(position * self.div_term)
        ], dim=-1)

class BiMPNNLayer(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()

        self.W = nn.Linear(in_size, out_size)
        self.W_trans = nn.Linear(in_size, out_size)
        self.W_self = nn.Linear(in_size, out_size)

    def forward(self, A, A_T, h_n):
        if A.nnz == 0:
            h_n_out = self.W_self(h_n)
        else:
            h_n_out = A @ self.W(h_n) + A_T @ self.W_trans(h_n) + self.W_self(h_n)
        return F.gelu(h_n_out)

class OneHotPE(nn.Module):
    def __init__(self, pe_size):
        super().__init__()

        self.pe_size = pe_size

    def forward(self, position):
        if self.pe_size == 0:
            return torch.zeros(len(position), 0).to(position.device)

        return F.one_hot(position.clamp(max=self.pe_size - 1).long().squeeze(-1),
                         num_classes=self.pe_size)

class MultiEmbedding(nn.Module):
    def __init__(self, num_x_n_cat, hidden_size):
        super().__init__()

        self.emb_list = nn.ModuleList([
            nn.Embedding(num_x_n_cat_i, hidden_size)
            for num_x_n_cat_i in num_x_n_cat.tolist()
        ])

    def forward(self, x_n_cat):
        if len(x_n_cat.shape) == 1:
            x_n_emb = self.emb_list[0](x_n_cat)
        else:
            x_n_emb = torch.cat([
                self.emb_list[i](x_n_cat[:, i]) for i in range(len(self.emb_list))
            ], dim=1)

        return x_n_emb

class BiMPNNEncoder(nn.Module):
    def __init__(self,
                 num_x_n_cat,
                 x_n_emb_size,
                 pe_emb_size,
                 hidden_size,
                 num_mpnn_layers,
                 pe=None,
                 y_emb_size=0,
                 pool=None):
        super().__init__()

        self.pe = pe
        if self.pe in ['relative_level', 'abs_level']:
            self.level_emb = SinusoidalPE(pe_emb_size)
        elif self.pe == 'relative_level_one_hot':
            self.level_emb = OneHotPE(pe_emb_size)

        self.x_n_emb = MultiEmbedding(num_x_n_cat, x_n_emb_size)
        self.y_emb = SinusoidalPE(y_emb_size)

        self.proj_input = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )

        self.mpnn_layers = nn.ModuleList()
        for _ in range(num_mpnn_layers):
            self.mpnn_layers.append(BiMPNNLayer(hidden_size, hidden_size))

        self.project_output_n = nn.Sequential(
            nn.Linear((num_mpnn_layers + 1) * hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )

        self.pool = pool
        if pool is not None:
            self.bn_g = nn.BatchNorm1d(hidden_size)

    def forward(self, A, x_n, abs_level=None, rel_level=None, y=None, A_n2g=None):
        A_T = A.T
        h_n = self.x_n_emb(x_n)

        if self.pe == 'abs_level':
            node_pe = self.level_emb(abs_level)

        if self.pe in ['relative_level', 'relative_level_one_hot']:
            node_pe = self.level_emb(rel_level)

        if self.pe is not None:
            h_n = torch.cat([h_n, node_pe], dim=-1)

        if y is not None:
            h_y = self.y_emb(y)
            h_n = torch.cat([h_n, h_y], dim=-1)

        h_n = self.proj_input(h_n)
        h_n_cat = [h_n]
        for layer in self.mpnn_layers:
            h_n = layer(A, A_T, h_n)
            h_n_cat.append(h_n)
        h_n = torch.cat(h_n_cat, dim=-1)
        h_n = self.project_output_n(h_n)

        if self.pool is None:
            return h_n
        elif self.pool == 'sum':
            h_g = A_n2g @ h_n
            return self.bn_g(h_g)
        elif self.pool == 'mean':
            h_g = A_n2g @ h_n
            h_g = h_g / A_n2g.sum(dim=1).unsqueeze(-1)
            return self.bn_g(h_g)

class GraphClassifier(nn.Module):
    def __init__(self,
                 graph_encoder,
                 emb_size,
                 num_classes):
        super().__init__()

        self.graph_encoder = graph_encoder
        self.predictor = nn.Sequential(
            nn.Linear(emb_size, emb_size),
            nn.GELU(),
            nn.Linear(emb_size, num_classes)
        )

    def forward(self, A, x_n, abs_level, rel_level, A_n2g, y=None):
        h_g = self.graph_encoder(A, x_n, abs_level, rel_level, y, A_n2g)
        pred_g = self.predictor(h_g)

        return pred_g

class TransformerLayer(nn.Module):
    def __init__(self,
                 hidden_size,
                 num_heads,
                 dropout):
        super().__init__()

        self.to_v = nn.Linear(hidden_size, hidden_size)
        self.to_qk = nn.Linear(hidden_size, hidden_size * 2)

        self._reset_parameters()

        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        assert head_dim * num_heads == hidden_size, "hidden_size must be divisible by num_heads"
        self.scale = head_dim ** -0.5

        self.proj_new = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(hidden_size)

        self.out_proj = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(hidden_size)

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.to_v.weight)
        nn.init.xavier_uniform_(self.to_qk.weight)

    def attn(self, q, k, v, num_query_cumsum):
        """
        Parameters
        ----------
        q : torch.Tensor of shape (N, F)
            Query matrix for node representations.
        k : torch.Tensor of shape (N, F)
            Key matrix for node representations.
        v : torch.Tensor of shape (N, F)
            Value matrix for node representations.
        num_query_cumsum : torch.Tensor of shape (B + 1)
            num_query_cumsum[0] is 0, num_query_cumsum[i] is the number of queries
            for the first i graphs in the batch for i > 0.

        Returns
        -------
        torch.Tensor of shape (N, F)
            Updated hidden representations of query nodes for the batch of graphs.
        """
        # Handle different numbers of query nodes in the batch with padding.
        batch_size = len(num_query_cumsum) - 1
        num_query_nodes = torch.diff(num_query_cumsum)
        max_num_nodes = num_query_nodes.max().item()

        q_padded = q.new_zeros(batch_size, max_num_nodes, q.shape[-1])
        k_padded = k.new_zeros(batch_size, max_num_nodes, k.shape[-1])
        v_padded = v.new_zeros(batch_size, max_num_nodes, v.shape[-1])
        pad_mask = q.new_zeros(batch_size, max_num_nodes).bool()

        for i in range(batch_size):
            q_padded[i, :num_query_nodes[i]] = q[num_query_cumsum[i]:num_query_cumsum[i + 1]]
            k_padded[i, :num_query_nodes[i]] = k[num_query_cumsum[i]:num_query_cumsum[i + 1]]
            v_padded[i, :num_query_nodes[i]] = v[num_query_cumsum[i]:num_query_cumsum[i + 1]]
            pad_mask[i, num_query_nodes[i]:] = True

        # Split F into H * D, where H is the number of heads
        # D is the dimension per head.

        # (B, H, max_num_nodes, D)
        q_padded = rearrange(q_padded, 'b n (h d) -> b h n d', h=self.num_heads)
        # (B, H, max_num_nodes, D)
        k_padded = rearrange(k_padded, 'b n (h d) -> b h n d', h=self.num_heads)
        # (B, H, max_num_nodes, D)
        v_padded = rearrange(v_padded, 'b n (h d) -> b h n d', h=self.num_heads)

        # Q * K^T / sqrt(D)
        # (B, H, max_num_nodes, max_num_nodes)
        dot = torch.matmul(q_padded, k_padded.transpose(-1, -2)) * self.scale
        # Mask unnormalized attention logits for non-existent nodes.
        dot = dot.masked_fill(
            pad_mask.unsqueeze(1).unsqueeze(2),
            float('-inf'),
        )

        attn_scores = F.softmax(dot, dim=-1)
        # (B, H, max_num_nodes, D)
        h_n_padded = torch.matmul(attn_scores, v_padded)
        # (B * max_num_nodes, H * D) = (B * max_num_nodes, F)
        h_n_padded = rearrange(h_n_padded, 'b h n d -> (b n) (h d)')

        # Unpad the aggregation results.
        # (N, F)
        pad_mask = (~pad_mask).reshape(-1)
        return h_n_padded[pad_mask]

    def forward(self, h_n, num_query_cumsum):
        # Compute value matrix
        v_n = self.to_v(h_n)

        # Compute query and key matrices
        q_n, k_n = self.to_qk(h_n).chunk(2, dim=-1)

        h_n_new = self.attn(q_n, k_n, v_n, num_query_cumsum)
        h_n_new = self.proj_new(h_n_new)

        # Add & Norm
        h_n = self.norm1(h_n + h_n_new)
        h_n = self.norm2(h_n + self.out_proj(h_n))

        return h_n

class NodePredModel(nn.Module):
    def __init__(self,
                 graph_encoder,
                 num_x_n_cat,
                 x_n_emb_size,
                 t_emb_size,
                 in_hidden_size,
                 out_hidden_size,
                 num_transformer_layers,
                 num_heads,
                 dropout):
        super().__init__()

        self.graph_encoder = graph_encoder
        num_real_classes = num_x_n_cat - 1
        self.x_n_emb = MultiEmbedding(num_real_classes, x_n_emb_size)
        self.t_emb = SinusoidalPE(t_emb_size)
        in_hidden_size = in_hidden_size + t_emb_size + len(num_real_classes) * x_n_emb_size
        self.project_h_n = nn.Sequential(
            nn.Linear(in_hidden_size, out_hidden_size),
            nn.GELU()
        )

        self.trans_layers = nn.ModuleList()
        for _ in range(num_transformer_layers):
            self.trans_layers.append(TransformerLayer(
                out_hidden_size, num_heads, dropout
            ))

        self.pred_list = nn.ModuleList([])
        num_real_classes = num_real_classes.tolist()
        for num_classes_f in num_real_classes:
            self.pred_list.append(nn.Sequential(
                nn.Linear(out_hidden_size, out_hidden_size),
                nn.GELU(),
                nn.Linear(out_hidden_size, num_classes_f)
            ))

    def forward_with_h_g(self, h_g, x_n_t,
                         t, query2g, num_query_cumsum):
        h_t = self.t_emb(t)
        h_g = torch.cat([h_g, h_t], dim=1)

        h_n_t = self.x_n_emb(x_n_t)
        h_n_t = torch.cat([h_n_t, h_g[query2g]], dim=1)
        h_n_t = self.project_h_n(h_n_t)

        for trans_layer in self.trans_layers:
            h_n_t = trans_layer(h_n_t, num_query_cumsum)

        pred = []
        for d in range(len(self.pred_list)):
            pred.append(self.pred_list[d](h_n_t))

        return pred

    def forward(self, A, x_n, abs_level, rel_level, A_n2g, x_n_t,
                t, query2g, num_query_cumsum, y=None):
        """
        Parameters
        ----------
        x_n_t : torch.LongTensor of shape (Q)
        t : torch.LongTensor of shape (B, 1)
        query2g : torch.LongTensor of shape (Q)
        num_query_cumsum : torch.LongTensor of shape (B + 1)
        """
        h_g = self.graph_encoder(A, x_n, abs_level,
                                 rel_level, y=y, A_n2g=A_n2g)
        return self.forward_with_h_g(h_g, x_n_t, t, query2g,
                                     num_query_cumsum)

class EdgePredModel(nn.Module):
    def __init__(self,
                 graph_encoder,
                 t_emb_size,
                 in_hidden_size,
                 out_hidden_size):
        super().__init__()

        self.graph_encoder = graph_encoder
        self.t_emb = SinusoidalPE(t_emb_size)
        self.pred = nn.Sequential(
            nn.Linear(2 * in_hidden_size + t_emb_size, out_hidden_size),
            nn.GELU(),
            nn.Linear(out_hidden_size, 2)
        )

    def forward(self, A, x_n, abs_level, rel_level, t,
                query_src, query_dst, y=None):
        """
        t : torch.tensor of shape (num_queries, 1)
        """
        h_n = self.graph_encoder(A, x_n, abs_level, rel_level, y=y)

        h_e = torch.cat([
            self.t_emb(t),
            h_n[query_src],
            h_n[query_dst]
        ], dim=-1)
        return self.pred(h_e)

class EdgeRefineModel(nn.Module):
    def __init__(self, graph_encoder, t_emb_size, in_hidden_size, out_hidden_size, x_n_dim=8):
        super().__init__()

        self.graph_encoder = graph_encoder
        self.x_n_encoder = nn.Embedding(3, x_n_dim)  

        self.pred = nn.Sequential(
            nn.Linear(2 * in_hidden_size + x_n_dim, out_hidden_size), 
            nn.LayerNorm(out_hidden_size),
            nn.GELU(),
            nn.Linear(out_hidden_size, out_hidden_size),
            nn.GELU(),
            nn.Linear(out_hidden_size, 2)
        )

    def forward(self, A, x_n, t, query_src, query_dst, number_list=None, edge_list=None, abs_level=None, rel_level=None, y=None):
        h_n = self.graph_encoder(A, x_n, abs_level, rel_level, y=y)  # [N, in_hidden_size]
        x_feat = self.x_n_encoder(x_n)  # [N, in_hidden_size]
        h_e = torch.cat([
            h_n[query_src],  
            h_n[query_dst],
            x_feat[query_src] 
        ], dim=-1)

        return self.pred(h_e)


class EdgeRefineModelWithEdgeTransformer(nn.Module):
    def __init__(self, vocab_size=5, d_model=64, nhead=4, num_layers=8, dim_feedforward=256, max_length=128):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=2)
        self.pos_emb   = nn.Embedding(max_length, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, src_ids, tgt_ids):
        """
        src_ids: (S,  B)   S=2L+2,  B=batch
        tgt_ids: (T,  B)   T=L+1
        """
        S, B = src_ids.shape
        T, _ = tgt_ids.shape
        # embed + pos
        src = self.token_emb(src_ids) + self.pos_emb(torch.arange(S, device=src_ids.device)[:,None])
        tgt = self.token_emb(tgt_ids) + self.pos_emb(torch.arange(T, device=tgt_ids.device)[:,None])
        # causal mask for tgt
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(src_ids.device)
        # No src mask (full attention)
        out = self.decoder(tgt, src, tgt_mask=tgt_mask)    # (T, B, D)
        logits = self.out(out)                             # (T, B, V)
        return logits                         # (T, B, V)

class LayerDAG(nn.Module):
    def __init__(self,
                 device,
                 num_x_n_cat,
                 node_count_encoder_config,
                 max_layer_size,
                 node_diffusion,
                 node_pred_graph_encoder_config,
                 node_predictor_config,
                 edge_diffusion,
                 edge_pred_graph_encoder_config,
                 edge_predictor_config,
                 rho,
                 max_level=None,
                 load_refine_model=False,
                 refine_with_transformer=False,):
        """
        Parameters
        ----------
        num_x_n_cat :
            Case1: int
            Case2: torch.LongTensor of shape (num_feats)
        """
        super().__init__()
        self.rho = rho
        if isinstance(num_x_n_cat, int):
            num_x_n_cat = torch.LongTensor([num_x_n_cat])
        elif isinstance(num_x_n_cat, torch.Tensor) and num_x_n_cat.dim() == 0:
            num_x_n_cat = num_x_n_cat.unsqueeze(0)

        self.dummy_x_n = num_x_n_cat - 1
        hidden_size = len(num_x_n_cat) * node_count_encoder_config['x_n_emb_size'] +\
            node_count_encoder_config['pe_emb_size'] +\
            node_count_encoder_config['y_emb_size']
        node_count_encoder = BiMPNNEncoder(num_x_n_cat,
                                           hidden_size=hidden_size,
                                           **node_count_encoder_config).to(device)
        self.node_count_model = GraphClassifier(
            node_count_encoder,
            emb_size=hidden_size,
            num_classes=max_layer_size+1).to(device)

        self.node_diffusion = node_diffusion
        hidden_size = len(num_x_n_cat) * node_pred_graph_encoder_config['x_n_emb_size'] +\
            node_pred_graph_encoder_config['pe_emb_size'] +\
            node_pred_graph_encoder_config['y_emb_size']
        node_pred_graph_encoder = BiMPNNEncoder(num_x_n_cat, hidden_size=hidden_size,
                                                **node_pred_graph_encoder_config).to(device)
        self.node_pred_model = NodePredModel(node_pred_graph_encoder,
                                             num_x_n_cat,
                                             node_pred_graph_encoder_config['x_n_emb_size'],
                                             in_hidden_size=hidden_size,
                                             **node_predictor_config).to(device)

        self.edge_diffusion = edge_diffusion
        hidden_size = len(num_x_n_cat) * edge_pred_graph_encoder_config['x_n_emb_size'] +\
            edge_pred_graph_encoder_config['pe_emb_size'] +\
            edge_pred_graph_encoder_config['y_emb_size']
        edge_pred_graph_encoder = BiMPNNEncoder(num_x_n_cat, hidden_size=hidden_size,
                                                **edge_pred_graph_encoder_config).to(device)
        self.edge_pred_model = EdgePredModel(edge_pred_graph_encoder,
                                             in_hidden_size=hidden_size,
                                             **edge_predictor_config).to(device)

        self.max_level = max_level
        self.refine_model = RefineWithEdgeTransformer()

    def posterior(self, Z_t, Q_t, Q_bar_s, Q_bar_t, Z_0):
        # (num_rows, num_classes)
        left_term = Z_t @ torch.transpose(Q_t, -1, -2)
        # (num_rows, 1, num_classes)
        left_term = left_term.unsqueeze(dim=-2)
        # (1, num_classes, num_classes)
        right_term = Q_bar_s.unsqueeze(dim=-3)
        # (num_rows, num_classes, num_classes)
        numerator = left_term * right_term

        # (num_classes, num_rows)
        prod = Q_bar_t @ torch.transpose(Z_t, -1, -2)
        # (num_rows, num_classes)
        prod = torch.transpose(prod, -1, -2)
        # (num_rows, num_classes, 1)
        denominator = prod.unsqueeze(-1)
        denominator[denominator == 0.] = 1.
        # (num_rows, num_classes, num_classes)
        out = numerator / denominator

        # (num_rows, num_classes, num_classes)
        prob = Z_0.unsqueeze(-1) * out
        # (num_rows, num_classes)
        prob = prob.sum(dim=-2)

        return prob

    def posterior_edge(self,
                       Z_t,
                       alpha_t,
                       alpha_bar_s,
                       alpha_bar_t,
                       Z_0,
                       marginal_list,
                       num_new_nodes_list,
                       num_query_list,):
        batch_size = len(num_new_nodes_list)

        Z_t_list = torch.split(Z_t, num_query_list, dim=0)
        Z_0_list = torch.split(Z_0, num_query_list, dim=0)
        device = Z_t.device
        e_mask_list = []

        for i in range(batch_size):
            Z_t_i = Z_t_list[i]
            Z_0_i = Z_0_list[i]

            Q_t_i, Q_bar_s_i, Q_bar_t_i = self.edge_diffusion.get_Qs(
                alpha_t, alpha_bar_s, alpha_bar_t, marginal_list[i])
            Q_t_i = Q_t_i.to(device)
            Q_bar_s_i = Q_bar_s_i.to(device)
            Q_bar_t_i = Q_bar_t_i.to(device)

            # (num_rows, num_classes)
            left_term_i = Z_t_i @ torch.transpose(Q_t_i, -1, -2)
            # (num_rows, 1, num_classes)
            left_term_i = left_term_i.unsqueeze(dim=-2)
            # (1, num_classes, num_classes)
            right_term_i = Q_bar_s_i.unsqueeze(dim=-3)
            # (num_rows, num_classes, num_classes)
            numerator_i = left_term_i * right_term_i

            # (num_classes, num_rows)
            prod_i = Q_bar_t_i @ torch.transpose(Z_t_i, -1, -2)
            # (num_rows, num_classes)
            prod_i = torch.transpose(prod_i, -1, -2)
            # (num_rows, num_classes, 1)
            denominator_i = prod_i.unsqueeze(-1)
            denominator_i[denominator_i == 0.] = 1.
            # (num_rows, num_classes, num_classes)
            out_i = numerator_i / denominator_i

            # (num_rows, num_classes, num_classes)
            prob_i = Z_0_i.unsqueeze(-1) * out_i
            # (num_rows, num_classes)
            prob_i = prob_i.sum(dim=-2)
            prob_i = prob_i / (prob_i.sum(dim=-1, keepdim=True) + 1e-6)

            # Get the probabilities for edge existence.
            prob_i = prob_i[:, 1]
            prob_i = prob_i.reshape(num_new_nodes_list[i], -1)
            e_mask_i = torch.bernoulli(prob_i)

            isolated_mask_i = (e_mask_i.sum(dim=1) == 0).bool()
            if isolated_mask_i.any():
                e_mask_i[isolated_mask_i, prob_i[isolated_mask_i].argmax(dim=1)] = 1

            e_mask_list.append(e_mask_i.reshape(-1))

        return torch.cat(e_mask_list).bool()

    @torch.no_grad()
    def sample_node_layer(self,
                          A,
                          x_n,
                          abs_level,
                          rel_level,
                          A_n2g,
                          curr_level=None,
                          y=None,
                          min_num_steps_n=None,
                          max_num_steps_n=None):
        device = A.device

        node_count_logits = self.node_count_model(A, x_n, abs_level,
                                                  rel_level, A_n2g=A_n2g, y=y)

        # For the first layer, the layer size must be nonzero.
        if curr_level == 0:
            node_count_logits[:, 0] = float('-inf')

        node_count_probs = node_count_logits.softmax(dim=-1)
        num_new_nodes = node_count_probs.multinomial(1)

        num_new_nodes_total = num_new_nodes.sum().item()
        batch_size = num_new_nodes.shape[0]
        if num_new_nodes_total == 0:
            return [torch.LongTensor([]).to(device)
                    for _ in range(batch_size)]

        num_classes_list = self.node_diffusion.num_classes_list
        marginal_list = self.node_diffusion.m_list
        D = len(num_classes_list)

        x_n_t = []
        for d in range(D):
            marginal_d = marginal_list[d]
            prior_d = marginal_d[0][None, :].expand(num_new_nodes_total, -1)
            # (num_new_nodes_total)
            x_n_t_d = prior_d.multinomial(1).squeeze(-1)
            x_n_t.append(x_n_t_d)
        x_n_t = torch.stack(x_n_t, dim=1).to(device)

        # Iteratively sample p(D^s | D^t) for t = 1, ..., T, with s = t - 1.
        h_g = self.node_pred_model.graph_encoder(A, x_n, abs_level, rel_level,
                                                 y=y, A_n2g=A_n2g)

        num_query_cumsum = torch.cumsum(torch.tensor(
            [0] + num_new_nodes.squeeze(-1).tolist()), dim=0).long().to(device)
        query2g = []
        for i in range(batch_size):
            query2g.append(torch.ones(num_query_cumsum[i+1] - num_query_cumsum[i]).fill_(i).long())
        query2g = torch.cat(query2g).to(device)

        T_x_n = self.node_diffusion.T
        if max_num_steps_n is not None:
            T_x_n = min(T_x_n, max_num_steps_n)

        time_x_n_list = list(reversed(range(0, T_x_n)))
        if min_num_steps_n is not None:
            num_steps_n = min_num_steps_n + int(
                (T_x_n - min_num_steps_n) * (curr_level / self.max_level)
            )
            time_x_n_list = time_x_n_list[-num_steps_n:]

        for s_x_n in time_x_n_list:
            t_x_n = s_x_n + 1

            # Note that computing Q_bar_t from alpha_bar_t is the same
            # as computing Q_t from alpha_t.
            alpha_t = self.node_diffusion.alphas[t_x_n]
            alpha_bar_s = self.node_diffusion.alpha_bars[s_x_n]
            alpha_bar_t = self.node_diffusion.alpha_bars[t_x_n]

            t_x_n_tensor = torch.LongTensor([[t_x_n]]).expand(batch_size, -1).to(device)
            x_n_0_logits = self.node_pred_model.forward_with_h_g(
                h_g, x_n_t, t_x_n_tensor, query2g,
                num_query_cumsum)

            x_n_s = []
            for d in range(D):
                Q_t_d = self.node_diffusion.get_Q(alpha_t, d).to(device)
                Q_bar_s_d = self.node_diffusion.get_Q(alpha_bar_s, d).to(device)
                Q_bar_t_d = self.node_diffusion.get_Q(alpha_bar_t, d).to(device)

                x_n_0_probs_d = x_n_0_logits[d].softmax(dim=-1)
                # (num_new_nodes, num_classes)
                x_n_t_one_hot_d = F.one_hot(x_n_t[:, d], num_classes=num_classes_list[d]).float()

                x_n_s_probs_d = self.posterior(x_n_t_one_hot_d, Q_t_d, Q_bar_s_d,
                                               Q_bar_t_d, x_n_0_probs_d)
                x_n_s_d = x_n_s_probs_d.multinomial(1).squeeze(-1)
                x_n_s.append(x_n_s_d)

            x_n_t = torch.stack(x_n_s, dim=1)

        return torch.split(x_n_t, num_new_nodes.squeeze(-1).tolist())

    @torch.no_grad()
    def sample_edge_layer(self, num_nodes_cumsum, edge_index_list,
                          batch_x_n, batch_abs_level, batch_rel_level,
                          num_new_nodes_list, batch_query_src, batch_query_dst,
                          query_src_list, query_dst_list, x_n_list,
                          batch_y=None,
                          curr_level=None,
                          min_num_steps_e=None,
                          max_num_steps_e=None, refine=False,
                          inner_refine=False, refine_with_transformer=False, check_with_refine=False):
        device = batch_x_n.device

        e_t_mask_list = []
        batch_size = len(num_new_nodes_list)
        marginal_list = []
        num_query_list = []
        for i in range(batch_size):
            num_query_i = len(query_src_list[i])
            num_query_list.append(num_query_i)

            num_new_nodes_i = num_new_nodes_list[i]
            prior_i = torch.ones(num_query_i).reshape(num_new_nodes_i, -1)
            mean_in_deg_i = min(self.edge_diffusion.avg_in_deg, prior_i.shape[1])
            marginal_i = mean_in_deg_i / prior_i.shape[1]
            marginal_list.append(marginal_i)
            prior_i = prior_i * marginal_i
            e_t_mask_i = torch.bernoulli(prior_i)
            isolated_mask = (e_t_mask_i.sum(dim=1) == 0).bool()
            if isolated_mask.any():
                e_t_mask_i[isolated_mask, torch.zeros(int(isolated_mask.sum().item())).long()] = 1
            e_t_mask_list.append(e_t_mask_i.reshape(-1))

        e_t_mask = torch.cat(e_t_mask_list).bool().to(device)


        num_nodes = len(batch_x_n)
        num_queries = len(batch_query_src)

        batch_edge_index = self.get_batch_A(
            num_nodes_cumsum, edge_index_list, device,
            return_edge_index=True)

        # Iteratively sample p(D^s | D^t) for t = 1, ..., T, with s = t - 1.
        T_x_e = self.edge_diffusion.T
        if max_num_steps_e is not None:
            T_x_e = min(T_x_e, max_num_steps_e)

        time_x_e_list = list(reversed(range(0, T_x_e)))
        if min_num_steps_e is not None:
            num_steps_e = min_num_steps_e + int(
                (T_x_e - min_num_steps_e) * (curr_level / self.max_level)
            )
            time_x_e_list = time_x_e_list[-num_steps_e:]

        query_src_dst_map_list = []

        for i in range(len(query_src_list)):
            query_src = query_src_list[i].tolist()
            query_dst = query_dst_list[i].tolist()
            query_src_dst_map = {}
            index = 0
            for j in range(len(query_src)):
                query_src_dst_map[(query_src[j], query_dst[j])] = index
                index += 1
            query_src_dst_map_list.append(query_src_dst_map.copy())
        if inner_refine:
            if not refine_with_transformer:
                e_t_mask = self.sample_refine(e_t_mask, batch_query_dst, batch_query_src, query_src_list,
                                          query_dst_list, num_nodes_cumsum, time_x_e_list[0], batch_edge_index,
                                          batch_x_n, device, batch_size, num_query_list, query_src_dst_map_list)
            else:
                e_t_mask = self.sample_refine_with_transformer(e_t_mask, query_src_list,query_dst_list,
                                          x_n_list, device, batch_size, num_query_list, query_src_dst_map_list)

        for s_x_e in time_x_e_list:
            t_x_e = s_x_e + 1

            # Note that computing Q_bar_t from alpha_bar_t is the same
            # as computing Q_t from alpha_t.
            alpha_t = self.edge_diffusion.alphas[t_x_e]
            alpha_bar_s = self.edge_diffusion.alpha_bars[s_x_e]
            alpha_bar_t = self.edge_diffusion.alpha_bars[t_x_e]

            edge_index_t = torch.stack([
                batch_query_dst[e_t_mask],
                batch_query_src[e_t_mask]
            ]).to(device)

            A = dglsp.spmatrix(
                torch.cat([batch_edge_index, edge_index_t], dim=1),
                shape=(num_nodes, num_nodes)).to(device)
            t_x_e_tensor = torch.LongTensor([t_x_e])[None, :].expand(
                num_queries, -1).to(device)
            e_0_logits = self.edge_pred_model(
                A, batch_x_n, batch_abs_level, batch_rel_level, t_x_e_tensor,
                batch_query_src, batch_query_dst, batch_y)
            e_0_probs = e_0_logits.softmax(dim=-1)
            # (num_queries, num_classes)
            e_t_one_hot = F.one_hot(e_t_mask.long(), num_classes=2).float()
            e_t_mask = self.posterior_edge(e_t_one_hot,
                                           alpha_t,
                                           alpha_bar_s,
                                           alpha_bar_t,
                                           e_0_probs,
                                           marginal_list,
                                           num_new_nodes_list,
                                           num_query_list)
            if inner_refine:
                if not refine_with_transformer:
                    e_t_mask = self.sample_refine(e_t_mask, batch_query_dst, batch_query_src, query_src_list,
                                              query_dst_list, num_nodes_cumsum, 0, batch_edge_index,
                                              batch_x_n, device, batch_size, num_query_list, query_src_dst_map_list)
                else:
                    e_t_mask = self.sample_refine_with_transformer(e_t_mask, query_src_list,query_dst_list,
                                          x_n_list, device, batch_size, num_query_list, query_src_dst_map_list)

        if refine and (not inner_refine):
            if not refine_with_transformer:
                e_t_mask = self.sample_refine(e_t_mask, batch_query_dst, batch_query_src, query_src_list,
                                          query_dst_list, num_nodes_cumsum, 0, batch_edge_index,
                                          batch_x_n, device, batch_size, num_query_list, query_src_dst_map_list)
            else:
                e_t_mask = self.sample_refine_with_transformer(e_t_mask, query_src_list,query_dst_list,
                                          x_n_list, device, batch_size, num_query_list, query_src_dst_map_list)

        num_query_split = [len(query_src_i) for query_src_i in query_src_list]
        e_t_mask_split = torch.split(e_t_mask, num_query_split)

        edge_index_list_ = []
        for i in range(len(edge_index_list)):
            edge_index_i = edge_index_list[i]
            e_t_mask_i = e_t_mask_split[i]
            edge_index_l_i = torch.stack([
                query_dst_list[i][e_t_mask_i],
                query_src_list[i][e_t_mask_i]
            ])
            edge_index_i = torch.cat([edge_index_i, edge_index_l_i], dim=1)
            edge_index_list_.append(edge_index_i)
        edge_index_list = edge_index_list_

        return edge_index_list

    def sample_refine(self, e_t_mask, batch_query_dst, batch_query_src, query_src_list, query_dst_list,
                      num_nodes_cumsum, t_x_e, batch_edge_index, batch_x_n, device, batch_size, num_query_list, query_src_dst_map_list):
        batch_noisy_edge_index = torch.stack([
            batch_query_dst[e_t_mask],
            batch_query_src[e_t_mask]
        ]).to(device)
        batch_noisy_src = []
        batch_noisy_dst = []
        e_t_mask_list = torch.split(e_t_mask.cpu(), num_query_list)
        assert len(query_src_list) == len(query_dst_list) == len(e_t_mask_list) == batch_size
        for i in range(batch_size):
            query_src = query_src_list[i]
            query_dst = query_dst_list[i]
            e_t_mask_i = e_t_mask_list[i].tolist()
            new_query_src = query_src[e_t_mask_i].tolist()
            new_query_dst = query_dst[e_t_mask_i].tolist()
            assert len(query_src) == len(query_dst)
            assert len(new_query_src) != 0
            batch_noisy_src.append(new_query_src)
            batch_noisy_dst.append(new_query_dst)
        item_data_map = {'batch_noisy_src': batch_noisy_src,
                         'batch_noisy_dst': batch_noisy_dst, 'num_nodes_cumsum': num_nodes_cumsum,
                         'batch_size': batch_size,
                         'batch_t': [torch.tensor(t_x_e).unsqueeze(dim=0) for _ in range(batch_size)]}
        noisy_edge_index_new, new_batch_noisy_src, new_batch_noisy_dst, pred_total = self.refine_model.refine(batch_edge_index, batch_noisy_edge_index, batch_x_n.squeeze(dim=1), item_data_map)
        new_e_t_mask_list = []
        for i in range(batch_size):
            query_src_dst_map = query_src_dst_map_list[i]
            query_src = query_src_list[i].tolist()
            new_noisy_src_i_list = new_batch_noisy_src[i].tolist()
            new_noisy_dst_i_list = new_batch_noisy_dst[i].tolist()
            new_e_t_mask_i = [False] * len(query_src)
            for j in range(len(new_noisy_src_i_list)):
                index = query_src_dst_map[(new_noisy_src_i_list[j], new_noisy_dst_i_list[j])]
                new_e_t_mask_i[index] = True
            new_e_t_mask_list.extend(new_e_t_mask_i)
        assert len(new_e_t_mask_list) == len(e_t_mask.cpu().tolist())
        new_e_t_mask = torch.tensor(new_e_t_mask_list, dtype=torch.bool).to(device)
        return new_e_t_mask

    def sample_refine_with_transformer(self, e_t_mask, query_src_list, query_dst_list, x_n_list, device, batch_size, num_query_list, query_src_dst_map_list):

        e_t_mask_list = torch.split(e_t_mask.cpu(), num_query_list)
        assert len(query_src_list) == len(query_dst_list) == len(e_t_mask_list) == len(x_n_list) == batch_size
        src_list = []
        noisy_list = []
        new_query_src_list = []
        new_query_dst_list = []
        split_list = []
        unsatisfy_data_list = []
        for i in range(batch_size):
            input_x_n = x_n_list[i].squeeze(dim=1).tolist()
            query_src = query_src_list[i]
            query_dst = query_dst_list[i]
            e_t_mask_i = e_t_mask_list[i].tolist()

            new_query_src_ = query_src[e_t_mask_i].tolist()
            new_query_dst_ = query_dst[e_t_mask_i].tolist()

            unsatisfy_dst_list, unsatisfy_src_list = check_predecessor_balance_src_dst(input_x_n[1:], new_query_src_, new_query_dst_, self.rho)
            if unsatisfy_dst_list is None:
                unsatisfy_data = {'refine': False, 'e_t_mask_i': e_t_mask_i.copy()}
            else:
                new_query_src = list(sorted(set(unsatisfy_src_list)))
                new_query_dst = list(sorted(set(unsatisfy_dst_list)))
                query_src_ = []
                query_dst_ = []
                remain_query = []
                for src, dst in zip(new_query_src_, new_query_dst_):
                    if dst in new_query_dst:
                        assert src in new_query_src
                        query_src_.append(src)
                        query_dst_.append(dst)
                    else:
                        remain_query.append((src, dst))
                assert len(query_src_) != 0
                x_n = [input_x_n[i] for i in new_query_src]
                df = pd.DataFrame({'src': query_src_, 'dst': query_dst_})
                unique_src = sorted(df['src'].unique())
                num_groups = len(df.groupby('dst'))
                # groupby + apply + tolist
                result = (
                    df.groupby('dst')['src']
                    .apply(lambda grp: [int(x in grp.values) for x in unique_src])
                    .sort_index()  
                    .tolist()
                )
                assert len(result) == num_groups
                src_list.extend([x_n] * num_groups)
                noisy_list.extend(result)

                new_query_src_list.extend([new_query_src] * num_groups)
                new_query_dst_list.extend([[val] * len(new_query_src) for val in new_query_dst])
                split_list.append(num_groups)
                unsatisfy_data = {'refine': True, 'index': i, 'remain_query': remain_query.copy(),'e_t_mask_i': e_t_mask_i.copy()}
            unsatisfy_data_list.append(unsatisfy_data)
            # edge = []
            # for src, dst in zip(new_query_src, new_query_dst):
            #     edge.append((src, dst))
            # for dst in list(sorted(set(new_query_dst))):
            #     noisy_label = []
            #     src_list.append([input_x_n[i] for i in query_src])
            #     for src in query_src:
            #         if (src, dst) not in edge:
            #             noisy_label.append(0)
            #         else:
            #             noisy_label.append(1)
            #     noisy_list.append(noisy_label.copy())
        # print(len(src_list))
        if len(src_list) == 0:
            return e_t_mask
        refine_edges = self.refine_model.refine(src_list, noisy_list)
        idx = 0
        new_e_t_mask_list = []
        #assert len(split_list) == batch_size
        #for i, length in enumerate(split_list):
        i = 0
        for unsatisfy_data in unsatisfy_data_list:
            if not unsatisfy_data['refine']:
                new_e_t_mask_list.extend(unsatisfy_data['e_t_mask_i'])
            else:
                length = split_list[i]
                i += 1
                pred = refine_edges[idx:idx + length]
                query_src = new_query_src_list[idx:idx + length]
                query_dst = new_query_dst_list[idx:idx + length]
                query_src_dst_map = query_src_dst_map_list[unsatisfy_data['index']]
                idx += length
                pred = np.concatenate(pred)
                query_src = np.concatenate(query_src)
                query_dst = np.concatenate(query_dst)
                mask = (pred == 1)
                query_src_sel = query_src[mask].tolist()
                query_dst_sel = query_dst[mask].tolist()
                new_e_t_mask_i = [False] * len(unsatisfy_data['e_t_mask_i'])
                for src, dst in zip(query_src_sel, query_dst_sel):
                    index = query_src_dst_map[(src, dst)]
                    new_e_t_mask_i[index] = True
                for src, dst in unsatisfy_data['remain_query']:
                    index = query_src_dst_map[(src, dst)]
                    new_e_t_mask_i[index] = True
                new_e_t_mask_list.extend(new_e_t_mask_i)

        assert len(new_e_t_mask_list) == len(e_t_mask.cpu().tolist())
        new_e_t_mask = torch.tensor(new_e_t_mask_list, dtype=torch.bool).to(device)
        # new_e_t_mask_list = torch.split(new_e_t_mask.cpu(), num_query_list)
        # for i in range(batch_size):
        #     input_x_n = x_n_list[i].squeeze(dim=1).tolist()
        #     query_src = query_src_list[i]
        #     query_dst = query_dst_list[i]
        #     e_t_mask_i = e_t_mask_list[i].tolist()
        #     new_e_t_mask_i = new_e_t_mask_list[i].tolist()

        #     query_src_ = query_src[e_t_mask_i].tolist()
        #     query_dst_ = query_dst[e_t_mask_i].tolist()
        #     new_query_src_ = query_src[new_e_t_mask_i].tolist()
        #     new_query_dst_ = query_dst[new_e_t_mask_i].tolist()

        #     unsatisfy_dst_list, unsatisfy_src_list = check_predecessor_balance_src_dst(input_x_n[1:], new_query_src_, new_query_dst_)
        #     if unsatisfy_dst_list is not None:
        #         print(input_x_n)
        #         print(query_src_, query_dst_)
        #         print(new_query_src_, new_query_dst_)
        return new_e_t_mask

    def get_batch_A(self, num_nodes_cumsum, edge_index_list, device, return_edge_index=False):
        batch_size = len(edge_index_list)
        edge_index_list_ = []
        for i in range(batch_size):
            edge_index_list_.append(edge_index_list[i] + num_nodes_cumsum[i])

        batch_edge_index = torch.cat(edge_index_list_, dim=1)

        if return_edge_index:
            return batch_edge_index

        N = num_nodes_cumsum[-1].item()
        batch_A = dglsp.spmatrix(batch_edge_index, shape=(N, N)).to(device)

        return batch_A

    def get_batch_A_n2g(self, num_nodes_cumsum, device):
        batch_size = len(num_nodes_cumsum) - 1
        nids = []
        gids = []
        for i in range(batch_size):
            nids.append(torch.arange(num_nodes_cumsum[i], num_nodes_cumsum[i+1]).long())
            gids.append(torch.ones(num_nodes_cumsum[i+1] - num_nodes_cumsum[i]).fill_(i).long())

        nids = torch.cat(nids, dim=0)
        gids = torch.cat(gids, dim=0)
        n2g_index = torch.stack([gids, nids])

        N = num_nodes_cumsum[-1].item()
        batch_A_n2g = dglsp.spmatrix(n2g_index, shape=(batch_size, N)).to(device)

        return batch_A_n2g

    def get_batch_y(self, y_list, x_n_list, device):
        if y_list is None:
            return None

        y_list_ = []
        for i in range(len(x_n_list)):
            y_list_.append(torch.zeros(len(x_n_list[i]), 1).fill_(y_list[i]))
        batch_y = torch.cat(y_list_).to(device)

        return batch_y

    @torch.no_grad()
    def sample(self,
               device,
               batch_size=1,
               y=None,
               min_num_steps_n=None,
               max_num_steps_n=None,
               min_num_steps_e=None,
               max_num_steps_e=None,
               check=False, solve=False, refine=False,
               inner_refine=False, refine_with_transformer=False,
               check_with_refine=False, mode='train'):
        if y is not None:
            assert batch_size == len(y)
        y_list = y

        edge_index_list = [
            torch.LongTensor([[], []]).to(device)
            for _ in range(batch_size)
        ]

        if isinstance(self.dummy_x_n, int):
            init_x_n = torch.LongTensor([[self.dummy_x_n]]).to(device)
        elif isinstance(self.dummy_x_n, torch.Tensor):
            init_x_n = self.dummy_x_n.to(device).unsqueeze(0)
        else:
            raise NotImplementedError
        x_n_list = [init_x_n for _ in range(batch_size)]
        batch_x_n = torch.cat(x_n_list)
        batch_y = self.get_batch_y(y_list, x_n_list, device)

        level = 0.
        abs_level_list = [
            torch.tensor([[level]]).to(device)
            for _ in range(batch_size)
        ]
        batch_abs_level = torch.cat(abs_level_list)
        batch_rel_level = batch_abs_level.max() - batch_abs_level

        edge_index_finished = []
        x_n_finished = []
        if y is not None:
            y_finished = []


        num_nodes_cumsum = torch.cumsum(torch.tensor([0] + [len(x_n_i) for x_n_i in x_n_list]), dim=0)
        while True:
            batch_A = self.get_batch_A(num_nodes_cumsum, edge_index_list, device)
            batch_A_n2g = self.get_batch_A_n2g(num_nodes_cumsum, device)
            x_n_l_list = self.sample_node_layer(
                batch_A, batch_x_n, batch_abs_level, batch_rel_level,
                batch_A_n2g, curr_level=level,
                y=batch_y,
                min_num_steps_n=min_num_steps_n,
                max_num_steps_n=max_num_steps_n)

            edge_index_list_ = []
            x_n_list_ = []
            abs_level_list_ = []
            query_src_list = []
            query_dst_list = []
            num_new_nodes_list = []
            batch_query_src = []
            batch_query_dst = []

            if y is not None:
                y_list_ = []
            else:
                y_list_ = None

            level += 1
            node_count = 0
            for i, x_n_l_i in enumerate(x_n_l_list):
                if len(x_n_l_i) == 0:
                    edge_index_finished.append(edge_index_list[i] - 1)
                    x_n_finished.append(x_n_list[i][1:])
                    if y is not None:
                        y_finished.append(y_list[i])
                else:
                    edge_index_list_.append(edge_index_list[i])
                    x_n_list_.append(torch.cat([x_n_list[i], x_n_l_i]))
                    if y is not None:
                        y_list_.append(y_list[i])
                    abs_level_list_.append(
                        torch.cat([
                            abs_level_list[i],
                            torch.zeros(len(x_n_l_i), 1).fill_(level).to(device)
                        ])
                    )

                    N_old_i = len(x_n_list[i])
                    N_new_i = len(x_n_l_i)

                    query_src_i = []
                    query_dst_i = []

                    src_candidates_i = list(range(1, N_old_i))
                    for dst_i in range(N_old_i, N_old_i + N_new_i):
                        query_src_i.extend(src_candidates_i)
                        query_dst_i.extend([dst_i] * len(src_candidates_i))
                    query_src_i = torch.LongTensor(query_src_i).to(device)
                    query_dst_i = torch.LongTensor(query_dst_i).to(device)

                    query_src_list.append(query_src_i)
                    query_dst_list.append(query_dst_i)
                    batch_query_src.append(query_src_i + node_count)
                    batch_query_dst.append(query_dst_i + node_count)
                    num_new_nodes_list.append(N_new_i)

                    node_count = node_count + N_old_i + N_new_i

            edge_index_list = edge_index_list_
            x_n_list = x_n_list_
            y_list = y_list_
            abs_level_list = abs_level_list_
            if len(edge_index_list) == 0:
                break
            num_nodes_cumsum = torch.cumsum(torch.tensor([0] + [len(x_n_i) for x_n_i in x_n_list]), dim=0)

            batch_x_n = torch.cat(x_n_list)
            batch_abs_level = torch.cat(abs_level_list)
            batch_rel_level = batch_abs_level.max() - batch_abs_level
            batch_y = self.get_batch_y(y_list, x_n_list, device)

            if level == 1:
                continue

            batch_query_src = torch.cat(batch_query_src)
            batch_query_dst = torch.cat(batch_query_dst)
            edge_index_list_old = edge_index_list.copy()
            edge_index_list = self.sample_edge_layer(
                num_nodes_cumsum, edge_index_list, batch_x_n, batch_abs_level,
                batch_rel_level, num_new_nodes_list, batch_query_src,
                batch_query_dst, query_src_list, query_dst_list, x_n_list, batch_y,
                curr_level=level,
                min_num_steps_e=min_num_steps_e,
                max_num_steps_e=max_num_steps_e, refine=refine, inner_refine=inner_refine,
                refine_with_transformer=refine_with_transformer, check_with_refine=check_with_refine)

            if check:
                total = 0
                correct = 0
                assert len(x_n_list) == len(edge_index_list)
                x_n_list_temp = []
                edge_index_list_temp = []
                abs_level_list_temp = []
                unsatisfy_x_n_list_temp = []
                unsatisfy_abs_level_list_temp = []

                unsatisfy_edge_index_list_temp = []
                unsatisfy_list = []
                unsatisfy_src_list = []
                if y is not None:
                    y_list_temp = []
                    unsatisfy_y_list_temp = []
                    for x_n_groups, z_t, z_t_old, abs_level, y in zip(x_n_list, edge_index_list, edge_index_list_old, abs_level_list, y_list):
                        x_n = []
                        for j, row in enumerate(x_n_groups[1:, :]):
                            x_n.append(row.tolist())
                        unsatisfy, src_list = check_predecessor_balance(x_n, z_t, rho=self.rho)
                        if unsatisfy is None and src_list is None:
                            x_n_list_temp.append(x_n_groups)
                            edge_index_list_temp.append(z_t)
                            abs_level_list_temp.append(abs_level)
                            y_list_temp.append(y)
                        else:
                            unsatisfy_x_n_list_temp.append(x_n_groups)
                            unsatisfy_abs_level_list_temp.append(abs_level)
                            unsatisfy_y_list_temp.append(y)
                            unsatisfy_edge_index_list_temp.append(z_t)
                            unsatisfy_list.append(unsatisfy)
                            unsatisfy_src_list.append(src_list)
                            if solve:
                                z_t_new = solve_edge_constraint(x_n, z_t, z_t_old, unsatisfy, rho=self.rho)
                                if z_t_new is not None:
                                    x_n_list_temp.append(x_n_groups)
                                    edge_index_list_temp.append(z_t_new)
                                    abs_level_list_temp.append(abs_level)
                                    y_list_temp.append(y)
                    y_list = list(y_list_temp)
                else:
                    for x_n_groups, z_t, z_t_old, abs_level in zip(x_n_list, edge_index_list, edge_index_list_old,abs_level_list):
                        x_n = []
                        for j, row in enumerate(x_n_groups[1:, :]):
                            x_n.append(row.tolist())
                        unsatisfy, src_list = check_predecessor_balance(x_n, z_t, rho=self.rho)
                        if unsatisfy is None and src_list is None:
                            x_n_list_temp.append(x_n_groups)
                            edge_index_list_temp.append(z_t)
                            abs_level_list_temp.append(abs_level)
                            correct += 1
                        else:
                            unsatisfy_x_n_list_temp.append(x_n_groups)
                            unsatisfy_abs_level_list_temp.append(abs_level)
                            unsatisfy_edge_index_list_temp.append(z_t)
                            unsatisfy_list.append(unsatisfy)
                            unsatisfy_src_list.append(src_list)
                            if solve:
                                z_t_new = solve_edge_constraint(x_n, z_t, z_t_old, unsatisfy, self.rho)
                                if z_t_new is not None:
                                    unsatisfy, src_list = check_predecessor_balance(x_n, z_t_new, rho=self.rho)
                                    assert unsatisfy is None and src_list is None
                                    x_n_list_temp.append(x_n_groups)
                                    edge_index_list_temp.append(z_t_new)
                                    abs_level_list_temp.append(abs_level)
                                    if self.rho > 0.6:
                                        src_list_new, dst_list_new = z_t_new[1].cpu().tolist(), z_t_new[0].cpu().tolist()
                                        src_list_old, dst_list_old = z_t[1].cpu().tolist(), z_t[0].cpu().tolist()
                                        src_list, dst_list = z_t_old[1].cpu().tolist(), z_t_old[0].cpu().tolist()
                                        input_src = src_list
                                        input_dst = dst_list
                                        assert len(input_src) == len(input_dst)
                                        noisy_src = src_list_new[len(input_dst):]
                                        noisy_dst = dst_list_new[len(input_dst):]
                                        noisy_src_old = src_list_old[len(input_dst):]
                                        noisy_dst_old = dst_list_old[len(input_dst):]
                                        input_x_n = x_n
                                        t = [0]
                                        check = check_predecessor_balance_src_dst(x_n=input_x_n,
                                                                            src_list=input_src + noisy_src,
                                                                            dst_list=input_dst + noisy_dst,
                                                                            rho=self.rho)
                                        check_old = check_predecessor_balance_src_dst(x_n=input_x_n,
                                                                            src_list=input_src + noisy_src_old,
                                                                            dst_list=input_dst + noisy_dst_old,
                                                                            rho=self.rho)
                                        assert check == True and check_old == False
                                        data = {'input_src': input_src,
                                                'input_dst': input_dst,
                                                'noisy_src': noisy_src,
                                                'noisy_dst': noisy_dst,
                                                'noisy_src_old': noisy_src_old,
                                                'noisy_dst_old': noisy_dst_old,
                                                'input_x_n': input_x_n,
                                                't': t,
                                                }
                                        with open(f'./data/{mode}_1.jsonl', "a") as f:
                                            json.dump(data, f)
                                            f.write('\n')

                        total += 1
                if check_with_refine and (not inner_refine):
                    refine_index, new_edge_index_list_temp = self.sample_with_refine_transformer(unsatisfy_x_n_list_temp, unsatisfy_list, unsatisfy_src_list, unsatisfy_edge_index_list_temp, device)
                    for i, new_edge_index in zip(refine_index, new_edge_index_list_temp):
                        x_n_list_temp.append(unsatisfy_x_n_list_temp[i])
                        edge_index_list_temp.append(new_edge_index)
                        abs_level_list_temp.append(unsatisfy_abs_level_list_temp[i])
                        if y is not None:
                            y_list_temp.append(unsatisfy_y_list_temp[i])
                        x_n = []
                        for j, row in enumerate(unsatisfy_x_n_list_temp[i][1:, :]):
                            x_n.append(row.tolist())
                        u, s = check_predecessor_balance(x_n, new_edge_index, rho=self.rho)
                        assert u is None and s is None
                        correct += 1

                x_n_list = list(x_n_list_temp)
                edge_index_list = list(edge_index_list_temp)
                abs_level_list = list(abs_level_list_temp)
                if len(edge_index_list) == 0:
                    break
                num_nodes_cumsum = torch.cumsum(torch.tensor([0] + [len(x_n_i) for x_n_i in x_n_list]), dim=0)
                batch_x_n = torch.cat(x_n_list)
                batch_abs_level = torch.cat(abs_level_list)
                batch_rel_level = batch_abs_level.max() - batch_abs_level
                batch_y = self.get_batch_y(y_list, x_n_list, device)
                #(f'correct = {correct} / total = {total}')

            if self.max_level is not None and level == self.max_level:
                break

        for i in range(len(edge_index_list)):
            edge_index_finished.append(edge_index_list[i] - 1)
            x_n_finished.append(x_n_list[i][1:])

        if y is None:
            return edge_index_finished, x_n_finished
        else:
            y_finished.extend(y_list)
            return edge_index_finished, x_n_finished, y_finished

    def sample_with_refine_transformer(self, x_n_list, unsatisfy_list, unsatisfy_src_list, z_t_list, device):
        assert len(x_n_list) == len(unsatisfy_list) == len(unsatisfy_src_list) == len(z_t_list)
        refine_index = []
        refine_z_t_list = []
        unsatisfy_data_list = []
        for i, (x_n, dst, src, z_t) in enumerate(zip(x_n_list, unsatisfy_list, unsatisfy_src_list, z_t_list)):
            x_n_list_temp = []
            for j, row in enumerate(x_n[1:, :]):
                x_n_list_temp.append(row.tolist())
                
            input_x_n = [x[0] for x in x_n.cpu().tolist()]
            query_src = list(sorted(set(src)))
            query_dst = list(sorted(set(dst)))
            src_x_n = [input_x_n[i] for i in query_src]
            result_src = z_t[1].cpu().tolist()
            result_dst = z_t[0].cpu().tolist()
            result_src_np = np.array(result_src)
            result_dst_np = np.array(result_dst)

            indices = np.isin(result_dst_np, query_dst)

            new_query_src_ = result_src_np[indices].tolist()
            new_query_dst_ = result_dst_np[indices].tolist()
            
            unsatisfy_data = {'edges_src': src_x_n.copy(), 'x_n': x_n_list_temp.copy(), 'refine': True}
            current_src = result_src_np[~indices].tolist()
            current_dst = result_dst_np[~indices].tolist()
            unsatisfy_data['z_t'] = torch.stack([torch.tensor(current_dst, device=device), torch.tensor(current_src, device=device)])

            df = pd.DataFrame({'src': new_query_src_, 'dst': new_query_dst_})
            unique_src = sorted(df['src'].unique())
            num_groups = len(df.groupby('dst'))
            #  groupby + apply + tolist
            result = (
                df.groupby('dst')['src']
                .apply(lambda grp: [int(x in grp.values) for x in unique_src])
                .sort_index()
                .tolist()
            )
            assert len(result) == num_groups
            unsatisfy_data['edges_noisy'] = result.copy()
            unsatisfy_data['query_src'] = query_src.copy()
            unsatisfy_data['query_dst'] = query_dst.copy()
            unsatisfy_data['index'] = i
            #assert len(unsatisfy_data['edges_noisy']) == len(unsatisfy_data['query_dst'])
            unsatisfy_data_list.append(unsatisfy_data)
        
        is_first = True
        for _ in range(5):
            src_list = []
            noisy_list = []
            new_query_src_list = []
            new_query_dst_list = []
            split_list = []
            index_list = []
            count = 0
            for unsatisfy_data in unsatisfy_data_list:
                if unsatisfy_data['refine']:
                    src_list_temp = [unsatisfy_data['edges_src']] * len(unsatisfy_data['edges_noisy'])
                    noisy_list_temp = unsatisfy_data['edges_noisy'] 
                    add = True
                    for src, noisy in zip(src_list_temp, noisy_list_temp):
                        assert len(src) == len(noisy)
                        if len(noisy) > 62:
                            refine_index.append(unsatisfy_data['index'])
                            unsatisfy_data['refine'] = False
                            refine_z_t_list.append(unsatisfy_data['z_t'])
                            add = False
                            break
                    if add:
                        src_list.extend(src_list_temp)
                        noisy_list.extend(noisy_list_temp)
                        new_query_src_list.extend([unsatisfy_data['query_src']] * len(unsatisfy_data['edges_noisy']))
                        new_query_dst_list.extend([[val] * len(unsatisfy_data['query_src']) for val in unsatisfy_data['query_dst']])
                        split_list.append(len(unsatisfy_data['edges_noisy']))
                        index_list.append(unsatisfy_data['index'])
                        count += 1
                        
            if count == 0:
                break
            refine_edges = self.refine_model.refine(src_list, noisy_list, is_first)
            is_first = False
            idx = 0

            for i, length in zip(index_list, split_list):
                pred = refine_edges[idx:idx + length]
                query_src = new_query_src_list[idx:idx + length]
                query_dst = new_query_dst_list[idx:idx + length]
                idx += length
                pred = np.concatenate(pred)
                query_src = np.concatenate(query_src)
                query_dst = np.concatenate(query_dst)
                mask = (pred == 1)
                query_src_sel = query_src[mask].tolist()
                query_dst_sel = query_dst[mask].tolist()
                unsatisfy_data = unsatisfy_data_list[i]
                src_list = unsatisfy_data['z_t'][1].cpu().tolist()
                dst_list = unsatisfy_data['z_t'][0].cpu().tolist()
            
                for src, dst in zip(query_src_sel, query_dst_sel):
                    src_list.append(src)
                    dst_list.append(dst)
                src_np = np.array(src_list)
                dst_np = np.array(dst_list)

                index = np.argsort(dst_np)
                src_list_sorted = src_np[index].tolist()
                dst_list_sorted = dst_np[index].tolist()
                new_z_t = torch.stack([torch.tensor(dst_list_sorted, device=device), torch.tensor(src_list_sorted, device=device)])
                u, s = check_predecessor_balance(unsatisfy_data['x_n'], new_z_t, rho=self.rho)
                if u is None and s is None:
                    refine_index.append(unsatisfy_data['index'])
                    unsatisfy_data_list[i]['refine'] = False
                    refine_z_t_list.append(new_z_t)

        return refine_index, refine_z_t_list

class Refine:
    def __init__(self, model_path='./model/model_latent_preferential_smt.pth'):
        self.path = model_path
        ckpt = torch.load(self.path, weights_only=True)
        num_x_n_cat = ckpt['num_x_n_cat']
        x_n_emb_size = 256
        num_mpnn_layers = 4
        t_emb_size = 256
        out_hidden_size = 320
        hidden_size = len(num_x_n_cat) * x_n_emb_size
        device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device_str)

        edge_pred_graph_encoder = BiMPNNEncoder(num_x_n_cat, hidden_size=hidden_size,
                                                x_n_emb_size=x_n_emb_size, pe_emb_size=0,
                                                num_mpnn_layers=num_mpnn_layers).to(self.device)
        self.model = EdgeRefineModel(graph_encoder=edge_pred_graph_encoder,
                                in_hidden_size=hidden_size, out_hidden_size=out_hidden_size,
                                t_emb_size=t_emb_size).to(self.device)
        self.model.load_state_dict(ckpt['model_state_dict'])

    def refine(self, batch_edge_index, batch_noisy_edge_index, batch_x_n, item_data_map):
        self.model.eval()
        num_nodes = len(batch_x_n)
        A = dglsp.spmatrix(
            torch.cat([batch_edge_index, batch_noisy_edge_index], dim=1),
            shape=(num_nodes, num_nodes)).to(self.device)
        batch_noisy_src = item_data_map['batch_noisy_src']
        batch_noisy_dst = item_data_map['batch_noisy_dst']
        num_nodes_cumsum = item_data_map['num_nodes_cumsum']
        batch_size = item_data_map['batch_size']
        batch_t = item_data_map['batch_t']
        batch_query_src = []
        batch_query_dst = []
        query_src_list = []
        query_dst_list = []
        t = []
        number_query_list = []
        for i in range(batch_size):
            if isinstance(batch_noisy_src[i], torch.Tensor):
                query_src_set = set(batch_noisy_src[i].tolist())
                query_dst_set = set(batch_noisy_dst[i].tolist())
            else:
                query_src_set = set(batch_noisy_src[i])
                query_dst_set = set(batch_noisy_dst[i])
            query_src_list_i = list(sorted(list(query_src_set)))
            query_dst_list_i = list(sorted(list(query_dst_set)))
            batch_query_src_ = query_src_list_i * len(query_dst_list_i)
            batch_query_dst_ = list(itertools.chain.from_iterable([[x] * len(query_src_list_i) for x in query_dst_list_i]))
            # for query_dst in list(sorted(list(query_dst_set))):
            #     for query_src in list(sorted(list(query_src_set))):
            #         batch_query_src_.append(query_src)
            #         batch_query_dst_.append(query_dst)
            query_src_list.append(batch_query_src_)
            query_dst_list.append(batch_query_dst_)
            batch_query_src.append(torch.tensor(batch_query_src_) + num_nodes_cumsum[i])
            batch_query_dst.append(torch.tensor(batch_query_dst_) + num_nodes_cumsum[i])
            t.append(batch_t[i].expand(len(batch_query_src[i]), -1))
            number_query_list.append(len(batch_query_src_))

        batch_query_src = torch.cat(batch_query_src)
        batch_query_dst = torch.cat(batch_query_dst)
        batch_t = torch.cat(t)

        logits = self.model(A, batch_x_n.to(self.device), batch_t.to(self.device), batch_query_src.to(self.device), batch_query_dst.to(self.device))

        pred_total = torch.argmax(logits, dim=-1).cpu()
        pred_list = torch.split(pred_total, number_query_list)
        assert len(pred_list) == batch_size
        new_batch_noisy_src = []
        new_batch_noisy_dst = []
        assert len(batch_noisy_src) == len(batch_noisy_dst) == batch_size
        for i in range(batch_size):
            pred = pred_list[i].tolist()
            assert len(query_src_list[i]) == len(query_dst_list[i]) == len(pred)
            pred_src = []
            pred_dst = []
            for j in range(len(query_src_list[i])):
                if pred[j] == 1:
                    pred_src.append(query_src_list[i][j])
                    pred_dst.append(query_dst_list[i][j])
            if len(pred_src) == 0:
                with open('log.txt', 'a') as f:
                    f.write('error, model need to train!!!\n')
                pred_src = batch_noisy_src[i]
                pred_dst = batch_noisy_dst[i]
            new_batch_noisy_src.append(torch.tensor(pred_src, dtype=torch.int64))
            new_batch_noisy_dst.append(torch.tensor(pred_dst, dtype=torch.int64))
        noisy_src_ = []
        noisy_dst_ = []
        for i in range(batch_size):
            noisy_src_.append(new_batch_noisy_src[i] + num_nodes_cumsum[i])
            noisy_dst_.append(new_batch_noisy_dst[i] + num_nodes_cumsum[i])
        noisy_src = torch.cat(noisy_src_, dim=0)
        noisy_dst = torch.cat(noisy_dst_, dim=0)

        noisy_edge_index = torch.stack([noisy_dst, noisy_src])
        return noisy_edge_index, new_batch_noisy_src, new_batch_noisy_dst, pred_total


class RefineWithEdgeTransformer:
    def __init__(self, model_path='./model/model_latent_preferential_smt_ar.pth'):
        self.path = model_path
        if os.path.exists(self.path):
            ckpt = torch.load(self.path, weights_only=True)
            device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
            self.device = torch.device(device_str)
            self.model = EdgeRefineModelWithEdgeTransformer().to(self.device)
            self.model.load_state_dict(ckpt['model_state_dict'])
        else:
            self.model = None

    def refine(self, src_lists, noisy_lists, is_first=True, temperature=0.8):
        if self.model is None:
            return noisy_lists
        self.model.eval()
        B = len(src_lists)
        prefixes = [
            [3] + src + [4] + noisy
            for src, noisy in zip(src_lists, noisy_lists)
        ]
        lengths = [len(p) for p in prefixes]
        max_src_len = 128
        # 2) pad to (max_src_len, B)
        pad_id = 2
        src_ids = torch.full(
            (max_src_len, B),
            pad_id,
            dtype=torch.long,
            device=self.device
        )
        for i, p in enumerate(prefixes):
            src_ids[:len(p), i] = torch.tensor(p, dtype=torch.long, device=self.device)
        gen_steps = max(len(src) for src in src_lists)
        dec_inp = torch.full(
            (1, B),
            1,  # <sos>
            dtype=torch.long,
            device=self.device
        )
        for _ in range(gen_steps):
            if is_first:
                logits = self.model(src_ids, dec_inp)  # (T, B, V)
                next_tok = logits[-1].argmax(dim=-1, keepdim=True)  # (B,1)
                dec_inp = torch.cat([dec_inp, next_tok.transpose(0, 1)], dim=0)  # (T+1, B)
            else:
                logits = self.model(src_ids, dec_inp)  # (T, B, 2)
                last_logits = logits[-1]  # (B, 2)
                scaled = last_logits / temperature  # (B, 2)
                probs = F.softmax(scaled, dim=-1)  # (B, 2)
                next_tok = torch.multinomial(probs, num_samples=1)  # (B, 1)
                dec_inp = torch.cat([dec_inp, next_tok.transpose(0, 1)], dim=0)

        gens = dec_inp[1:]  # (gen_steps, B)
        results = []
        for i, src in enumerate(src_lists):
            L = len(src)
            seq = gens[:L, i].tolist()
            if sum(seq) == 0:
                seq[random.randint(0, len(seq) - 1)] = 1
            results.append(seq)
        return results
