#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
new_mean_consistency_benchmark_cluster.py

"""

import argparse
import time
import os
import sys
import random
import warnings
from typing import Optional, Tuple, Dict, Any

import numpy as np
import pandas
import torch
import dgl
from torch_geometric.data import Data

sys.path.insert(0, os.path.join(os.getcwd(), 'gocm'))

from utils import Dataset, model_detector_dict, save_results
from gocm.gocm_cluster import GOCM_Cluster

warnings.filterwarnings("ignore")

SEED_LIST = list(range(3407, 10000, 10))


def set_seed(seed: int = 3407) -> None:
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


class DataConverter:
    
    
    @staticmethod
    def dgl_to_pyg(dgl_graph, time_attr: str = 'edge_time') -> Tuple[Data, Dict[str, Any]]:
        
        meta_info = {}
        
        is_single_ntype = len(dgl_graph.ntypes) == 1
        is_multi_etype = len(dgl_graph.etypes) > 1
        
        meta_info['ntypes'] = dgl_graph.ntypes
        meta_info['etypes'] = dgl_graph.etypes
        meta_info['is_single_ntype'] = is_single_ntype
        meta_info['is_multi_etype'] = is_multi_etype
        meta_info['time_attr'] = time_attr
        
        if is_single_ntype:
            ntype = dgl_graph.ntypes[0]
            x = dgl_graph.nodes[ntype].data['feature']
            y = dgl_graph.nodes[ntype].data['label']
            train_mask = dgl_graph.nodes[ntype].data['train_mask']
            val_mask = dgl_graph.nodes[ntype].data['val_mask']
            test_mask = dgl_graph.nodes[ntype].data['test_mask']
            if 'train_masks' in dgl_graph.nodes[ntype].data:
                meta_info['train_masks'] = dgl_graph.nodes[ntype].data['train_masks']
                meta_info['val_masks'] = dgl_graph.nodes[ntype].data['val_masks']
                meta_info['test_masks'] = dgl_graph.nodes[ntype].data['test_masks']
        else:
            x = dgl_graph.ndata['feature']
            y = dgl_graph.ndata['label']
            train_mask = dgl_graph.ndata['train_mask']
            val_mask = dgl_graph.ndata['val_mask']
            test_mask = dgl_graph.ndata['test_mask']
            if 'train_masks' in dgl_graph.ndata:
                meta_info['train_masks'] = dgl_graph.ndata['train_masks']
                meta_info['val_masks'] = dgl_graph.ndata['val_masks']
                meta_info['test_masks'] = dgl_graph.ndata['test_masks']
        
        meta_info['original_num_nodes'] = x.shape[0]
        
        edge_time = None
        edge_type = None
        
        if is_multi_etype:
            src_list, dst_list, type_list, time_list = [], [], [], []
            has_time = False
            
            for i, etype in enumerate(dgl_graph.etypes):
                src, dst = dgl_graph.edges(etype=etype)
                src_list.append(src)
                dst_list.append(dst)
                type_list.append(torch.full_like(src, i))
                
                edge_data = dgl_graph.edges[etype].data
                for possible_time_attr in [time_attr, 'timestamp', 'time', 'edge_timestamp']:
                    if possible_time_attr in edge_data:
                        time_list.append(edge_data[possible_time_attr])
                        has_time = True
                        break
                else:
                    if has_time:
                        time_list.append(torch.zeros_like(src, dtype=torch.float))
            
            edge_index = torch.stack([torch.cat(src_list), torch.cat(dst_list)], dim=0).long()
            edge_type = torch.cat(type_list).long()
            
            if has_time and time_list:
                edge_time = torch.cat(time_list).float()
        else:
            edge_index = torch.stack([dgl_graph.edges()[0], dgl_graph.edges()[1]], dim=0).long()
            
            edata = dgl_graph.edata
            for possible_time_attr in [time_attr, 'timestamp', 'time', 'edge_timestamp']:
                if possible_time_attr in edata:
                    edge_time = edata[possible_time_attr].float()
                    meta_info['original_time_attr'] = possible_time_attr
                    break
            
            if 'edge_type' in edata:
                edge_type = edata['edge_type'].long()
        
        if edge_time is None:
            edge_time = DataConverter._infer_edge_time_from_features(x, edge_index)
            if edge_time is not None:
                meta_info['edge_time_inferred'] = True
        
        meta_info['has_edge_time'] = edge_time is not None
        meta_info['has_edge_type'] = edge_type is not None
        
        pyg_data = Data(
            x=x,
            edge_index=edge_index,
            y=y.long(),
            train_mask=train_mask,
            val_mask=val_mask,
            test_mask=test_mask
        )
        
        if edge_time is not None:
            pyg_data.edge_time = edge_time
            inferred_flag = " (Inferred from node features)" if meta_info.get('edge_time_inferred', False) else ""
            print(f"      [DataConverter] ✓ Edge time attribute added edge_time{inferred_flag}")
            print(f"        Range: {edge_time.min():.1f} ~ {edge_time.max():.1f}")
        
        if edge_type is not None:
            pyg_data.edge_type = edge_type
            print(f"      [DataConverter] ✓ Edge type attribute added edge_type (Num types: {edge_type.unique().numel()})")
        
        return pyg_data, meta_info
    
    @staticmethod
    def _infer_edge_time_from_features(x: torch.Tensor, edge_index: torch.Tensor) -> Optional[torch.Tensor]:
        
        if x.shape[1] < 2:
            return None
        
        first_col = x[:, 0]
        col_min, col_max = first_col.min().item(), first_col.max().item()
        
        is_integer_like = torch.allclose(first_col, first_col.round(), atol=0.01)
        is_valid_range = (col_min >= 0.5) and (col_max <= 100) and (col_max > col_min)
        
        if is_valid_range and is_integer_like:
            src_nodes = edge_index[0]
            dst_nodes = edge_index[1]
            
            src_time = first_col[src_nodes]
            dst_time = first_col[dst_nodes]
            
            edge_time = torch.maximum(src_time, dst_time).float()
            
            print(f"      [DataConverter] 💡 Detected feature[:,0] might be timestep (Range: {col_min:.0f}~{col_max:.0f})")
            print(f"        → Inferred edge time from node features edge_time = max(src_time, dst_time)")
            
            return edge_time
        
        return None
    
    @staticmethod
    def pyg_to_dgl(pyg_data: Data, meta_info: Dict[str, Any], device: torch.device):
        """
        Convert PyG Data object back to DGL graph.
        """
        is_single_ntype = meta_info['is_single_ntype']
        is_multi_etype = meta_info['is_multi_etype']
        time_attr = meta_info.get('time_attr', 'edge_time')
        
        aug_edge_type = getattr(pyg_data, 'edge_type', None)
        aug_edge_time = getattr(pyg_data, 'edge_time', None)
        
        if is_multi_etype and is_single_ntype:
            ntype = meta_info['ntypes'][0]
            edge_dict = {}
            edge_time_dict = {}
            
            for i, etype in enumerate(meta_info['etypes']):
                etype_str = str(etype) if not isinstance(etype, str) else etype
                
                if aug_edge_type is not None:
                    mask = (aug_edge_type == i)
                    src = pyg_data.edge_index[0][mask]
                    dst = pyg_data.edge_index[1][mask]
                    
                    if aug_edge_time is not None:
                        edge_time_dict[etype_str] = aug_edge_time[mask]
                else:
                    src, dst = pyg_data.edge_index
                
                edge_dict[(ntype, etype_str, ntype)] = (src, dst)
            
            dgl_graph = dgl.heterograph(edge_dict)
            
            dgl_graph.nodes[ntype].data['feature'] = pyg_data.x.cpu()
            dgl_graph.nodes[ntype].data['label'] = pyg_data.y.cpu()
            dgl_graph.nodes[ntype].data['train_mask'] = pyg_data.train_mask.cpu()
            dgl_graph.nodes[ntype].data['val_mask'] = pyg_data.val_mask.cpu()
            dgl_graph.nodes[ntype].data['test_mask'] = pyg_data.test_mask.cpu()
            
            for etype_str, et in edge_time_dict.items():
                canonical = (ntype, etype_str, ntype)
                if canonical in dgl_graph.canonical_etypes:
                    dgl_graph.edges[canonical].data[time_attr] = et.cpu()
        else:
            dgl_graph = dgl.graph((pyg_data.edge_index[0], pyg_data.edge_index[1]))
            dgl_graph.ndata['feature'] = pyg_data.x.cpu()
            dgl_graph.ndata['label'] = pyg_data.y.cpu()
            dgl_graph.ndata['train_mask'] = pyg_data.train_mask.cpu()
            dgl_graph.ndata['val_mask'] = pyg_data.val_mask.cpu()
            dgl_graph.ndata['test_mask'] = pyg_data.test_mask.cpu()
            
            if aug_edge_type is not None:
                dgl_graph.edata['edge_type'] = aug_edge_type.cpu()
            
            if aug_edge_time is not None:
                original_time_attr = meta_info.get('original_time_attr', time_attr)
                dgl_graph.edata[original_time_attr] = aug_edge_time.cpu()
        
        DataConverter._extend_trial_masks(dgl_graph, pyg_data, meta_info)
        
        return dgl_graph.to(device)
    
    @staticmethod
    def _extend_trial_masks(dgl_graph, pyg_data: Data, meta_info: Dict[str, Any]) -> None:
        """
        Extend multi-trial masks for newly added nodes.
        """
        if 'train_masks' not in meta_info:
            return
        
        is_single_ntype = meta_info['is_single_ntype']
        original_num_nodes = meta_info['original_num_nodes']
        new_num_nodes = pyg_data.x.shape[0]
        new_nodes_count = new_num_nodes - original_num_nodes
        
        if new_nodes_count <= 0:
            return
        
        orig_train_masks = meta_info['train_masks']
        orig_val_masks = meta_info['val_masks']
        orig_test_masks = meta_info['test_masks']
        num_trials = orig_train_masks.shape[1]
        
        new_train_masks = torch.ones(new_nodes_count, num_trials, dtype=torch.bool)
        new_val_masks = torch.zeros(new_nodes_count, num_trials, dtype=torch.bool)
        new_test_masks = torch.zeros(new_nodes_count, num_trials, dtype=torch.bool)
        
        extended_train = torch.cat([orig_train_masks, new_train_masks], dim=0)
        extended_val = torch.cat([orig_val_masks, new_val_masks], dim=0)
        extended_test = torch.cat([orig_test_masks, new_test_masks], dim=0)
        
        if is_single_ntype:
            ntype = meta_info['ntypes'][0]
            dgl_graph.nodes[ntype].data['train_masks'] = extended_train
            dgl_graph.nodes[ntype].data['val_masks'] = extended_val
            dgl_graph.nodes[ntype].data['test_masks'] = extended_test
        else:
            dgl_graph.ndata['train_masks'] = extended_train
            dgl_graph.ndata['val_masks'] = extended_val
            dgl_graph.ndata['test_masks'] = extended_test


class MCClusterAugmentor:
    
    
    def __init__(
        self,
        hid_dim: int = 64,
        vae_epochs: int = 50,
        
        cons_dim: int = 128,
        cons_epochs: int = 50,
        sample_steps: int = 1,
        
        mc_T_type: str = 'baseline',
        mc_T_k: float = 48.0,
        mc_T_eps: float = 0.002,
        mc_W_type: str = 'constant1',
        
        mc_schedule: str = 'linear',
        mc_eta: float = 0.0,
        mc_s_min: float = 0.002,
        mc_step_clip: Optional[float] = None,
        mc_rho: float = 7.0,
        mc_heun: bool = False,
        mc_single_s: str = 'zero',
        
        lr: float = 0.001,
        batch_size: int = 4096,
        gen_ratio: float = 1.0,
        reuse_ae: bool = False,
        reuse_cm: bool = False,
        cached_neg: bool = False,
        device: str = 'cuda',
        verbose: bool = False,
        time_attr: str = 'edge_time'
    ):
        
        self.hid_dim = hid_dim
        self.vae_epochs = vae_epochs
        self.cons_dim = cons_dim
        self.cons_epochs = cons_epochs
        self.sample_steps = sample_steps
        
        self.mc_T_type = mc_T_type
        self.mc_T_k = mc_T_k
        self.mc_T_eps = mc_T_eps
        self.mc_W_type = mc_W_type
        
        self.mc_schedule = mc_schedule
        self.mc_eta = mc_eta
        self.mc_s_min = mc_s_min
        self.mc_step_clip = mc_step_clip
        self.mc_rho = mc_rho
        self.mc_heun = mc_heun
        self.mc_single_use_s_min = (mc_single_s == 's_min')
        
        self.lr = lr
        self.batch_size = batch_size
        self.gen_ratio = gen_ratio
        self.reuse_ae = reuse_ae
        self.reuse_cm = reuse_cm
        self.cached_neg = cached_neg
        self.device = device
        self.verbose = verbose
        self.time_attr = time_attr
    
    def augment(self, dataset: Dataset) -> Tuple[Dataset, float, float]:
        
        print(f"\n{'='*60}")
        print(f"[MC-Cluster] Start data augmentation: {dataset.name}")
        print(f"{'='*60}")
        
        total_start = time.time()
        
        print("[1/4] Converting data format: DGL -> PyG ...")
        pyg_data, meta_info = DataConverter.dgl_to_pyg(dataset.graph, time_attr=self.time_attr)
        print(f"      Num nodes: {pyg_data.x.shape[0]}, Num edges: {pyg_data.edge_index.shape[1]}")
        
        print("[2/4] Initializing GOCM_Cluster model ...")
        
        if 'cuda' in self.device:
            try:
                device_id = int(self.device.split(':')[-1]) if ':' in self.device else 0
            except ValueError:
                device_id = 0
        else:
            device_id = -1
        
        gocm = GOCM_Cluster(
            name=f"{dataset.name}_cluster",
            hid_dim=self.hid_dim,
            cons_dim=self.cons_dim,
            vae_epochs=self.vae_epochs,
            cons_epochs=self.cons_epochs,
            lr=self.lr,
            batch_size=self.batch_size,
            sample_steps=self.sample_steps,
            device=device_id,
            verbose=self.verbose,
            reuse_ae=self.reuse_ae,
            reuse_cm=self.reuse_cm,
            gen_ratio=self.gen_ratio,
            cached_neg=self.cached_neg,
            mc_T_type=self.mc_T_type,
            mc_T_k=self.mc_T_k,
            mc_T_eps=self.mc_T_eps,
            mc_W_type=self.mc_W_type,
            mc_schedule=self.mc_schedule,
            mc_eta=self.mc_eta,
            mc_s_min=self.mc_s_min,
            mc_step_clip=self.mc_step_clip,
            mc_rho=self.mc_rho,
            mc_heun=self.mc_heun,
            mc_single_use_s_min=self.mc_single_use_s_min,
        )
        
        if hasattr(pyg_data, 'edge_type'):
            print(f"      ✓ Edge type detected (edge_type), multi-edge type support enabled")
        if hasattr(pyg_data, 'edge_time'):
            print(f"      ✓ Edge time detected (edge_time), temporal modeling enabled")
        else:
            print(f"      ✗ Edge time information not detected, running in non-temporal mode")
        
        print("[3/4] Executing MC Cluster version data augmentation ...")
        aug_pyg_data = gocm(pyg_data)
        
        mc_gen_time = getattr(gocm, 'last_gen_time', 0.0)
        
        if aug_pyg_data is None:
            raise ValueError("GOCM_Cluster returned None, augmentation failed")
        
        print(f"      Augmented num nodes: {aug_pyg_data.x.shape[0]}")
        print(f"      Added num nodes: {aug_pyg_data.x.shape[0] - meta_info['original_num_nodes']}")
        
        print("[4/4] Converting data format: PyG -> DGL ...")
        device_torch = torch.device(self.device)
        aug_dgl_graph = DataConverter.pyg_to_dgl(aug_pyg_data, meta_info, device_torch)
        
        dataset.graph = aug_dgl_graph
        
        total_time = time.time() - total_start
        
        print(f"\n[MC-Cluster] Data augmentation completed!")
        print(f"             MC inference time: {mc_gen_time:.2f}s, Total time: {total_time:.2f}s")
        print(f"{'='*60}\n")
        
        return dataset, mc_gen_time, total_time


class MCClusterBenchmark:
    
    
    def __init__(self, args):
        self.args = args
        self.device = self._setup_device()
        self.datasets = self._parse_datasets()
        self.models = self._parse_models()
        self.results = self._init_results_df()
        self.seed_list = self._generate_seed_list()
        
        self.augmentor = MCClusterAugmentor(
            hid_dim=args.hid_dim,
            vae_epochs=args.vae_epochs,
            cons_dim=args.cons_dim,
            cons_epochs=args.cons_epochs,
            sample_steps=args.sample_steps,
            mc_T_type=args.mc_T_type,
            mc_T_k=args.mc_T_k,
            mc_T_eps=args.mc_T_eps,
            mc_W_type=args.mc_W_type,
            mc_schedule=args.mc_schedule,
            mc_eta=args.mc_eta,
            mc_s_min=args.mc_s_min,
            mc_step_clip=args.mc_step_clip,
            mc_rho=args.mc_rho,
            mc_heun=args.mc_heun,
            mc_single_s=args.mc_single_s,
            lr=args.lr,
            batch_size=args.batch_size,
            gen_ratio=args.gen_ratio,
            reuse_ae=args.reuse_ae,
            reuse_cm=args.reuse_cm,
            cached_neg=args.cached_neg,
            device=self.device,
            verbose=args.verbose,
            time_attr=args.time_attr
        )
    
    def _setup_device(self) -> str:
        """Setup computing device"""
        args = self.args
        if args.gpu >= 0 and args.device == 'cpu':
            args.device = 'cuda'
        if args.device == 'cuda' and not torch.cuda.is_available():
            print("⚠️ Warning: CUDA not available, falling back to CPU")
            args.device = 'cpu'
        
        device = args.device if args.device == 'cpu' else f'cuda:{args.gpu}'
        print(f"🖥️ Using device: {device}")
        return device
    
    def _generate_seed_list(self) -> list:
        """
        
        """
        args = self.args
        
        if getattr(args, 'randomize_seeds', False):
            if args.seed_master is not None:
                random.seed(int(args.seed_master))
            else:
                random.seed(int(time.time() * 1000) % (2**31 - 1))
            seed_list = [random.randint(1, 2**31 - 1) for _ in range(args.trials)]
        else:
            seed_start = int(getattr(args, 'seed_base', 3407)) + int(getattr(args, 'seed_offset', 0))
            seed_step = int(getattr(args, 'seed_step', 10))
            seed_list = list(range(seed_start, seed_start + seed_step * args.trials, seed_step))
        
        print(f"🎲 Seed mode: randomize={getattr(args, 'randomize_seeds', False)}, seeds(head)={seed_list[:min(5, len(seed_list))]}")
        return seed_list
    
    def _parse_datasets(self) -> list:
        """Parse dataset list"""
        all_datasets = [
            'reddit', 'weibo', 'amazon', 'yelp', 'tfinance',
            'elliptic', 'tolokers', 'questions', 'dgraphfin', 'tsocial',
            'hetero/amazon', 'hetero/yelp'
        ]
        
        if self.args.datasets is None:
            return all_datasets
        
        if '-' in self.args.datasets:
            st, ed = self.args.datasets.split('-')
            datasets = all_datasets[int(st):int(ed)+1]
        else:
            datasets = [all_datasets[int(t)] for t in self.args.datasets.split(',')]
        
        print(f"📊 Evaluation datasets: {datasets}")
        return datasets
    
    def _parse_models(self) -> list:
        """Parse model list"""
        if self.args.models is None:
            return list(model_detector_dict.keys())
        
        if '-' in self.args.models:
            models = self.args.models.split('-')
        else:
            models = self.args.models.split(',')
        models = [m.strip() for m in models]
        print(f"🤖 Evaluation models: {models}")
        return models
    
    def _init_results_df(self) -> pandas.DataFrame:
        """Initialize results DataFrame"""
        columns = ['name']
        for dataset in self.datasets:
            for metric in ['AUROC mean', 'AUROC std', 'AUPRC mean', 'AUPRC std',
                           'RecK mean', 'RecK std', 'Time', 'MCGenTime', 'AugTime']:
                columns.append(f'{dataset}-{metric}')
        return pandas.DataFrame(columns=columns)
    
    def _should_skip(self, model: str, dataset: str) -> bool:
        """Check if a model-dataset combination should be skipped"""
        if model in ['CAREGNN', 'H2FD'] and 'hetero' not in dataset:
            return True
        return False
    
    def run(self) -> None:
        """Run the full benchmark"""
        print("\n" + "="*70)
        print("🚀 Start MC Cluster Version + GADBench Benchmark")
        print("="*70 + "\n")
        
        for model in self.models:
            model_result = {'name': model}
            
            for dataset_name in self.datasets:
                if self._should_skip(model, dataset_name):
                    continue
                
                result = self._evaluate_single(model, dataset_name)
                model_result.update(result)
                
                self._save_single_result(model, dataset_name, result)
            
            model_result_df = pandas.DataFrame(model_result, index=[0])
            self.results = pandas.concat([self.results, model_result_df])
            print(f"\nCurrent results summary:\n{self.results}")
        
        self._print_summary()
    
    def _evaluate_single(self, model: str, dataset_name: str) -> Dict[str, Any]:
        """
        Evaluate a single model on a single dataset
        
        Args:
            model: Model name
            dataset_name: Dataset name
            
        Returns:
            Dictionary containing metrics
        """
        print(f"\n{'─'*50}")
        print(f"📈 Evaluating: {model} @ {dataset_name}")
        print(f"{'─'*50}")
        
        time_cost = 0
        mc_gen_time = 0.0
        augmentation_time = 0.0
        
        train_config = {
            'device': self.device,
            'epochs': self.args.epochs,
            'patience': self.args.patience,
            'metric': 'AUPRC',
            'inductive': bool(self.args.inductive)
        }
        
        data = Dataset(dataset_name)
        
        if self.args.use_mc:
            try:
                data.split(self.args.semi_supervised, 0)
                data, mc_gen_time, augmentation_time = self.augmentor.augment(data)
            except Exception as e:
                import traceback
                print(f"❌ MC Cluster augmentation failed: {e}")
                traceback.print_exc()
                if self.args.strict:
                    raise
                print("⚠️ Continuing with original data...")
                mc_gen_time = 0.0
                augmentation_time = 0.0
        
        model_config = {'model': model, 'lr': 0.01, 'drop_rate': 0}
        if dataset_name == 'tsocial':
            model_config['h_feats'] = 16
        
        auc_list, pre_list, rec_list = [], [], []
        
        for t in range(self.args.trials):
            torch.cuda.empty_cache()
            print(f"  Trial {t+1}/{self.args.trials}", end=" ")
            
            data.split(self.args.semi_supervised, t)
            seed = self.seed_list[t]
            set_seed(seed)
            train_config['seed'] = seed
            
            detector = model_detector_dict[model](train_config, model_config, data)
            
            st = time.time()
            test_score = detector.train()
            ed = time.time()
            
            time_cost += ed - st
            auc_list.append(test_score['AUROC'])
            pre_list.append(test_score['AUPRC'])
            rec_list.append(test_score['RecK'])
            
            print(f"| AUROC: {test_score['AUROC']:.4f} | AUPRC: {test_score['AUPRC']:.4f}")
        
        del detector, data
        
        result = {
            f'{dataset_name}-AUROC mean': np.nanmean(auc_list),
            f'{dataset_name}-AUROC std': np.nanstd(auc_list),
            f'{dataset_name}-AUPRC mean': np.nanmean(pre_list),
            f'{dataset_name}-AUPRC std': np.nanstd(pre_list),
            f'{dataset_name}-RecK mean': np.nanmean(rec_list),
            f'{dataset_name}-RecK std': np.nanstd(rec_list),
            f'{dataset_name}-Time': time_cost / self.args.trials,
            f'{dataset_name}-MCGenTime': mc_gen_time,
            f'{dataset_name}-AugTime': augmentation_time
        }
        
        print(f"\n  📊 Result: AUROC={np.nanmean(auc_list):.4f}±{np.nanstd(auc_list):.4f}, "
              f"AUPRC={np.nanmean(pre_list):.4f}±{np.nanstd(pre_list):.4f}")
        
        return result
    
    def _save_single_result(self, model: str, dataset_name: str, result: Dict[str, Any]) -> None:
        """Save single dataset-model result"""
        single_result = {'name': model}
        single_result.update(result)
        df = pandas.DataFrame(single_result, index=[0])
        
        prefix = "mc_cluster_" if self.args.use_mc else ""
        save_results(df, None, dataset_name=f"{prefix}{dataset_name}", model_name=model)
    
    def _print_summary(self) -> None:
        """Print experiment summary"""
        print("\n" + "="*70)
        print("✅ Experiment Completed!")
        print("="*70)
        print("\nResults saved to results/ directory")
        print("\nUsage Examples:")
        print("  # Use MC Cluster augmentation (default)")
        print("  python new_mean_consistency_benchmark_cluster.py --models GCN --datasets 0,1")
        print("\n  # Disable MC Cluster augmentation")
        print("  python new_mean_consistency_benchmark_cluster.py --use_mc 0 --models GCN")
        print("\n  # Custom MC parameters")
        print("  python new_mean_consistency_benchmark_cluster.py --cons_epochs 100 --sample_steps 2")
        print("\n  # Enable Heun second-order correction")
        print("  python new_mean_consistency_benchmark_cluster.py --mc_heun")


def parse_args():
    parser = argparse.ArgumentParser(
        description='MC Cluster Version + GADBench Benchmark',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:


  # Specify models and datasets
  python New_Origin_consistency_benchmark_cluster.py --models GCN-GAT --datasets 0,1,2


        """
    )
    
    exp_group = parser.add_argument_group('Experiment Settings')
    exp_group.add_argument('--trials', type=int, default=10,
                           help='Number of experiment trials (default: 10)')
    exp_group.add_argument('--semi_supervised', type=int, default=0,
                           help='Whether to use semi-supervised setting (default: 0)')
    exp_group.add_argument('--inductive', type=int, default=0,
                           help='Whether to use inductive learning (default: 0)')
    exp_group.add_argument('--epochs', type=int, default=250,
                           help='Detector training epochs (default: 250)')
    exp_group.add_argument('--patience', type=int, default=50,
                           help='Early stopping patience (default: 50)')
    
    seed_group = parser.add_argument_group('Random Seed Control')
    seed_group.add_argument('--randomize_seeds', action='store_true',
                            help='Enable random seed for each trial (default: off = fixed seed sequence)')
    seed_group.add_argument('--seed_base', type=int, default=3407,
                            help='Start value for fixed seed sequence (default: 3407)')
    seed_group.add_argument('--seed_step', type=int, default=10,
                            help='Step size for fixed seed sequence (default: 10)')
    seed_group.add_argument('--seed_offset', type=int, default=0,
                            help='Offset for fixed seed sequence (default: 0)')
    seed_group.add_argument('--seed_master', type=int, default=None,
                            help='Master seed for initializing random numbers when randomize_seeds is enabled (repeatable randomness)')
    
    data_group = parser.add_argument_group('Models and Datasets')
    data_group.add_argument('--models', type=str, default=None,
                            help='Models to evaluate, separated by - or , (e.g., GCN-GAT-GraphSAGE)')
    data_group.add_argument('--datasets', type=str, default=None,
                            help='Datasets indices to evaluate, separated by comma or range (e.g., 0,1,2 or 0-5)')
    
    mc_group = parser.add_argument_group('MC Cluster Augmentation')
    mc_group.add_argument('--use_mc', type=int, default=1,
                          help='Whether to use MC Cluster data augmentation (default: 1)')
    
    vae_group = parser.add_argument_group('VGAE Parameters')
    vae_group.add_argument('--hid_dim', type=int, default=64,
                           help='VGAE hidden dimension (default: 64)')
    vae_group.add_argument('--vae_epochs', type=int, default=50,
                           help='VGAE training epochs (default: 50)')
    
    cons_group = parser.add_argument_group('MeanConsistency Model Parameters')
    cons_group.add_argument('--cons_dim', type=int, default=128,
                            help='MeanConsistency model dimension (default: 128)')
    cons_group.add_argument('--cons_epochs', type=int, default=50,
                            help='MeanConsistency training epochs (default: 50)')
    cons_group.add_argument('--sample_steps', type=int, default=1,
                            help='MC sampling steps (default: 1)')
    
    robust_group = parser.add_argument_group('MC Robust T/W Parameters')
    robust_group.add_argument('--mc_T_type', type=str, default='baseline',
                              choices=['baseline', 'saturated'],
                              help='MC Robust T selection (default: baseline)')
    robust_group.add_argument('--mc_T_k', type=float, default=48.0,
                              help='Saturated T upper bound parameter k (default: 48.0)')
    robust_group.add_argument('--mc_T_eps', type=float, default=0.002,
                              help='Lower bound for s in T calculation (default: 0.002)')
    robust_group.add_argument('--mc_W_type', type=str, default='constant1',
                              choices=['constant1'],
                              help='MC Weight W selection (default: constant1)')
    
    sample_group = parser.add_argument_group('MC Sampling Parameters')
    sample_group.add_argument('--mc_schedule', type=str, default='linear',
                              choices=['linear', 'rho'],
                              help='MC multi-step sampling time schedule (default: linear)')
    sample_group.add_argument('--mc_eta', type=float, default=0.0,
                              help='MC sampling noise injection intensity (default: 0.0)')
    sample_group.add_argument('--mc_s_min', type=float, default=0.002,
                              help='Lower bound of s in MC sampling (default: 0.002)')
    sample_group.add_argument('--mc_step_clip', type=float, default=None,
                              help='Gradient clipping threshold by sample norm for MC single step update (default: None)')
    sample_group.add_argument('--mc_rho', type=float, default=7.0,
                              help='MC rho schedule shape parameter (default: 7.0)')
    sample_group.add_argument('--mc_heun', action='store_true',
                              help='Enable Heun second-order correction (default: False)')
    sample_group.add_argument('--mc_single_s', type=str, default='zero',
                              choices=['zero', 's_min'],
                              help='End point s selection for single step sampling (default: zero)')
    
    other_group = parser.add_argument_group('Other Parameters')
    other_group.add_argument('--lr', '--gocm_lr', type=float, default=0.001,
                             dest='lr', help='GOCM learning rate (default: 0.001)')
    other_group.add_argument('--batch_size', type=int, default=4096,
                             help='Graph clustering batch size (default: 4096)')
    other_group.add_argument('--gen_ratio', type=float, default=1.0,
                             help='Ratio of generated nodes to anomaly nodes in training set (default: 1.0)')
    other_group.add_argument('--reuse_ae', action='store_true',
                             help='Reuse trained VGAE weights')
    other_group.add_argument('--reuse_cm', action='store_true',
                             help='Reuse trained MeanConsistency weights')
    other_group.add_argument('--cached_neg', action='store_true',
                             help='Enable negative sampling cache reuse')
    other_group.add_argument('--time_attr', type=str, default='edge_time',
                             help='Edge time attribute name (default: edge_time)')
    
    device_group = parser.add_argument_group('Device Settings')
    device_group.add_argument('--device', type=str, default='cpu',
                              help='Computing device (cpu or cuda)')
    device_group.add_argument('--gpu', type=int, default=0,
                              help='GPU ID (default: 0)')
    
    debug_group = parser.add_argument_group('Debug Options')
    debug_group.add_argument('--verbose', action='store_true',
                             help='Show verbose output')
    debug_group.add_argument('--strict', action='store_true',
                             help='Abort experiment if MC augmentation fails (instead of continuing with original data)')
    
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    benchmark = MCClusterBenchmark(args)
    benchmark.run()
