#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Mean Consistency Benchmark - Heterogeneous Version

"""

import os
import sys
import argparse
import warnings
import pandas as pd
import time
import random


sys.path.insert(0, os.path.join(os.getcwd(), 'gocm'))

import torch
import dgl
import numpy as np

from utils import *
from gocm.gocm_mivae import GOCM_MIVAE

warnings.filterwarnings("ignore")

os.environ['DGLBACKEND'] = 'pytorch'


def set_seed(seed: int = 3407):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


def _mask_stats(graph, trial_id: int, prefix: str = ""):
    
    y = graph.ndata['label']
    train = graph.ndata['train_mask']
    val = graph.ndata['val_mask']
    test = graph.ndata['test_mask']
    def cnt(mask):
        m = mask.bool()
        return int((y[m] == 1).sum()), int((y[m] == 0).sum()), int(m.sum())
    tr_pos, tr_neg, tr_tot = cnt(train)
    va_pos, va_neg, va_tot = cnt(val)
    te_pos, te_neg, te_tot = cnt(test)
    print(f"{prefix}trial={trial_id}: "
          f"train pos/neg/total={tr_pos}/{tr_neg}/{tr_tot}, "
          f"val pos/neg/total={va_pos}/{va_neg}/{va_tot}, "
          f"test pos/neg/total={te_pos}/{te_neg}/{te_tot}")


def parse_args():
    
    parser = argparse.ArgumentParser(description='Mean Consistency Benchmark - Heterogeneous Version')

    
    parser.add_argument('--trials', type=int, default=10)
    parser.add_argument('--semi_supervised', type=int, default=0)
    parser.add_argument('--inductive', type=int, default=0)
    parser.add_argument('--models', type=str, default=None)
    parser.add_argument('--datasets', type=str, default=None)
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--gpu', type=int, default=0)

    
    parser.add_argument('--randomize_seeds', action='store_true')
    parser.add_argument('--seed_base', type=int, default=3407)
    parser.add_argument('--seed_step', type=int, default=10)
    parser.add_argument('--seed_offset', type=int, default=0)
    parser.add_argument('--seed_master', type=int, default=None)

    
    parser.add_argument('--hid_dim', type=int, default=64, help='VGAE hidden dimension')
    parser.add_argument('--cons_dim', type=int, default=128, help='MeanConsistency model dimension')
    parser.add_argument('--vae_epochs', type=int, default=50, help='VGAE training epochs')
    parser.add_argument('--cons_epochs', type=int, default=50, help='MeanConsistency training epochs')
    parser.add_argument('--gocm_lr', type=float, default=1e-3, help='GOCM learning rate')
    parser.add_argument('--batch_size', type=int, default=2048, help='Batch size (not explicitly used in heterogeneous skeleton currently)')
    parser.add_argument('--sample_steps', type=int, default=1, help='MC sampling steps')
    parser.add_argument('--verbose', action='store_true', help='Verbose output')
    parser.add_argument('--reuse_ae', action='store_true', help='Reuse trained VGAE weights')
    parser.add_argument('--reuse_cm', action='store_true', help='Reuse trained MeanConsistency weights')
    parser.add_argument('--gen_ratio', type=float, default=1.0, help='Ratio to the number of anomaly nodes in training set (not explicitly used in heterogeneous skeleton currently)')
    parser.add_argument('--relink_ratio', type=float, default=0.0, help='Ratio of generated nodes relinking to original nodes')
    parser.add_argument('--cached_neg', action='store_true', help='Enable negative sampling cache (not explicitly used in heterogeneous skeleton currently)')

    
    parser.add_argument('--mc_T_type', type=str, default='baseline', choices=['baseline', 'saturated'])
    parser.add_argument('--mc_T_k', type=float, default=48.0)
    parser.add_argument('--mc_T_eps', type=float, default=0.002)
    parser.add_argument('--mc_W_type', type=str, default='constant1', choices=['constant1'])
    parser.add_argument('--mc_schedule', type=str, default='linear', choices=['linear', 'rho'])
    parser.add_argument('--mc_eta', type=float, default=0.0)
    parser.add_argument('--mc_s_min', type=float, default=0.002)
    parser.add_argument('--mc_step_clip', type=float, default=None)
    parser.add_argument('--mc_rho', type=float, default=7.0)
    parser.add_argument('--mc_heun', action='store_true')

    
    parser.add_argument('--target_ntype', type=str, default='user',
                       help='Target node type (amazon=user / yelp=review)')

    
    parser.add_argument('--relations', type=str, default='all',
                       help='Comma-separated subset of relation names; default all means automatically discovering all target->target relations')
    
   
    parser.add_argument('--fusion_strategy', type=str, default='concat_linear',
                       choices=['concat_linear', 'attention', 'mean'],
                       help='Cross-relation fusion strategy: concat_linear (paper method), attention, mean')
    
    
    parser.add_argument('--k_prototypes', type=int, default=3,
                       help='Number of KNN neighbors for edge generation (default 3)')
    
    parser.add_argument('--combine', type=str, default='concat_nodes', choices=['concat_nodes','dedup_edges'],
                       help='[Deprecated] Relation branch aggregation strategy in old architecture (not applicable in new architecture)')
    parser.add_argument('--dedup_undirected', action='store_true',
                       help='[Deprecated] Edge deduplication option in old architecture (not applicable in new architecture)')
    parser.add_argument('--branch_gen_ratio', type=str, default=None,
                       help='[Deprecated] Per-relation generation ratio in old architecture (use global gen_ratio in new architecture)')

    return parser.parse_args()


def _ensure_multi_trial_masks(g: dgl.DGLGraph, trials: int):
    n = g.number_of_nodes()
    device = g.device
    if 'train_masks' not in g.ndata or 'val_masks' not in g.ndata or 'test_masks' not in g.ndata:
        train_mask = g.ndata['train_mask'] if 'train_mask' in g.ndata else torch.ones(n, dtype=torch.bool, device=device)
        val_mask = g.ndata['val_mask'] if 'val_mask' in g.ndata else torch.zeros(n, dtype=torch.bool, device=device)
        test_mask = g.ndata['test_mask'] if 'test_mask' in g.ndata else torch.zeros(n, dtype=torch.bool, device=device)
        g.ndata['train_masks'] = train_mask.unsqueeze(1).repeat(1, int(trials))
        g.ndata['val_masks'] = val_mask.unsqueeze(1).repeat(1, int(trials))
        g.ndata['test_masks'] = test_mask.unsqueeze(1).repeat(1, int(trials))
    return g


def _parse_relations_arg(relations: str):
    if relations is None or relations == 'all':
        return 'all'
    items = [r.strip() for r in str(relations).split(',') if len(r.strip()) > 0]
    return items if len(items) > 0 else 'all'


def _parse_branch_gen_ratio_arg(r: str):
    if r is None:
        return None
    try:
        vals = [float(x) for x in str(r).split(',') if len(x) > 0]
        return vals if len(vals) > 0 else None
    except Exception:
        return None


def _hetero_to_homo(g_hetero: dgl.DGLHeteroGraph, target_ntype: str) -> dgl.DGLGraph:
    num_nodes = g_hetero.number_of_nodes(target_ntype)
    
    all_src, all_dst = [], []
    for et in g_hetero.canonical_etypes:
        s, r, d = et
        if s == target_ntype and d == target_ntype:
            src, dst = g_hetero.edges(etype=et)
            all_src.append(src.cpu())
            all_dst.append(dst.cpu())
    
    if len(all_src) > 0:
        edge_src = torch.cat(all_src, dim=0)
        edge_dst = torch.cat(all_dst, dim=0)
    else:
        edge_src = torch.empty(0, dtype=torch.long)
        edge_dst = torch.empty(0, dtype=torch.long)
    
    g_homo = dgl.graph((edge_src, edge_dst), num_nodes=num_nodes)
    
    for key, value in g_hetero.nodes[target_ntype].data.items():
        g_homo.ndata[key] = value.cpu()
    
    return g_homo


def apply_gocm_mivae(dataset_obj, args, device_torch):
    dgl_graph = dataset_obj.graph
    available_ntypes = list(dgl_graph.ntypes) if hasattr(dgl_graph, 'ntypes') else []
    if args.target_ntype not in available_ntypes and len(available_ntypes) > 0:
        print(f"Warning: Specified target_ntype='{args.target_ntype}' not in graph, automatically switching to '{available_ntypes[0]}'")
        args.target_ntype = available_ntypes[0]

    print(f"Applying MIVAE augmentation (Dataset: {dataset_obj.name}, Target Node Type: {args.target_ntype})...")
    aug_start_time = time.time()

    device_id_for_pygod = args.gpu if (hasattr(args, 'device') and args.device == 'cuda') else -1
    
    if hasattr(args, 'combine') and args.combine != 'concat_nodes':
        print(f"  [Warning] --combine={args.combine} is no longer applicable in new architecture (joint training does not need merge strategy)")
    if hasattr(args, 'dedup_undirected') and args.dedup_undirected:
        print(f"  [Warning] --dedup_undirected is no longer applicable in new architecture")
    if hasattr(args, 'branch_gen_ratio') and args.branch_gen_ratio:
        print(f"  [Warning] --branch_gen_ratio is no longer applicable in new architecture (use global --gen_ratio)")
    
    if args.relink_ratio > 0:
        print(f"  [Relink] Relinking enabled: relink_ratio={args.relink_ratio}")
    
    gocm = GOCM_MIVAE(
        name=f"{dataset_obj.name}_mivae",
        target_ntype=args.target_ntype,
        relations=_parse_relations_arg(getattr(args, 'relations', 'all')),
        hid_dim=args.hid_dim,
        cons_dim=args.cons_dim,
        vae_epochs=args.vae_epochs,
        cons_epochs=args.cons_epochs,
        lr=args.gocm_lr,
        batch_size=args.batch_size,
        sample_steps=args.sample_steps,
        device=device_id_for_pygod,
        verbose=args.verbose,
        reuse_ae=args.reuse_ae,
        reuse_cm=args.reuse_cm,
        gen_ratio=args.gen_ratio,
        relink_ratio=args.relink_ratio,
        fusion_strategy=getattr(args, 'fusion_strategy', 'concat_linear'),
        mc_T_type=args.mc_T_type,
        mc_T_k=args.mc_T_k,
        mc_T_eps=args.mc_T_eps,
        mc_W_type=args.mc_W_type,
        mc_schedule=args.mc_schedule,
        mc_eta=args.mc_eta,
        mc_s_min=args.mc_s_min,
        mc_step_clip=args.mc_step_clip,
        mc_rho=args.mc_rho,
        mc_heun=args.mc_heun,
        k_prototypes=getattr(args, 'k_prototypes', 3),
    )

    augmented_hetero = gocm(dgl_graph)

    aug_time = time.time() - aug_start_time
    mc_gen_time = getattr(gocm, 'last_gen_time', 0.0)

    dataset_obj.graph = augmented_hetero
    dataset_obj.__dict__[f'aug_time_{dataset_obj.name}'] = aug_time
    dataset_obj.__dict__[f'mc_gen_time_{dataset_obj.name}'] = mc_gen_time

    print(f"Data augmentation completed: Hetero Graph Nodes({args.target_ntype})={augmented_hetero.number_of_nodes(args.target_ntype)}")
    print(f"Augmentation Time: {aug_time:.2f}s, MC Inference Time: {mc_gen_time:.2f}s")
    return dataset_obj


def build_results_dataframe(selected_datasets):
    columns = ['name']
    for dataset in selected_datasets:
        for metric in ['AUROC mean', 'AUROC std', 'AUPRC mean', 'AUPRC std', 'RecK mean', 'RecK std', 'Time', 'AugTime', 'MCGenTime']:
            columns.append(dataset + '-' + metric)
    return pd.DataFrame(columns=columns)


def parse_datasets_arg(arg_value: str, base_list):
    if arg_value is None:
        return base_list
    
    max_idx = len(base_list) - 1
    
    try:
        if '-' in arg_value:
            st, ed = arg_value.split('-')
            start_idx, end_idx = int(st), int(ed)
            if start_idx < 0 or end_idx > max_idx:
                raise ValueError(f"Dataset index range {start_idx}-{end_idx} out of valid range [0, {max_idx}]")
            return base_list[start_idx:end_idx+1]
        else:
            indices = [int(t) for t in arg_value.split(',')]
            for idx in indices:
                if idx < 0 or idx > max_idx:
                    raise ValueError(f"Dataset index {idx} out of valid range [0, {max_idx}]")
            return [base_list[idx] for idx in indices]
    except (ValueError, IndexError) as e:
        print(f"\nError: Invalid dataset argument '{arg_value}'")
        print(f"Valid dataset index range: [0, {max_idx}]")
        print("Available dataset list:")
        for i, dataset in enumerate(base_list):
            print(f"  {i}: {dataset}")
        print(f"\nError details: {e}")
        raise SystemExit(1)


def main():
    args = parse_args()

    if args.gpu == 1 and args.device == 'cpu':
        args.device = 'cuda'
    if args.device == 'cuda' and not torch.cuda.is_available():
        print("Warning: CUDA requested but not available. Automatically switching to CPU mode.")
        args.device = 'cpu'
    if args.device == 'cuda':
        print("GPU enabled, using CUDA mode")
    else:
        print("GPU disabled, using CPU mode")
    device_torch = torch.device(args.device if args.device == 'cpu' else f'cuda:{args.gpu}')

    base_datasets = ['hetero/amazon', 'hetero/yelp']
    print(f'\nAvailable heterogeneous datasets (Total {len(base_datasets)}):')
    for i, ds in enumerate(base_datasets):
        print(f'  [{i}] {ds}')
    
    datasets = parse_datasets_arg(args.datasets, base_datasets)
    print(f'\nDatasets for this evaluation: {datasets}')

    models = model_detector_dict.keys() if args.models is None else (
        args.models.split('-') if '-' in args.models else args.models.split(',')
    )
    models = [m.strip() for m in models]
    print('Benchmark models: ', models)

    results = build_results_dataframe(datasets)
    file_id = None

    if getattr(args, 'randomize_seeds', False):
        if args.seed_master is not None:
            random.seed(int(args.seed_master))
        else:
            random.seed(int(time.time() * 1000) % (2**31 - 1))
        seed_list = [random.randint(1, 2**31 - 1) for _ in range(args.trials)]
    else:
        seed_start = int(args.seed_base) + int(args.seed_offset)
        seed_list = list(range(seed_start, seed_start + int(args.seed_step) * args.trials, int(args.seed_step)))

    print(f"Seed mode -> randomize={bool(args.randomize_seeds)}, seeds(head)={seed_list[:min(5, len(seed_list))]}")

    for model in models:
        model_result = {'name': model}
        for dataset_name in datasets:
            if model in ['CAREGNN', 'H2FD'] and 'hetero' not in dataset_name:
                continue

            time_cost = 0.0
            train_config = {
                'device': args.device,
                'epochs': 250,
                'patience': 50,
                'metric': 'AUPRC',
                'inductive': bool(args.inductive)
            }

            data = Dataset(dataset_name)

            mc_gen_time_record = 0.0
            try:
                aug_st = time.time()
                data = apply_gocm_mivae(data, args, device_torch)
                aug_ed = time.time()
                aug_time = aug_ed - aug_st
                try:
                    mc_gen_time_record = getattr(data, f'mc_gen_time_{dataset_name}', 0.0)
                except Exception:
                    mc_gen_time_record = 0.0
                _mask_stats(data.graph, trial_id=-1, prefix="[After Aug] ")
            except Exception as e:
                print(f"Error applying Hetero GOCM on {dataset_name}: {e}")
                import traceback
                traceback.print_exc()
                print("Continuing with original data...")
                aug_time = 0.0

            model_config = {'model': model, 'lr': 0.01, 'drop_rate': 0}
            if dataset_name == 'tsocial':
                model_config['h_feats'] = 16

            auc_list, pre_list, rec_list = [], [], []
            for t in range(args.trials):
                if args.device == 'cuda':
                    torch.cuda.empty_cache()
                print("Dataset {}, Model {}, Trial {}".format(dataset_name, model, t))
                data.split(args.semi_supervised, t)
                _mask_stats(data.graph, trial_id=t, prefix="[Split] ")
                seed = seed_list[t]
                set_seed(seed)
                train_config['seed'] = seed
                detector = model_detector_dict[model](train_config, model_config, data)
                st = time.time()
                print(detector.model)
                test_score = detector.train()
                auc_list.append(test_score['AUROC'])
                pre_list.append(test_score['AUPRC'])
                rec_list.append(test_score['RecK'])
                ed = time.time()
                time_cost += ed - st
            del detector, data

            model_result[dataset_name+'-AUROC mean'] = np.mean(auc_list)
            model_result[dataset_name+'-AUROC std'] = np.std(auc_list)
            model_result[dataset_name+'-AUPRC mean'] = np.mean(pre_list)
            model_result[dataset_name+'-AUPRC std'] = np.std(pre_list)
            model_result[dataset_name+'-RecK mean'] = np.mean(rec_list)
            model_result[dataset_name+'-RecK std'] = np.std(rec_list)
            model_result[dataset_name+'-Time'] = time_cost/args.trials
            model_result[dataset_name+'-AugTime'] = aug_time
            model_result[dataset_name+'-MCGenTime'] = mc_gen_time_record
            
            dataset_model_result = {'name': model}
            dataset_model_result[dataset_name+'-AUROC mean'] = np.mean(auc_list)
            dataset_model_result[dataset_name+'-AUROC std'] = np.std(auc_list)
            dataset_model_result[dataset_name+'-AUPRC mean'] = np.mean(pre_list)
            dataset_model_result[dataset_name+'-AUPRC std'] = np.std(pre_list)
            dataset_model_result[dataset_name+'-RecK mean'] = np.mean(rec_list)
            dataset_model_result[dataset_name+'-RecK std'] = np.std(rec_list)
            dataset_model_result[dataset_name+'-Time'] = time_cost/args.trials
            dataset_model_result[dataset_name+'-AugTime'] = aug_time
            dataset_model_result[dataset_name+'-MCGenTime'] = mc_gen_time_record
            dataset_model_df = pd.DataFrame(dataset_model_result, index=[0])
            save_results(dataset_model_df, None, dataset_name=dataset_name, model_name=model)

        model_result = pd.DataFrame(model_result, index=[0])
        results = pd.concat([results, model_result])
        print(results)

    print("\nExperiment completed. Results saved to results/ directory by dataset-model combination")
    
if __name__ == '__main__':
    main()
