"""
GOCM_MIVAE: Heterogeneous Graph Augmentation based on Full Graph Structure Clustering

Core Improvements (compared to previous relation-wise independent clustering):
1. Merge edges of all relations for METIS clustering -> Based on full graph structure
2. Jointly process all relations within each cluster -> Generated nodes have edges in all relations
3. Support edge_type to distinguish different relations -> Consider edge type during VGAE training
"""

import os
import math
import time
import copy
import tqdm
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import BaseTransform, ToUndirected
from torch_geometric.loader import ClusterData, ClusterLoader
from torch_geometric.utils import add_self_loops, negative_sampling
from torch.optim.lr_scheduler import ReduceLROnPlateau

try:
    import pygod
except ImportError:
    pygod = None

# GraphSAGE Convolution Layer (Reuse implementation from gocm_cluster)
try:
    from .sage import SAGEConv
except (ImportError, ModuleNotFoundError):
    from sage import SAGEConv

# MeanConsistency Model
try:
    from .OriginConsistency_cluster import MeanConsistencyCluster, mean_consistency_sampling
except (ImportError, ModuleNotFoundError):
    from MeanConsistency_cluster import MeanConsistencyCluster, mean_consistency_sampling


def _dgl_hetero_to_pyg_merged(
    g: dgl.DGLHeteroGraph,
    target_ntype: str,
    relations: List[Tuple[str, str, str]]
) -> Tuple[Data, Dict[int, str], Dict[str, int]]:
    """
    Convert DGL HeteroGraph to merged-edge PyG Data 
    
    Key Design:
    - Merge edges of all relations into one edge_index
    - Use edge_type to mark which relation each edge belongs to
    - Prepare data for relation-specific convolution + fusion architecture
    
    Args:
        g: DGL HeteroGraph
        target_ntype: Target node type
        relations: List of relations to process [(src, rel, dst), ...]
    
    Returns:
        data: PyG Data, containing merged edge_index and edge_type
        etype_to_name: {edge_type_id: relation_name}
        name_to_etype: {relation_name: edge_type_id}
    """
    # Node features and labels
    x = g.nodes[target_ntype].data.get('feature', g.nodes[target_ntype].data.get('feat'))
    y = g.nodes[target_ntype].data.get('label', torch.zeros(x.size(0), dtype=torch.long))
    train_mask = g.nodes[target_ntype].data.get('train_mask', torch.ones(x.size(0), dtype=torch.bool))
    val_mask = g.nodes[target_ntype].data.get('val_mask', torch.zeros(x.size(0), dtype=torch.bool))
    test_mask = g.nodes[target_ntype].data.get('test_mask', torch.zeros(x.size(0), dtype=torch.bool))
    
    # Merge edges from all relations
    all_src, all_dst, all_etype = [], [], []
    etype_to_name = {}
    name_to_etype = {}
    
    for i, (s, r, d) in enumerate(relations):
        src, dst = g.edges(etype=(s, r, d))
        all_src.append(src)
        all_dst.append(dst)
        all_etype.append(torch.full((src.size(0),), i, dtype=torch.long))
        etype_to_name[i] = r
        name_to_etype[r] = i
    
    if len(all_src) > 0:
        merged_src = torch.cat(all_src)
        merged_dst = torch.cat(all_dst)
        merged_etype = torch.cat(all_etype)
    else:
        merged_src = torch.empty(0, dtype=torch.long)
        merged_dst = torch.empty(0, dtype=torch.long)
        merged_etype = torch.empty(0, dtype=torch.long)
    
    edge_index = torch.stack([merged_src, merged_dst], dim=0)
    
    # Construct PyG Data (use clone to avoid contaminating original data)
    # Critical: edge_type must be set correctly for relation-specific convolution
    data = Data(
        x=x.cpu().clone(),
        edge_index=edge_index.cpu().clone(),
        edge_type=merged_etype.cpu().clone(),  # Core: Edge type labels
        y=y.cpu().clone(),
        train_mask=train_mask.cpu().clone(),
        val_mask=val_mask.cpu().clone(),
        test_mask=test_mask.cpu().clone(),
    )
    
    # Verify correctness of edge_type
    assert hasattr(data, 'edge_type'), "PyG Data must have edge_type attribute!"
    assert data.edge_type.size(0) == edge_index.size(1), \
        f"edge_type size mismatch: {data.edge_type.size(0)} vs {edge_index.size(1)}"
    
    return data, etype_to_name, name_to_etype


class VGAE_Hetero(nn.Module):
    """
    Heterogeneous Graph VGAE - Relation-specific Convolution + Cross-Relation Fusion
    
    Architecture Improvements:
    1. Intra-Relation Convolution: Independent GNN parameters for each relation
    2. Cross-Relation Fusion: Learnable fusion layer
    3. Latent Parameter Projection: Independent mu and log_sigma projections
    
    Advantages compared to previous versions:
    - Independent semantic modeling capability for each relation
    - Adaptive cross-relation information fusion
    - Stronger generation quality and edge type discrimination
    """
    def __init__(
        self,
        in_dim: int,
        hid_dim: int,
        etypes: int = 1,
        threshold: float = 0.5,
        temporal: bool = False,
        t_min: int = 0,
        t_max: int = 1024,
        fusion_strategy: str = 'concat_linear'  # 'concat_linear', 'attention', 'mean'
    ):
        super().__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.etypes = etypes
        self.threshold = threshold
        self.temporal = temporal
        self.t_min = t_min
        self.t_max = t_max
        self.time_len = int(t_max - t_min + 1) if temporal else None
        self.fusion_strategy = fusion_strategy
        
        # ============ Intra-Relation Convolution ============
        # Independent encoder for each relation (no shared parameters)
        self.relation_encoders = nn.ModuleList([
            nn.ModuleDict({
                'shared': SAGEConv(in_dim, hid_dim, temporal=temporal,
                                  time_len=self.time_len, etypes=1),  # Note etypes=1
            })
            for _ in range(etypes)
        ])
        
        # ============ Cross-Relation Fusion ============
        if fusion_strategy == 'concat_linear':
            # Paper method: concat + linear
            self.fusion_layer = nn.Sequential(
                nn.Linear(etypes * hid_dim, hid_dim),
                nn.ReLU(),
                nn.LayerNorm(hid_dim)
            )
        elif fusion_strategy == 'attention':
            
            self.fusion_layer = RelationAttentionFusion(hid_dim, etypes)
        elif fusion_strategy == 'mean':
            
            self.fusion_layer = None
        else:
            raise ValueError(f"Unknown fusion strategy: {fusion_strategy}")
        
        # ============ Latent Parameter Projection ============
        # Independent mu and log_sigma projection layers
        self.proj_mu = nn.Linear(hid_dim, hid_dim)
        self.proj_sigma = nn.Linear(hid_dim, hid_dim)
        
        # ============ Decoder ============
        self.dec_attr = nn.Linear(hid_dim, in_dim)
        self.dec_stru = nn.Linear(2 * hid_dim, 1)
        self.dec_type = nn.Linear(2 * hid_dim, etypes) if etypes > 1 else None
        self.dec_time = nn.Linear(2 * hid_dim, 1) if temporal else None
        
        
        self.map_label_e = nn.Linear(1, in_dim, bias=False)
        self.map_label_d = nn.Linear(1, hid_dim, bias=False)
    
    def encode(self, h, edge_index, label, edge_time=None, edge_type=None, verbose_stats=False):
        
        # Label conditioning
        h = h + self.map_label_e(label)
        
        # ============ Step 1: Intra-Relation Convolution ============
        relation_embeddings = []
        rel_edge_counts = []  # Record edge count for each relation (for debugging)
        
        for rel_id in range(self.etypes):
            
            if edge_type is not None:
                rel_mask = (edge_type == rel_id)
                edge_index_r = edge_index[:, rel_mask]
            else:
                
                edge_index_r = edge_index if rel_id == 0 else edge_index[:, :0]
            
            num_edges_r = edge_index_r.size(1)
            rel_edge_counts.append(num_edges_r)
            
           
            if num_edges_r == 0:
                h_r = torch.zeros(h.size(0), self.hid_dim, device=h.device)
            else:
                
                encoder = self.relation_encoders[rel_id]
                h_r = encoder['shared'](h, edge_index_r, edge_time, None)
                h_r = F.relu(h_r)
            
            relation_embeddings.append(h_r)
        
        
        if verbose_stats:
            print(f"  [RelConv] Relation edge counts: {rel_edge_counts}")
            for i, h_r in enumerate(relation_embeddings):
                norm = h_r.norm(dim=1).mean().item()
                print(f"  [RelConv] Relation {i} embedding norm: {norm:.4f}")
        
        # ============ Step 2: Cross-Relation Fusion ============
        if self.fusion_strategy == 'mean':
            
            h_fused = torch.stack(relation_embeddings, dim=0).mean(dim=0)
        elif self.fusion_strategy == 'attention':
            # Attention fusion
            h_fused = self.fusion_layer(relation_embeddings)
        else:
            
            h_concat = torch.cat(relation_embeddings, dim=1)  # [N, etypes * hid_dim]
            h_fused = self.fusion_layer(h_concat)  # [N, hid_dim]
        
        if verbose_stats:
            print(f"  [Fusion] Fused embedding norm: {h_fused.norm(dim=1).mean().item():.4f}")
        
        # ============ Step 3: Latent Parameter Projection ============
        mu = self.proj_mu(h_fused)
        log_std = self.proj_sigma(h_fused)
        
        
        log_std = torch.clamp(log_std, min=-10.0, max=10.0)
        
        return mu, log_std
    
    def reparameterize(self, mu, log_std):
        
        std = torch.exp(log_std)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, pos_edge_index, neg_edge_index, label):
        
        z = z + self.map_label_d(label)
        x_rec = self.dec_attr(z)
        
        
        pos_ze = torch.cat([z[pos_edge_index[0]], z[pos_edge_index[1]]], dim=1)
        neg_ze = torch.cat([z[neg_edge_index[0]], z[neg_edge_index[1]]], dim=1)
        pos_edge_pred = self.dec_stru(pos_ze).squeeze(-1)
        neg_edge_pred = self.dec_stru(neg_ze).squeeze(-1)
        edge_pred = torch.cat([pos_edge_pred, neg_edge_pred], dim=0)
        
        
        p_pred = self.dec_type(pos_ze) if self.etypes > 1 else None
        
        
        t_pred = self.dec_time(pos_ze).squeeze(-1) if self.temporal else None
        
        return x_rec, edge_pred, t_pred, p_pred
    
    def sample(self, z, label):
        
        z = z + self.map_label_d(label)
        x_rec = self.dec_attr(z)
        
       
        n = z.size(0)
        z1 = z.unsqueeze(1).expand(-1, n, -1)
        z2 = z.unsqueeze(0).expand(n, -1, -1)
        ze = torch.cat([z1, z2], dim=2)
        
        adj = torch.sigmoid(self.dec_stru(ze)).squeeze(-1)
        edge_index = (adj > self.threshold).nonzero().T
        edge_index = add_self_loops(edge_index, num_nodes=n)[0]
        
        
        pos_ze = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
        p_pred = self.dec_type(pos_ze).argmax(-1) if self.etypes > 1 else None
        
        
        if self.temporal:
            t_pred = self.dec_time(pos_ze).squeeze(-1)
            t_pred = torch.clamp(t_pred, min=0, max=1)
            t_pred = t_pred * (self.t_max - self.t_min)
        else:
            t_pred = None
        
        return x_rec, edge_index, t_pred, p_pred


class RelationAttentionFusion(nn.Module):
    
    def __init__(self, hid_dim: int, num_relations: int):
        super().__init__()
        self.hid_dim = hid_dim
        self.num_relations = num_relations
        
        
        self.attention = nn.Sequential(
            nn.Linear(hid_dim, hid_dim // 2),
            nn.Tanh(),
            nn.Linear(hid_dim // 2, 1)
        )
        
        
        self.output_proj = nn.Linear(hid_dim, hid_dim)
    
    def forward(self, relation_embeddings: List[torch.Tensor]) -> torch.Tensor:
        
        # Stack: [R, N, D]
        stacked = torch.stack(relation_embeddings, dim=0)
        
        
        attn_scores = self.attention(stacked)
        
        # Softmax over relations: [R, N, 1]
        attn_weights = F.softmax(attn_scores, dim=0)
        
        
        weighted = (stacked * attn_weights).sum(dim=0)
        
        
        fused = self.output_proj(weighted)
        
        return fused


class GOCM_MIVAE(BaseTransform):
    

    def __init__(
        self,
        name: str = "",
        target_ntype: str = 'user',
        relations: Optional[List[str]] = None,
        # VGAE/MC Parameters
        hid_dim: int = 64,
        cons_dim: int = 128,
        vae_epochs: int = 100,
        cons_epochs: int = 100,
        patience: int = 50,
        lr: float = 1e-3,
        wd: float = 0.0,
        batch_size: int = 4096,
        threshold: float = 0.75,
        # Relink Parameters
        relink_ratio: float = 0.0,
        relink_max_candidates: int = 64,
        relink_threshold: float | None = None,
        # Loss Weights
        wx: float = 1.0,
        we: float = 0.5,
        beta: float = 1e-3,
        wp: float = 0.3,
        # Generation Parameters
        gen_ratio: float = 1.0,
        sample_steps: int = 1,
        device: int | str = 0,
        verbose: bool = False,
        reuse_ae: bool = False,
        reuse_cm: bool = False,
        
        fusion_strategy: str = 'concat_linear',  # 'concat_linear', 'attention', 'mean'
        
        mc_T_type: str = 'baseline',
        mc_T_k: float = 48.0,
        mc_T_eps: float = 0.002,
        mc_W_type: str = 'constant1',
        mc_schedule: str = 'linear',
        mc_eta: float = 0.0,
        mc_s_min: float = 0.002,
        mc_step_clip: float | None = None,
        mc_rho: float = 7.0,
        mc_heun: bool = False,
        **kwargs
    ):
        self.name = name
        self.target_ntype = target_ntype
        self.relations = relations
        
        self.hid_dim = hid_dim
        self.cons_dim = cons_dim
        self.vae_epochs = vae_epochs
        self.cons_epochs = cons_epochs
        self.patience = patience
        self.lr = lr
        self.wd = wd
        self.batch_size = batch_size
        self.threshold = threshold
        self.fusion_strategy = fusion_strategy
        self.relink_ratio = relink_ratio
        self.relink_max_candidates = relink_max_candidates
        self.relink_threshold = relink_threshold
        
        self.wx = wx
        self.we = we
        self.beta = beta
        self.wp = wp
        
        self.gen_ratio = gen_ratio
        self.sample_steps = sample_steps
        self.verbose = verbose
        self.reuse_ae = reuse_ae
        self.reuse_cm = reuse_cm
        
        # Device Setup
        if isinstance(device, int):
            if device == -1:
                self.device = torch.device('cpu')
            else:
                self.device = torch.device(f'cuda:{device}' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        
        self.mc_T_type = mc_T_type
        self.mc_T_k = mc_T_k
        self.mc_T_eps = mc_T_eps
        self.mc_W_type = mc_W_type
        self.mc_schedule = mc_schedule
        self.mc_eta = mc_eta
        self.mc_s_min = mc_s_min
        self.mc_step_clip = mc_step_clip
        self.mc_rho = mc_rho
        self.mc_heun = mc_heun
        
        
        self.ae = None
        self.cm = None
        self.y_orig = None
        self.mean = None
        self.std = None
        self.etypes = 1
        self.etype_to_name = {}
        self.name_to_etype = {}
        self.last_gen_time = 0.0

    def _discover_relations(self, g: dgl.DGLHeteroGraph) -> List[Tuple[str, str, str]]:
        
        etypes = []
        for et in g.canonical_etypes:
            s, r, d = et
            if s == self.target_ntype and d == self.target_ntype:
                if self.relations is None or self.relations == 'all' or r in self.relations:
                    etypes.append(et)
        return etypes

    def preprocess(self, data: Data) -> Data:
        
        if data.is_directed():
            data = ToUndirected(reduce='min')(data)
        
        
        self.mean = data.x.mean(0)
        self.std = data.x.std(0)
        std_safe = self.std.clone()
        std_safe[std_safe == 0] = 1.0
        data.x = (data.x - self.mean) / std_safe
        self.std = std_safe
        
        return data

    def postprocess(self, data: Data) -> Data:
        
        data.x = data.x * self.std + self.mean
        return data

    def recon_loss(self, x, x_rec, edge_label, edge_pred, p=None, p_pred=None):
        
        loss_x = F.mse_loss(x_rec, x)
        loss_e = F.binary_cross_entropy_with_logits(edge_pred, edge_label)
        loss_p = F.cross_entropy(p_pred, p) if (self.etypes > 1 and p_pred is not None) else 0.0
        return self.wx * loss_x + self.we * loss_e + self.wp * loss_p

    def train_ae(self, dataloader):
        
        if self.verbose:
            print('[GOCM_MIVAE] Training VGAE (Relation-specific Conv + Fusion)...', flush=True)
        
        optimizer = torch.optim.Adam(self.ae.parameters(), lr=self.lr, weight_decay=self.wd)
        best_loss = float('inf')
        patience_count = 0
        skipped_nan = 0
        
        for epoch in range(self.vae_epochs):
            start = time.time()
            self.ae.train()
            total_loss = 0
            num_nodes = 0
            batch_count = 0
            
            for batch in dataloader:
                batch_size = batch.x.size(0)
                x = batch.x.to(self.device)
                edge_index = batch.edge_index.to(self.device)
                y = batch.y.float().unsqueeze(1).to(self.device)
                
                
                if not hasattr(batch, 'edge_type'):
                    raise ValueError(
                        "Batch must have 'edge_type' attribute for relation-specific convolution! "
                        "Ensure _dgl_hetero_to_pyg_merged() correctly sets edge_type."
                    )
                edge_type = batch.edge_type.to(self.device)
                
                
                if self.verbose and epoch == 0 and batch_count == 0:
                    print(f'  [VGAE] First batch edge type distribution:', flush=True)
                    for rel_id in range(self.etypes):
                        count = (edge_type == rel_id).sum().item()
                        rel_name = self.etype_to_name.get(rel_id, f'rel_{rel_id}')
                        print(f'    Relation {rel_id} ({rel_name}): {count} edges', flush=True)
                
                
                neg_edge_index = negative_sampling(edge_index, num_nodes=batch_size)
                edge_label = torch.cat([
                    torch.ones(edge_index.size(1), device=self.device),
                    torch.zeros(neg_edge_index.size(1), device=self.device)
                ])
                
                
                verbose_stats = (epoch == 0 and batch_count == 0 and self.verbose)
                mu, log_std = self.ae.encode(x, edge_index, y, edge_type=edge_type, verbose_stats=verbose_stats)
                z = self.ae.reparameterize(mu, log_std)
                x_rec, edge_pred, _, p_pred = self.ae.decode(z, edge_index, neg_edge_index, y)
                
                batch_count += 1
                
                # Loss
                loss = self.recon_loss(x, x_rec, edge_label, edge_pred, edge_type, p_pred)
                
                # KL Divergence
                kl = -0.5 * (1 + 2*log_std - mu**2 - torch.exp(2*log_std)).sum(1).mean()
                loss = loss + self.beta * kl
                
                
                if torch.isnan(loss) or torch.isinf(loss):
                    skipped_nan += 1
                    if self.verbose and skipped_nan <= 5:
                        print(f'  [VGAE] NaN/Inf detected, skipping batch (epoch={epoch}, total_skipped={skipped_nan})', flush=True)
                    optimizer.zero_grad()
                    continue
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
                optimizer.step()
                
                total_loss += loss.item() * batch_size
                num_nodes += batch_size
            
            curr_loss = total_loss / max(1, num_nodes)
            
            if curr_loss < best_loss:
                best_loss = curr_loss
                patience_count = 0
                os.makedirs('ckpt', exist_ok=True)
                torch.save(self.ae, f"ckpt/{self.name.replace('/', '_')}_ae_hetero.pt")
            else:
                patience_count += 1
                if patience_count >= self.patience:
                    if self.verbose:
                        print(f'  [VGAE] Early stopping at epoch {epoch}', flush=True)
                    break
            
            if self.verbose and (epoch % 10 == 0 or epoch == self.vae_epochs - 1):
                print(f'  [VGAE] Epoch {epoch:03d}, Loss: {curr_loss:.6f}, Time: {time.time()-start:.2f}s', flush=True)
        
        
        ckpt_path = f"ckpt/{self.name.replace('/', '_')}_ae_hetero.pt"
        if os.path.exists(ckpt_path):
            self.ae = torch.load(ckpt_path)
        else:
            raise RuntimeError(f"VGAE checkpoint not found: {ckpt_path}. Training may have failed.")

    def train_cm(self, dataloader):
        """Train MeanConsistency Model (in fused latent space)"""
        if self.verbose:
            print('[GOCM_MIVAE] Training MeanConsistency (in fused latent space)...', flush=True)
        
        optimizer = torch.optim.Adam(self.cm.parameters(), lr=self.lr, weight_decay=self.wd)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=False)
        best_loss = float('inf')
        patience_count = 0
        skipped_nan = 0
        
        for epoch in range(self.cons_epochs):
            pbar = tqdm.tqdm(dataloader, total=len(dataloader), disable=not self.verbose)
            pbar.set_description(f"Epoch {epoch}")
            
            batch_loss = 0.0
            len_input = 0
            
            for batch in pbar:
                x = batch.x.to(self.device)
                edge_index = batch.edge_index.to(self.device)
                y = batch.y.float().unsqueeze(1).to(self.device)
                
                
                if not hasattr(batch, 'edge_type'):
                    raise ValueError("Batch must have 'edge_type' for relation-specific encoding!")
                edge_type = batch.edge_type.to(self.device)
                
                
                with torch.no_grad():
                    mu, _ = self.ae.encode(x, edge_index, y, edge_type=edge_type)
                inputs = mu.detach()
                
                
                loss = self.cm(inputs, y)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    skipped_nan += 1
                    if self.verbose and skipped_nan <= 5:
                        print(f'  [MC] NaN/Inf detected, skipping batch (epoch={epoch}, total_skipped={skipped_nan})', flush=True)
                    optimizer.zero_grad()
                    continue
                
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.cm.parameters(), 1.0)
                optimizer.step()
                
                batch_loss += loss.item() * len(inputs)
                len_input += len(inputs)
                pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            curr_loss = batch_loss / max(1, len_input)
            scheduler.step(curr_loss)
            
            if curr_loss < best_loss:
                best_loss = curr_loss
                patience_count = 0
                os.makedirs('ckpt', exist_ok=True)
                torch.save(self.cm, f"ckpt/{self.name.replace('/', '_')}_cm_hetero.pt")
            else:
                patience_count += 1
                if patience_count >= self.patience:
                    if self.verbose:
                        print(f'  [MC] Early stopping at epoch {epoch}', flush=True)
                    break
        
       
        ckpt_path = f"ckpt/{self.name.replace('/', '_')}_cm_hetero.pt"
        if os.path.exists(ckpt_path):
            self.cm = torch.load(ckpt_path)
        else:
            raise RuntimeError(f"MC checkpoint not found: {ckpt_path}. Training may have failed.")

    def sample(self, graph_size: int) -> Data:
        """Generate new graph"""
        noise = torch.randn(graph_size, self.hid_dim, device=self.device)
        label = torch.ones(graph_size, 1, device=self.device)
        
        
        if self.sample_steps > 0:
            z = mean_consistency_sampling(
                self.cm, noise, label,
                num_steps=self.sample_steps,
                schedule=self.mc_schedule,
                eta=self.mc_eta,
                s_min=self.mc_s_min,
                step_clip=self.mc_step_clip,
                rho=self.mc_rho,
                heun=self.mc_heun
            )
        else:
            z = noise
        
        # VGAE Decoding
        x_rec, edge_index, _, edge_type = self.ae.sample(z, label)
        
        # Construct generated graph
        data = Data(
            x=x_rec.cpu().detach(),
            edge_index=edge_index.cpu().detach(),
            y=torch.ones(graph_size, dtype=torch.long),
            train_mask=torch.ones(graph_size, dtype=torch.bool),
            val_mask=torch.zeros(graph_size, dtype=torch.bool),
            test_mask=torch.zeros(graph_size, dtype=torch.bool),
        )
        if edge_type is not None:
            data.edge_type = edge_type.cpu().detach()
        
        return data

    def forward(self, dgl_hetero: dgl.DGLHeteroGraph) -> dgl.DGLHeteroGraph:
        
        if self.verbose:
            print(f'\n{"="*60}')
            print(f'[GOCM_MIVAE] Start processing: {self.name}')
            print(f'{"="*60}')
        
        t_start = time.time()
        
        
        if self.target_ntype not in dgl_hetero.ntypes:
            self.target_ntype = dgl_hetero.ntypes[0]
        ntype = self.target_ntype
        
        
        relations = self._discover_relations(dgl_hetero)
        if len(relations) == 0:
            raise ValueError(f"No self-loop relation found for target node type {ntype}")
        self.etypes = len(relations)
        
        if self.verbose:
            print(f'[GOCM_MIVAE] Target node: {ntype}, Relations: {[r[1] for r in relations]}, Num edge types: {self.etypes}')
        
        
        data, self.etype_to_name, self.name_to_etype = _dgl_hetero_to_pyg_merged(
            dgl_hetero, ntype, relations
        )
        num_orig = data.x.size(0)
        
        
        if self.verbose:
            edge_type_tensor = data.edge_type
            print(f'[GOCM_MIVAE] Edge statistics (before merge):')
            for rel_id, rel_name in self.etype_to_name.items():
                count = (edge_type_tensor == rel_id).sum().item()
                print(f'  Relation {rel_id} ({rel_name}): {count} edges')
        
        
        self.y_orig = data.y.clone()
        data.y[data.train_mask == 0] = 0  # mask out val/test
        
        
        data = self.preprocess(data)
        
        
        num_parts = max(1, data.num_nodes // self.batch_size)
        save_dir = os.path.join('ckpt', 'cluster_cache', f'{self.name.replace("/", "_")}_parts{num_parts}')
        os.makedirs(save_dir, exist_ok=True)
        
        if self.verbose:
            print(f'[GOCM_MIVAE] METIS Clustering: nodes={data.num_nodes}, edges={data.edge_index.size(1)}, parts={num_parts}')
        
        cluster_data = ClusterData(data, num_parts=num_parts, log=self.verbose, save_dir=save_dir)
        dataloader = ClusterLoader(cluster_data, batch_size=2, shuffle=False, num_workers=4)
        
        
        if self.hid_dim is None:
            self.hid_dim = 2 ** int(math.log2(data.x.size(1)) - 1)
        if self.cons_dim is None:
            self.cons_dim = 2 * self.hid_dim
        
        self.ae = VGAE_Hetero(
            in_dim=data.x.size(1),
            hid_dim=self.hid_dim,
            etypes=self.etypes,
            threshold=self.threshold,
            fusion_strategy=self.fusion_strategy
        ).to(self.device)
        
        if self.verbose:
            print(f'[GOCM_MIVAE] VGAE Architecture: num_relations={self.etypes}, fusion_strategy={self.fusion_strategy}')
            param_count = sum(p.numel() for p in self.ae.parameters())
            print(f'[GOCM_MIVAE] VGAE Parameters: {param_count:,}')
        
        
        ae_ckpt = f"ckpt/{self.name.replace('/', '_')}_ae_hetero.pt"
        if self.reuse_ae and os.path.exists(ae_ckpt):
            if self.verbose:
                print(f'[GOCM_MIVAE] Reuse VGAE: {ae_ckpt}')
            self.ae = torch.load(ae_ckpt, map_location=self.device)
        else:
            self.train_ae(dataloader)
        
        
        self.cm = MeanConsistencyCluster(
            d_in=self.hid_dim,
            dim_t=self.cons_dim,
            opts={
                'T_type': self.mc_T_type,
                'T_k': self.mc_T_k,
                'T_eps': self.mc_T_eps,
                'W_type': self.mc_W_type,
            },
            device=self.device
        ).to(self.device)
        
        
        cm_ckpt = f"ckpt/{self.name.replace('/', '_')}_cm_hetero.pt"
        if self.reuse_cm and os.path.exists(cm_ckpt):
            if self.verbose:
                print(f'[GOCM_MIVAE] Reuse MC: {cm_ckpt}')
            self.cm = torch.load(cm_ckpt, map_location=self.device)
        else:
            self.train_cm(dataloader)
        
        
        gen_nodes = int(max(1, self.y_orig[data.train_mask.bool()].sum().item() * self.gen_ratio))
        if self.verbose:
            print(f'[GOCM_MIVAE] Generated nodes: {gen_nodes}')
        
        gen_start = time.time()
        gen_gs = []
        remaining = gen_nodes
        while remaining > self.batch_size:
            gen_gs.append(self.sample(self.batch_size))
            remaining -= self.batch_size
        if remaining > 0:
            gen_gs.append(self.sample(remaining))
        self.last_gen_time = time.time() - gen_start
        
        
        data.y = self.y_orig
        
        
        aug_data = Batch.from_data_list([data] + gen_gs)
        if self.relink_ratio > 0:
            if self.verbose:
                print(f'[GOCM_MIVAE] Start relinking: relink_ratio={self.relink_ratio}, generated_nodes={aug_data.x.size(0) - num_orig}')
            aug_data = self._relink_generated_nodes(aug_data, num_orig)
        aug_data = self.postprocess(aug_data)
        
        
        g_aug = self._pyg_to_dgl_hetero(
            aug_data, ntype, relations, num_orig, dgl_hetero
        )
        
        t_end = time.time()
        if self.verbose:
            print(f'\n{"="*60}')
            print(f'[GOCM_MIVAE] Augmentation completed!')
            print(f'  Original nodes: {num_orig}')
            print(f'  Augmented nodes: {g_aug.number_of_nodes(ntype)}')
            print(f'  Generation time: {self.last_gen_time:.2f}s')
            print(f'  Total time: {t_end - t_start:.2f}s')
            print(f'{"="*60}\n')
        
        return g_aug

    def _relink_generated_nodes(self, aug_data: Data, num_orig: int) -> Data:
        """Relink partial generated nodes to original nodes"""
        num_total = aug_data.x.size(0)
        num_gen = num_total - num_orig
        if num_gen <= 0:
            if self.verbose:
                print(f'  [Relink] Skip: No generated nodes (num_gen={num_gen})')
            return aug_data
        
        num_relink = int(num_gen * self.relink_ratio)
        if num_relink <= 0 or num_orig <= 0:
            if self.verbose:
                print(f'  [Relink] Skip: num_relink={num_relink}, num_orig={num_orig}')
            return aug_data
        
        if self.verbose:
            print(f'  [Relink] Planned relink nodes: {num_relink} / {num_gen} (ratio={self.relink_ratio:.3f})')
        
        # Ensure edge_type exists
        if not hasattr(aug_data, 'edge_type') or aug_data.edge_type is None:
            aug_data.edge_type = torch.zeros(aug_data.edge_index.size(1), dtype=torch.long)
        
        # Encode current merged graph using VGAE
        x = aug_data.x.to(self.device)
        edge_index = aug_data.edge_index.to(self.device)
        edge_type = aug_data.edge_type.to(self.device)
        y = aug_data.y.unsqueeze(1).float().to(self.device)
        with torch.no_grad():
            mu_all, _ = self.ae.encode(x, edge_index, y, edge_type=edge_type)
        
        gen_indices = torch.arange(num_orig, num_total, device=self.device)
        if gen_indices.numel() == 0:
            return aug_data
        
        num_relink = min(num_relink, gen_indices.numel())
        selected_gen = gen_indices[torch.randperm(gen_indices.numel(), device=self.device)[:num_relink]]
        
        threshold = self.relink_threshold if self.relink_threshold is not None else self.threshold
        new_src, new_dst, new_type = [], [], []
        
        for gen_idx in selected_gen:
            k = min(self.relink_max_candidates, num_orig)
            if k <= 0:
                continue
            
            cand_orig = torch.randperm(num_orig, device=self.device)[:k]
            z_gen = mu_all[gen_idx].expand(cand_orig.size(0), -1)
            z_orig = mu_all[cand_orig]
            
            pair_embed = torch.cat([z_gen, z_orig], dim=1)
            score = torch.sigmoid(self.ae.dec_stru(pair_embed)).view(-1)
            mask = score >= threshold
            if mask.sum() == 0:
                continue
            
            valid_orig = cand_orig[mask]
            z_gen_valid = z_gen[mask]
            z_orig_valid = z_orig[mask]
            
            if self.etypes > 1 and self.ae.dec_type is not None:
                et_logits = self.ae.dec_type(torch.cat([z_gen_valid, z_orig_valid], dim=1))
                et_pred = torch.argmax(et_logits, dim=1).long()
            else:
                et_pred = torch.zeros(valid_orig.size(0), dtype=torch.long, device=self.device)
            
            new_src.append(torch.cat([gen_idx.repeat(valid_orig.size(0)), valid_orig]))
            new_dst.append(torch.cat([valid_orig, gen_idx.repeat(valid_orig.size(0))]))
            new_type.append(torch.cat([et_pred, et_pred]))
        
        if len(new_src) == 0:
            if self.verbose:
                print(f'  [Relink] No new edges added (all candidate pairs failed threshold)')
            return aug_data
        
        edge_index_new = torch.stack([torch.cat(new_src), torch.cat(new_dst)], dim=0).cpu()
        edge_type_new = torch.cat(new_type).cpu()
        
        num_new_edges = edge_index_new.size(1)
        old_num_edges = aug_data.edge_index.size(1)
        
        aug_data.edge_index = torch.cat([aug_data.edge_index, edge_index_new], dim=1)
        aug_data.edge_type = torch.cat([aug_data.edge_type, edge_type_new], dim=0)
        
        if self.verbose:
            print(f'  [Relink] Done: Added {num_new_edges} new edges (old_edges={old_num_edges}, new_edges={aug_data.edge_index.size(1)})')
        
        return aug_data

    def _pyg_to_dgl_hetero(
        self,
        aug_data: Data,
        ntype: str,
        relations: List[Tuple[str, str, str]],
        num_orig: int,
        dgl_orig: dgl.DGLHeteroGraph
    ) -> dgl.DGLHeteroGraph:
        """Convert augmented PyG Data back to DGL HeteroGraph"""
        num_total = aug_data.x.size(0)
        num_gen = num_total - num_orig
        
        # Split edges by edge type
        edge_index = aug_data.edge_index
        edge_type = aug_data.edge_type if hasattr(aug_data, 'edge_type') else torch.zeros(edge_index.size(1), dtype=torch.long)
        
        if self.verbose:
            print(f'[GOCM_MIVAE] Split edges of augmented graph by edge type:')
        
        data_dict = {}
        for i, (s, r, d) in enumerate(relations):
            mask = (edge_type == i)
            src = edge_index[0, mask]
            dst = edge_index[1, mask]
            data_dict[(s, r, d)] = (src.cpu(), dst.cpu())
            
            if self.verbose:
                num_edges = src.size(0)
                num_gen_edges = ((src >= num_orig) | (dst >= num_orig)).sum().item()
                print(f'  Relation {r}: Total edges={num_edges}, New edges={num_gen_edges}')
        
        # Construct HeteroGraph
        g_aug = dgl.heterograph(data_dict, num_nodes_dict={ntype: num_total})
        
        # Node features
        g_aug.nodes[ntype].data['feature'] = aug_data.x.cpu()
        
        # Labels
        y_aug = aug_data.y.cpu()
        g_aug.nodes[ntype].data['label'] = y_aug
        
        # Mask processing
        train_mask = aug_data.train_mask.cpu() if hasattr(aug_data, 'train_mask') else torch.ones(num_total, dtype=torch.bool)
        val_mask = aug_data.val_mask.cpu() if hasattr(aug_data, 'val_mask') else torch.zeros(num_total, dtype=torch.bool)
        test_mask = aug_data.test_mask.cpu() if hasattr(aug_data, 'test_mask') else torch.zeros(num_total, dtype=torch.bool)
        
        g_aug.nodes[ntype].data['train_mask'] = train_mask
        g_aug.nodes[ntype].data['val_mask'] = val_mask
        g_aug.nodes[ntype].data['test_mask'] = test_mask
        
        # Multi-trial masks
        for key in ['train_masks', 'val_masks', 'test_masks']:
            if key in dgl_orig.nodes[ntype].data:
                orig_masks = dgl_orig.nodes[ntype].data[key].cpu()
                num_trials = orig_masks.size(1)
                # Add generated nodes only for the first 10 trials (fully supervised)
                num_supervised_trials = min(10, num_trials)
                if key == 'train_masks':
                    new_masks_supervised = torch.ones(num_gen, num_supervised_trials, dtype=torch.bool)
                    new_masks_semi = torch.zeros(num_gen, num_trials - num_supervised_trials, dtype=torch.bool)
                else:
                    new_masks_supervised = torch.zeros(num_gen, num_supervised_trials, dtype=torch.bool)
                    new_masks_semi = torch.zeros(num_gen, num_trials - num_supervised_trials, dtype=torch.bool)
                new_masks = torch.cat([new_masks_supervised, new_masks_semi], dim=1) if num_trials > num_supervised_trials else new_masks_supervised
                g_aug.nodes[ntype].data[key] = torch.cat([orig_masks, new_masks], dim=0)
        
        return g_aug
