#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Consistency Benchmark - Graph Clustering Version


"""

import os
import sys
import argparse
import warnings
import pandas as pd
import time
from pathlib import Path
import random

sys.path.insert(0, os.path.join(os.getcwd(), 'gocm'))

import torch
import dgl
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, recall_score

from utils import *
from gocm.gocm import GOCM_Consistency_Cluster

warnings.filterwarnings("ignore")

os.environ['DGLBACKEND'] = 'pytorch'


def set_seed(seed: int = 3407):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

def parse_args():
    """Parse command line arguments (aligned with mean_consistency_benchmark_cluster.py)"""
    parser = argparse.ArgumentParser(description='Consistency Benchmark - Graph Clustering Version')

    parser.add_argument('--trials', type=int, default=10)
    parser.add_argument('--semi_supervised', type=int, default=0)
    parser.add_argument('--inductive', type=int, default=0)
    parser.add_argument('--models', type=str, default=None)
    parser.add_argument('--datasets', type=str, default=None)
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--gpu', type=int, default=0)

    parser.add_argument('--randomize_seeds', action='store_true',
                        help='Enable random seed for each trial (default: off = fixed seed sequence)')
    parser.add_argument('--seed_base', type=int, default=3407,
                        help='Start value for fixed seed sequence (default: 3407)')
    parser.add_argument('--seed_step', type=int, default=10,
                        help='Step size for fixed seed sequence (default: 10)')
    parser.add_argument('--seed_offset', type=int, default=0,
                        help='Offset for fixed seed sequence (default: 0)')
    parser.add_argument('--seed_master', type=int, default=None,
                        help='Master seed for initializing random numbers when randomize_seeds is enabled (repeatable randomness)')

    parser.add_argument('--hid_dim', type=int, default=64, help='VGAE hidden dimension')
    parser.add_argument('--cons_dim', type=int, default=128, help='Consistency model dimension')
    parser.add_argument('--vae_epochs', type=int, default=50, help='VGAE training epochs')
    parser.add_argument('--cons_epochs', type=int, default=50, help='Consistency training epochs')
    parser.add_argument('--gocm_lr', type=float, default=0.001, help='GOCM learning rate')
    parser.add_argument('--batch_size', type=int, default=4096, help='Graph clustering batch size')
    parser.add_argument('--sample_steps', type=int, default=1, help='Consistency sampling steps')
    parser.add_argument('--verbose', action='store_true', help='Verbose output')
    parser.add_argument('--reuse_ae', action='store_true', help='Reuse trained VGAE weights, skip VGAE training if exists')
    parser.add_argument('--reuse_cm', action='store_true', help='Reuse trained Consistency weights, skip Consistency training if exists')
    parser.add_argument('--gen_ratio', type=float, default=1.0, help='Ratio of generated nodes to anomaly nodes in training set')
    parser.add_argument('--cached_neg', action='store_true', help='Enable negative sampling cache reuse (default: off = resample every batch)')

    parser.add_argument('--mc_schedule', type=str, default='linear', choices=['linear', 'rho'],
                        help='Consistency multi-step sampling time schedule')
    parser.add_argument('--mc_eta', type=float, default=0.0, help='Consistency sampling noise injection intensity')
    parser.add_argument('--mc_s_min', type=float, default=0.002, help='Lower bound of s in Consistency sampling for stabilization')
    parser.add_argument('--mc_step_clip', type=float, default=None, help='Gradient clipping threshold by sample norm for Consistency single step update')
    parser.add_argument('--mc_rho', type=float, default=7.0, help='Consistency rho schedule shape parameter (effective when schedule=rho)')
    parser.add_argument('--mc_heun', action='store_true', help='Enable Heun second-order correction (default: off = Euler)')
    parser.add_argument('--mc_single_s', type=str, default='zero', choices=['zero', 's_min'],
                        help='End point s selection for single step sampling: zero=use 0; s_min=use mc_s_min')

    return parser.parse_args()


def convert_dgl_to_pyg(g):
    
    from torch_geometric.data import Data
    
    x = g.ndata['feature'] if 'feature' in g.ndata else g.ndata['feat']
    
    src, dst = g.edges()
    edge_index = torch.stack([src.to(torch.long), dst.to(torch.long)], dim=0)
    
    y = g.ndata['label'] if 'label' in g.ndata else torch.zeros(x.size(0), dtype=torch.long)
    y = y.to(torch.long)
    train_mask = g.ndata['train_mask'] if 'train_mask' in g.ndata else torch.ones(x.size(0), dtype=torch.bool)
    val_mask = g.ndata['val_mask'] if 'val_mask' in g.ndata else torch.zeros(x.size(0), dtype=torch.bool)
    test_mask = g.ndata['test_mask'] if 'test_mask' in g.ndata else torch.zeros(x.size(0), dtype=torch.bool)
    
    data = Data(
        x=x,
        edge_index=edge_index,
        y=y,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask
    )
    
    return data


def convert_pyg_to_dgl(data, device):
    """Convert PyG graph to DGL graph format"""
    src, dst = data.edge_index
    g = dgl.graph((src, dst))
    
    g.ndata['feature'] = data.x.cpu()
    g.ndata['label'] = data.y.cpu()
    
    def process_mask(mask, num_nodes, device):
        if mask.dim() > 1:
            mask = mask[:, 0] if mask.size(0) == num_nodes else mask[0, :]
        return mask.cpu()
    
    num_nodes = data.x.size(0)
    g.ndata['train_mask'] = process_mask(data.train_mask, num_nodes, device)
    g.ndata['val_mask'] = process_mask(data.val_mask, num_nodes, device)
    g.ndata['test_mask'] = process_mask(data.test_mask, num_nodes, device)
    
    return g.to(device)


def apply_gocm_consistency_cluster(dataset_obj, args, device):
    
    dgl_graph = dataset_obj.graph

    print(f"Applying Graph Clustering GOCM (Consistency) Augmentation (Dataset: {dataset_obj.name})...")
    aug_start_time = time.time()

    pyg_data = convert_dgl_to_pyg(dgl_graph)

    device_id_for_pygod = args.gpu if (hasattr(args, 'device') and args.device == 'cuda') else -1

    gocm = GOCM_Consistency_Cluster(
        name=f"{dataset_obj.name}_consistency_cluster",
        hid_dim=args.hid_dim,
        cons_dim=args.cons_dim,
        vae_epochs=args.vae_epochs,
        cons_epochs=args.cons_epochs,
        lr=args.gocm_lr,
        batch_size=args.batch_size,
        sample_steps=args.sample_steps,
        device=device_id_for_pygod,
        verbose=args.verbose,
        reuse_ae=args.reuse_ae,
        reuse_cm=args.reuse_cm,
        gen_ratio=args.gen_ratio,
        cached_neg=args.cached_neg,
        mc_schedule=args.mc_schedule,
        mc_eta=args.mc_eta,
        mc_s_min=args.mc_s_min,
        mc_step_clip=args.mc_step_clip,
        mc_rho=args.mc_rho,
        mc_heun=args.mc_heun,
        mc_single_use_s_min=(args.mc_single_s == 's_min'),
    )

    augmented_pyg_data = gocm(pyg_data)

    aug_time = time.time() - aug_start_time
    consistency_gen_time = getattr(gocm, 'last_gen_time', 0.0)

    aug_dgl = convert_pyg_to_dgl(augmented_pyg_data, device)

    original_num_nodes = dgl_graph.number_of_nodes()
    new_num_nodes = aug_dgl.number_of_nodes()
    added = new_num_nodes - original_num_nodes
    if 'train_masks' in dgl_graph.ndata:
        num_trials = dgl_graph.ndata['train_masks'].shape[1]
        target_device = aug_dgl.device
        new_train_masks = torch.ones(added, num_trials, dtype=torch.bool, device=target_device)
        new_val_masks = torch.zeros(added, num_trials, dtype=torch.bool, device=target_device)
        new_test_masks = torch.zeros(added, num_trials, dtype=torch.bool, device=target_device)
        aug_dgl.ndata['train_masks'] = torch.cat([dgl_graph.ndata['train_masks'].to(target_device), new_train_masks], dim=0)
        aug_dgl.ndata['val_masks'] = torch.cat([dgl_graph.ndata['val_masks'].to(target_device), new_val_masks], dim=0)
        aug_dgl.ndata['test_masks'] = torch.cat([dgl_graph.ndata['test_masks'].to(target_device), new_test_masks], dim=0)

    dataset_obj.graph = aug_dgl
    dataset_obj.__dict__[f'aug_time_{dataset_obj.name}'] = aug_time
    dataset_obj.__dict__[f'mc_gen_time_{dataset_obj.name}'] = consistency_gen_time

    print(f"Data augmentation completed: Original {original_num_nodes} nodes -> Augmented {new_num_nodes} nodes")
    print(f"Augmentation time: {aug_time:.2f}s, Consistency inference time: {consistency_gen_time:.2f}s")
    return dataset_obj


def build_results_dataframe(selected_datasets):
    
    columns = ['name']
    for dataset in selected_datasets:
        for metric in ['AUROC mean', 'AUROC std', 'AUPRC mean', 'AUPRC std',
                       'RecK mean', 'RecK std', 'Time', 'AugTime', 'ConsGenTime']:
            columns.append(dataset + '-' + metric)
    return pd.DataFrame(columns=columns)


def parse_datasets_arg(arg_value: str, base_list):
    
    if arg_value is None:
        return base_list
    if '-' in arg_value:
        st, ed = arg_value.split('-')
        return base_list[int(st):int(ed)+1]
    return [base_list[int(t)] for t in arg_value.split(',')]


def main():
    
    args = parse_args()

    if args.gpu == 1 and args.device == 'cpu':
        args.device = 'cuda'
    if args.device == 'cuda' and not torch.cuda.is_available():
        print("Warning: CUDA requested but not available. Switching to CPU mode automatically.")
        args.device = 'cpu'
    if args.device == 'cuda':
        print("GPU enabled, using CUDA mode")
    else:
        print("GPU disabled, using CPU mode")
    device_torch = torch.device(args.device if args.device == 'cpu' else f'cuda:{args.gpu}')

    base_datasets = ['reddit', 'weibo', 'amazon', 'yelp', 'tfinance',
                     'elliptic', 'tolokers', 'questions', 'dgraphfin', 'tsocial',
                     'hetero/amazon', 'hetero/yelp']
    datasets = parse_datasets_arg(args.datasets, base_datasets)
    print('Evaluation Datasets: ', datasets)

    models = model_detector_dict.keys() if args.models is None else (
        args.models.split('-') if '-' in args.models else args.models.split(',')
    )
    models = [m.strip() for m in models]
    print('Evaluation Baseline Models: ', models)

    results = build_results_dataframe(datasets)
    file_id = None

    if getattr(args, 'randomize_seeds', False):
        if args.seed_master is not None:
            random.seed(int(args.seed_master))
        else:
            random.seed(int(time.time() * 1000) % (2**31 - 1))
        seed_list = [random.randint(1, 2**31 - 1) for _ in range(args.trials)]
    else:
        seed_start = int(args.seed_base) + int(args.seed_offset)
        seed_list = list(range(seed_start, seed_start + int(args.seed_step) * args.trials, int(args.seed_step)))

    print(f"Seed mode -> randomize={bool(args.randomize_seeds)}, seeds(head)={seed_list[:min(5, len(seed_list))]}")

    for model in models:
        model_result = {'name': model}
        for dataset_name in datasets:
            if model in ['CAREGNN', 'H2FD'] and 'hetero' not in dataset_name:
                continue

            time_cost = 0.0
            train_config = {
                'device': args.device,
                'epochs': 250,
                'patience': 50,
                'metric': 'AUPRC',
                'inductive': bool(args.inductive)
            }

            data = Dataset(dataset_name)

            data.split(args.semi_supervised, 0)

            consistency_gen_time_record = 0.0
            try:
                aug_st = time.time()
                data = apply_gocm_consistency_cluster(data, args, device_torch)
                aug_ed = time.time()
                aug_time = aug_ed - aug_st
                try:
                    consistency_gen_time_record = getattr(data, f'mc_gen_time_{dataset_name}', 0.0)
                except Exception:
                    consistency_gen_time_record = 0.0
            except Exception as e:
                print(f"Error applying Graph Clustering GOCM (Consistency) on {dataset_name}: {e}")
                import traceback
                traceback.print_exc()
                print("Continuing with original data...")
                aug_time = 0.0

            model_config = {'model': model, 'lr': 0.01, 'drop_rate': 0}
            if dataset_name == 'tsocial':
                model_config['h_feats'] = 16

            auc_list, pre_list, rec_list = [], [], []
            for t in range(args.trials):
                if args.device == 'cuda':
                    torch.cuda.empty_cache()
                print("Dataset {}, Model {}, Trial {}".format(dataset_name, model, t))
                data.split(args.semi_supervised, t)
                seed = seed_list[t]
                set_seed(seed)
                train_config['seed'] = seed
                detector = model_detector_dict[model](train_config, model_config, data)
                st = time.time()
                print(detector.model)
                test_score = detector.train()
                auc_list.append(test_score['AUROC'])
                pre_list.append(test_score['AUPRC'])
                rec_list.append(test_score['RecK'])
                ed = time.time()
                time_cost += ed - st
            del detector, data

            model_result[dataset_name+'-AUROC mean'] = np.nanmean(auc_list)
            model_result[dataset_name+'-AUROC std'] = np.nanstd(auc_list)
            model_result[dataset_name+'-AUPRC mean'] = np.nanmean(pre_list)
            model_result[dataset_name+'-AUPRC std'] = np.nanstd(pre_list)
            model_result[dataset_name+'-RecK mean'] = np.nanmean(rec_list)
            model_result[dataset_name+'-RecK std'] = np.nanstd(rec_list)
            model_result[dataset_name+'-Time'] = time_cost/args.trials
            model_result[dataset_name+'-AugTime'] = aug_time
            model_result[dataset_name+'-ConsGenTime'] = consistency_gen_time_record
            
            dataset_model_result = {'name': model}
            dataset_model_result[dataset_name+'-AUROC mean'] = np.nanmean(auc_list)
            dataset_model_result[dataset_name+'-AUROC std'] = np.nanstd(auc_list)
            dataset_model_result[dataset_name+'-AUPRC mean'] = np.nanmean(pre_list)
            dataset_model_result[dataset_name+'-AUPRC std'] = np.nanstd(pre_list)
            dataset_model_result[dataset_name+'-RecK mean'] = np.nanmean(rec_list)
            dataset_model_result[dataset_name+'-RecK std'] = np.nanstd(rec_list)
            dataset_model_result[dataset_name+'-Time'] = time_cost/args.trials
            dataset_model_result[dataset_name+'-AugTime'] = aug_time
            dataset_model_result[dataset_name+'-ConsGenTime'] = consistency_gen_time_record
            dataset_model_df = pd.DataFrame(dataset_model_result, index=[0])
            save_results(dataset_model_df, None, dataset_name=f"CM_{dataset_name}", model_name=model)

        model_result = pd.DataFrame(model_result, index=[0])
        results = pd.concat([results, model_result])
        print(results)

    print("\nExperiment completed. Results saved to results/ directory by dataset-model combination")


if __name__ == '__main__':
    main()


if __name__ == '__main__':
    main()

