#!/usr/bin/env python3
import argparse
import yaml
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from pathlib import Path
from typing import Dict, Any

sys.path.append(str(Path(__file__).parent))
from core.utils import set_seed
from core.data import load_dataset
from core.models import create_model_from_config
from scripts.trainer import GraphDROTrainer
from core.perturbations import NodeCentricPerturbations
from core.metrics import compute_comprehensive_metrics, evaluate_robustness, accuracy, sp, eo


def stratified_split(labels, train_ratio=0.6, val_ratio=0.2, random_state=42):
    np.random.seed(random_state)
    torch.manual_seed(random_state)
    n = len(labels)
    unique_labels = torch.unique(labels)
    train_idx = []
    val_idx = []
    test_idx = []
    for label in unique_labels:
        label_indices = torch.where(labels == label)[0]
        n_label = len(label_indices)
        n_train = int(train_ratio * n_label)
        n_val = int(val_ratio * n_label)
        perm = torch.randperm(n_label)
        label_indices = label_indices[perm]
        train_idx.extend(label_indices[:n_train].tolist())
        val_idx.extend(label_indices[n_train:n_train + n_val].tolist())
        test_idx.extend(label_indices[n_train + n_val:].tolist())
    return torch.tensor(train_idx, dtype=torch.long), torch.tensor(val_idx, dtype=torch.long), torch.tensor(test_idx, dtype=torch.long)


def show_supported_datasets(config_path: str):
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    if 'datasets' in config:
        print(" Supported datasets:")
        print("=" * 60)
        for name, dataset_config in config['datasets'].items():
            print(f"  {name}:")
            print(f"   Path: {dataset_config['path']}")
            print(f"   Predict attribute: {dataset_config['predict_attr']}")
            print(f"   Sensitive attribute: {dataset_config['sens_attr']}")
            print(f"   Sensitive attribute index: {dataset_config['sens_attr_idx']}")
            print(f"   Number of labels: {dataset_config['label_number']}")
            print()
    else:
        print(" No dataset configuration found in the config file")


def load_config(config_path: str, dataset_name: str = None, model_name: str = None) -> Dict[str, Any]:
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    if config is None:
        raise ValueError(f"Failed to load config file {config_path}, please check the YAML format")
    if dataset_name and 'datasets' in config and dataset_name in config['datasets']:
        print(f" Using dataset: {dataset_name}")
        config['dataset'] = dataset_name
        config['data'] = config['datasets'][dataset_name]
        if 'logging' in config:
            config['logging']['save_path'] = f"runs/{dataset_name}_{model_name if model_name else 'graphdro'}/"
        print(f" Dataset configuration updated:")
        print(f"  Path: {config['data']['path']}")
        print(f"  Predict attribute: {config['data']['predict_attr']}")
        print(f"  Sensitive attribute: {config['data']['sens_attr']}")
        print(f"  Sensitive attribute index: {config['data']['sens_attr_idx']}")
    if model_name:
        print(f" Using model: {model_name}")
        config['model']['name'] = model_name.lower()
        if model_name.upper() == 'GCN':
            config['model'].update({
                'hidden_dim': 32,
                'dropout': 0.3,
                'num_layers': 2,
                'activation': 'relu'
            })
        elif model_name.upper() == 'GRAPHSAGE':
            config['model'].update({
                'name': 'sage',
                'hidden_dim': 64,
                'dropout': 0.5,
                'num_layers': 2,
                'activation': 'relu'
            })
        elif model_name.upper() == 'GIN':
            config['model'].update({
                'hidden_dim': 64,
                'dropout': 0.5,
                'num_layers': 3,
                'activation': 'relu',
                'eps': 0.0
            })
        print(f" Model configuration updated:")
        print(f"  Model type: {config['model']['name']}")
        print(f"  Hidden dimension: {config['model']['hidden_dim']}")
        print(f"  Number of layers: {config['model']['num_layers']}")
        print(f"  Dropout: {config['model']['dropout']}")
    return config


def compute_fairness_metrics(y_true, y_pred, s):
    pred_counts = torch.bincount(y_pred)
    print(f"Prediction distribution: {pred_counts.tolist()}")
    accuracy = (y_true == y_pred).float().mean().item()
    sp_0 = y_pred[s == 0].float().mean().item()
    sp_1 = y_pred[s == 1].float().mean().item()
    sp_gap = abs(sp_0 - sp_1)
    mask_0_pos = (s == 0) & (y_true == 1)
    mask_1_pos = (s == 1) & (y_true == 1)
    if mask_0_pos.sum() > 0 and mask_1_pos.sum() > 0:
        eo_0 = y_pred[mask_0_pos].float().mean().item()
        eo_1 = y_pred[mask_1_pos].float().mean().item()
        eo_gap = abs(eo_0 - eo_1)
    else:
        eo_0 = eo_1 = eo_gap = 0.0
    return {
        'accuracy': accuracy,
        'sp_gap': sp_gap,
        'eo_gap': eo_gap,
        'sp_0': sp_0,
        'sp_1': sp_1,
        'eo_0': eo_0,
        'eo_1': eo_1
    }


def apply_perturbations(data, eps_e=0.1, eps_x=0.01, eps_l=0.1, gamma=0.2):
    perturbed_data = data.clone()
    if eps_e > 0:
        num_edges = data.edge_index.size(1)
        num_perturb = int(eps_e * num_edges)
        edge_mask = torch.randperm(num_edges)[:num_edges - num_perturb]
        perturbed_data.edge_index = data.edge_index[:, edge_mask]
        print(f"    Edge perturbation: Removed {num_perturb}/{num_edges} edges ({eps_e*100:.1f}%)")
    if eps_x > 0:
        noise = torch.randn_like(data.x) * eps_x
        perturbed_data.x = data.x + noise
        perturbed_data.x = torch.clamp(perturbed_data.x, 0, 1)
        feature_perturb_magnitude = torch.norm(perturbed_data.x - data.x, p=2, dim=1).mean().item()
        print(f"    Feature perturbation magnitude: {feature_perturb_magnitude:.4f}")
    if eps_l > 0:
        num_nodes = data.y.size(0)
        num_flip = int(eps_l * num_nodes)
        flip_idx = torch.randperm(num_nodes)[:num_flip]
        perturbed_data.y = data.y.clone()
        perturbed_data.y[flip_idx] = 1 - perturbed_data.y[flip_idx]
        print(f"    Label perturbation: Flipped {num_flip}/{num_nodes} labels ({eps_l*100:.1f}%)")
    if gamma > 0:
        num_nodes = data.s.size(0)
        num_flip = int(gamma * num_nodes)
        flip_idx = torch.randperm(num_nodes)[:num_flip]
        perturbed_data.s = data.s.clone()
        perturbed_data.s[flip_idx] = 1 - perturbed_data.s[flip_idx]
        print(f"    Sensitive attribute perturbation: Flipped {num_flip}/{num_nodes} attributes ({gamma*100:.1f}%)")
    return perturbed_data


def train_baseline_model(config: Dict[str, Any], data, train_idx, val_idx, test_idx, num_epochs: int = 60, use_perturbed_data: bool = True) -> Dict[str, Any]:
    print("=" * 60)
    print("Training baseline model")
    print("=" * 60)
    device = torch.device(config['training']['device'])
    data = data.to(device)
    print("Using original data for training (to avoid backpropagation issues)")
    train_data = data
    print(f" Dataset information:")
    print(f"  Label distribution: {torch.bincount(data.y.long())}")
    print(f"  Sensitive attribute distribution: {torch.bincount(data.s.long())}")
    print(f"  Training set size: {len(train_idx)}")
    print(f"  Validation set size: {len(val_idx)}")
    print(f"  Test set size: {len(test_idx)}")
    actual_feature_dim = data.x.size(1)
    print(f" Actual feature dimension: {actual_feature_dim}")
    try:
        model = create_model_from_config(config, actual_feature_dim, 2).to(device)
        print(f" Model created successfully: {type(model).__name__}")
    except Exception as e:
        print(f" Model creation failed: {e}")
        import traceback
        traceback.print_exc()
        raise
    optimizer = optim.Adam(model.parameters(), lr=config['optimizer']['lr'])
    train_labels = data.y[train_idx]
    class_counts = torch.bincount(train_labels.long())
    total_samples = class_counts.sum()
    class_weights = total_samples / (len(class_counts) * class_counts.float())
    print(f"  Class distribution: {class_counts.tolist()}")
    print(f"  Class weights: {class_weights.tolist()}")
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device), reduction='mean')
    best_val_acc = 0
    patience = 10
    patience_counter = 0
    model.train()
    start_time = time.time()
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        x_full = train_data.x
        out = model(train_data.edge_index, x_full)
        loss = criterion(out[train_idx], train_data.y[train_idx])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_out = model(data.edge_index, data.x)
                val_pred = val_out[val_idx].argmax(dim=1)
                val_metrics = compute_fairness_metrics(data.y[val_idx], val_pred, data.s[val_idx])
                val_pred_counts = torch.bincount(val_pred, minlength=2)
                val_true_counts = torch.bincount(data.y[val_idx].long(), minlength=2)
                print(f"Epoch {epoch+1:3d}: Loss={loss:.4f}, Val_Acc={val_metrics['accuracy']:.4f}")
                print(f"  Validation true distribution: {val_true_counts.tolist()}")
                print(f"  Validation prediction distribution: {val_pred_counts.tolist()}")
                if val_metrics['accuracy'] > best_val_acc:
                    best_val_acc = val_metrics['accuracy']
                    patience_counter = 0
                else:
                    patience_counter += 1
                if patience_counter >= patience:
                    print(f" Early stopping triggered at epoch {epoch+1}")
                    break
            model.train()
    training_time = time.time() - start_time
    print(f" Baseline model training completed! Time taken: {training_time:.2f} seconds")
    model.eval()
    with torch.no_grad():
        test_out = model(data.edge_index, data.x)
        test_pred = test_out[test_idx].argmax(dim=1)
        test_metrics = compute_fairness_metrics(data.y[test_idx], test_pred, data.s[test_idx])
    return model, test_metrics


def train_graphdro_model(config: Dict[str, Any], data, train_idx, val_idx, test_idx, num_epochs: int = 60, use_perturbed_data: bool = True) -> Dict[str, Any]:
    print("=" * 60)
    print("Training GraphDRO model")
    print("=" * 60)
    print(" GraphDRO uses mixed empirical distribution for training (original + perturbed data)")
    print("   - Original data weight: η")
    print("   - K perturbed graphs weight: (1-η)/K")
    print("   - Perturbation types: topology noise + feature noise + label noise + sensitive attribute noise")
    train_data = data
    print(f"Dataset information:")
    print(f"  Label distribution: {torch.bincount(data.y.long())}")
    print(f"  Sensitive attribute distribution: {torch.bincount(data.s.long())}")
    print(f"  Training set size: {len(train_idx)}")
    print(f"  Validation set size: {len(val_idx)}")
    print(f"  Test set size: {len(test_idx)}")
    n = train_data.x.size(0)
    train_mask = torch.zeros(n, dtype=torch.bool)
    val_mask = torch.zeros(n, dtype=torch.bool)
    test_mask = torch.zeros(n, dtype=torch.bool)
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    train_data.train_mask = train_mask
    train_data.val_mask = val_mask
    train_data.test_mask = test_mask
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    config['training']['num_epochs'] = num_epochs
    print(" Creating GraphDRO trainer...")
    trainer = GraphDROTrainer(config)
    print(" Creating model...")
    actual_feature_dim = data.x.size(1)
    print(f" Actual feature dimension: {actual_feature_dim}")
    model = create_model_from_config(config, actual_feature_dim, 2)
    print(" Starting GraphDRO training...")
    print("  Note: GraphDRO training might be slow, please be patient...")
    import signal
    import time

    def timeout_handler(signum, frame):
        raise TimeoutError("GraphDRO training timeout")

    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(3000)
    try:
        start_time = time.time()
        training_results = trainer.train(model, train_data)
        training_time = time.time() - start_time
        print(f" GraphDRO training completed! Time taken: {training_time:.2f} seconds")
        signal.alarm(0)
    except TimeoutError:
        print(" GraphDRO training timeout, using simplified training...")
        signal.alarm(0)
        return train_simplified_graphdro(config, data, train_idx, val_idx, test_idx, num_epochs)
    except Exception as e:
        print(f" GraphDRO training error: {e}")
        signal.alarm(0)
        return train_simplified_graphdro(config, data, train_idx, val_idx, test_idx, num_epochs)
    return model, training_results


def train_simplified_graphdro(config: Dict[str, Any], data, train_idx, val_idx, test_idx, num_epochs: int = 60):
    print(" Using simplified GraphDRO training...")
    device = torch.device(config['training']['device'])
    data = data.to(device)
    model = create_model_from_config(config, data.x.size(1), 2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config['optimizer']['lr'])
    train_labels = data.y[train_idx]
    class_counts = torch.bincount(train_labels.long())
    total_samples = class_counts.sum()
    class_weights = total_samples / (len(class_counts) * class_counts.float())
    print("  Class distribution: {class_counts.tolist()}")
    print("  Class weights: {class_weights.tolist()}")
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    print(" Using simplified perturbation training strategy...")
    model.train()
    start_time = time.time()
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        if torch.rand(1).item() < 0.3:
            perturbed_data = apply_perturbations(data, eps_e=0.1, eps_x=0.05, eps_l=0.0, gamma=0.0)
            x_full = perturbed_data.x
            edge_index = perturbed_data.edge_index
        else:
            x_full = data.x
            edge_index = data.edge_index
        out = model(edge_index, x_full)
        loss = criterion(out[train_idx], data.y[train_idx])
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_out = model(data.edge_index, data.x)
                val_pred = val_out[val_idx].argmax(dim=1)
                val_metrics = compute_fairness_metrics(data.y[val_idx], val_pred, data.s[val_idx])
                print(f"Epoch {epoch+1:3d}: Loss={loss:.4f}, Val_Acc={val_metrics['accuracy']:.4f}")
            model.train()
    training_time = time.time() - start_time
    print(f" Simplified GraphDRO training completed! Time taken: {training_time:.2f} seconds")
    model.eval()
    with torch.no_grad():
        test_out = model(data.edge_index, data.x)
        test_pred = test_out[test_idx].argmax(dim=1)
        test_metrics = compute_fairness_metrics(data.y[test_idx], test_pred, data.s[test_idx])
    return model, {'final_metrics': test_metrics}


def evaluate_robustness(model, data, test_idx, eps_e=0.5, eps_x=0.15, eps_l=0.1, gamma=0.2):
    model.eval()
    device = next(model.parameters()).device
    with torch.no_grad():
        print(" Clean Acc evaluation: using original data for prediction")
        clean_out = model(data.edge_index, data.x)
        clean_pred = clean_out[test_idx].argmax(dim=1)
        clean_metrics = compute_fairness_metrics(data.y[test_idx], clean_pred, data.s[test_idx])
        clean_correct = (clean_pred == data.y[test_idx]).sum().item()
        clean_total = len(test_idx)
        print(f"  Clean prediction: correct={clean_correct}/{clean_total} = {clean_correct/clean_total:.4f}")
        print(f"  True label distribution: {torch.bincount(data.y[test_idx], minlength=2).tolist()}")
        print(f"  Prediction distribution: {torch.bincount(clean_pred, minlength=2).tolist()}")
        print(f" Attack Acc evaluation: using perturbed data for prediction")
        print(f"  Perturbation parameters: eps_e={eps_e}, eps_x={eps_x}, eps_l={eps_l}, gamma={gamma}")
        perturbed_data = apply_perturbations(data, eps_e, eps_x, eps_l, gamma)
        edge_perturbation_ratio = 1.0 - (perturbed_data.edge_index.size(1) / data.edge_index.size(1))
        feature_perturbation_magnitude = torch.norm(perturbed_data.x - data.x, p=2, dim=1).mean().item()
        label_perturbation_ratio = 0.0
        if eps_l > 0:
            label_changes = (perturbed_data.y != data.y).sum().item()
            label_perturbation_ratio = label_changes / data.y.size(0)
        sens_perturbation_ratio = 0.0
        if gamma > 0:
            sens_changes = (perturbed_data.s != data.s).sum().item()
            sens_perturbation_ratio = sens_changes / data.s.size(0)
        print(f"  Edge perturbation ratio: {edge_perturbation_ratio:.4f}")
        print(f"  Feature perturbation magnitude: {feature_perturbation_magnitude:.4f}")
        print(f"  Label perturbation ratio: {label_perturbation_ratio:.4f}")
        print(f"  Sensitive attribute perturbation ratio: {sens_perturbation_ratio:.4f}")
        attack_out = model(perturbed_data.edge_index, perturbed_data.x)
        attack_pred = attack_out[test_idx].argmax(dim=1)
        attack_correct = (attack_pred == data.y[test_idx]).sum().item()
        attack_total = len(test_idx)
        print(f"  Attack prediction: correct={attack_correct}/{attack_total} = {attack_correct/attack_total:.4f}")
        print(f"  Prediction distribution: {torch.bincount(attack_pred, minlength=2).tolist()}")
        prediction_changes = (clean_pred != attack_pred).sum().item()
        print(f"  Prediction changes: {prediction_changes}/{clean_total} = {prediction_changes/clean_total:.4f}")
        attack_metrics = compute_fairness_metrics(data.y[test_idx], attack_pred, data.s[test_idx])
    return clean_metrics, attack_metrics


def comparison_experiment(config: Dict[str, Any], num_runs: int = 3, num_epochs: int = 60,
                         attack_params: Dict[str, float] = None) -> Dict[str, Any]:
    print(" Starting comparison experiment...")
    print(f" Will run {num_runs} independent experiments, each with {num_epochs} epochs")
    if attack_params is None:
        attack_params = {
            'eps_e': 0.5,
            'eps_x': 0.15,
            'eps_l': 0.1,
            'gamma': 0.2
        }
    print(f"Attack parameters: {attack_params}")
    baseline_results = []
    graphdro_results = []
    cached_indices = None
    cached_data = None
    for run in range(num_runs):
        print(f"\n{'='*60}")
        print(f"Starting run {run+1}/{num_runs}")
        print(f"{'='*60}")
        current_seed = config.get('seed', 42) + run
        set_seed(current_seed)
        print(f"Set random seed: {current_seed}")
        if run == 0:
            print("Loading dataset...")
            data, in_dim, out_dim = load_dataset(config)
            labels = data.y.long()
            train_idx, val_idx, test_idx = stratified_split(
                labels,
                train_ratio=0.6,
                val_ratio=0.2,
                random_state=config.get('seed', 42)
            )
            cached_indices = (train_idx, val_idx, test_idx)
            cached_data = data
        else:
            data = cached_data
            train_idx, val_idx, test_idx = cached_indices
        print(f"Data split: train={len(train_idx)}, val={len(val_idx)}, test={len(test_idx)}")
        print()
        baseline_model, baseline_metrics = train_baseline_model(config, data, train_idx, val_idx, test_idx, num_epochs, use_perturbed_data=True)
        graphdro_model, graphdro_results_single = train_graphdro_model(config, data, train_idx, val_idx, test_idx, num_epochs, use_perturbed_data=True)
        baseline_model_path = ""
        graphdro_model_path = ""
        if hasattr(config, 'save_models') and config.get('save_models', False):
            import os
            save_dir = config.get('model_save_dir', 'saved_models')
            os.makedirs(save_dir, exist_ok=True)
            baseline_model_path = os.path.join(save_dir, f'baseline_model_run_{run+1}.pth')
            graphdro_model_path = os.path.join(save_dir, f'graphdro_model_run_{run+1}.pth')
            torch.save(baseline_model.state_dict(), baseline_model_path)
            torch.save(graphdro_model.state_dict(), graphdro_model_path)
        print(" Evaluating model robustness...")
        baseline_clean, baseline_attack = evaluate_robustness(
            baseline_model, data, test_idx,
            eps_e=attack_params['eps_e'],
            eps_x=attack_params['eps_x'],
            eps_l=attack_params['eps_l'],
            gamma=attack_params['gamma']
        )
        graphdro_clean, graphdro_attack = evaluate_robustness(
            graphdro_model, data, test_idx,
            eps_e=attack_params['eps_e'],
            eps_x=attack_params['eps_x'],
            eps_l=attack_params['eps_l'],
            gamma=attack_params['gamma']
        )
        baseline_results.append({
            'clean': baseline_clean,
            'attack': baseline_attack
        })
        graphdro_results.append({
            'clean': graphdro_clean,
            'attack': graphdro_attack
        })
        print(f" Run {run+1} completed")
    print(f"\n{'='*60}")
    print(" Calculating average results from multiple runs")
    print(f"{'='*60}")

    def compute_stats(results, metric_name):
        values = [r[metric_name] for r in results]
        mean_val = np.mean(values)
        std_val = np.std(values)
        return mean_val, std_val

    baseline_stats = {}
    for phase in ['clean', 'attack']:
        baseline_stats[phase] = {}
        for metric in ['accuracy', 'sp_gap', 'eo_gap']:
            mean_val, std_val = compute_stats([r[phase] for r in baseline_results], metric)
            baseline_stats[phase][metric] = {'mean': mean_val, 'std': std_val}
    graphdro_stats = {}
    for phase in ['clean', 'attack']:
        graphdro_stats[phase] = {}
        for metric in ['accuracy', 'sp_gap', 'eo_gap']:
            mean_val, std_val = compute_stats([r[phase] for r in graphdro_results], metric)
            graphdro_stats[phase][metric] = {'mean': mean_val, 'std': std_val}
    print("=" * 60)
    print("Summary of multiple runs (mean ± std)")
    print("=" * 60)
    print(" Experimental design:")
    print("  - Training: Baseline uses original data, GraphDRO uses mixed distribution (original+perturbed)")
    print("  - Clean Acc: Prediction on original data")
    print("  - Attack Acc: Prediction on perturbed data")
    print("  - ΔSP and ΔEO: Calculated from predictions on original data")
    print(f"  - Perturbation types: Edge removal(eps_e={attack_params['eps_e']}) + Feature noise(eps_x={attack_params['eps_x']}) + Label perturbation(eps_l={attack_params['eps_l']}) + Sensitive attribute perturbation(gamma={attack_params['gamma']})")
    print()
    baseline_delta_acc = baseline_stats['clean']['accuracy']['mean'] - baseline_stats['attack']['accuracy']['mean']
    graphdro_delta_acc = graphdro_stats['clean']['accuracy']['mean'] - graphdro_stats['attack']['accuracy']['mean']
    print(" Core metrics comparison:")
    print(f"{'Metric':<15} {'Baseline':<20} {'GraphDRO':<20} {'Improvement':<10}")
    print("-" * 70)
    baseline_clean_str = f"{baseline_stats['clean']['accuracy']['mean']*100:.2f}±{baseline_stats['clean']['accuracy']['std']*100:.2f}"
    graphdro_clean_str = f"{graphdro_stats['clean']['accuracy']['mean']*100:.2f}±{graphdro_stats['clean']['accuracy']['std']*100:.2f}"
    clean_diff = (graphdro_stats['clean']['accuracy']['mean'] - baseline_stats['clean']['accuracy']['mean'])*100
    print(f"{'Clean Acc(%)':<15} {baseline_clean_str:<20} {graphdro_clean_str:<20} {clean_diff:+.2f}")
    baseline_attack_str = f"{baseline_stats['attack']['accuracy']['mean']*100:.2f}±{baseline_stats['attack']['accuracy']['std']*100:.2f}"
    graphdro_attack_str = f"{graphdro_stats['attack']['accuracy']['mean']*100:.2f}±{graphdro_stats['attack']['accuracy']['std']*100:.2f}"
    attack_diff = (graphdro_stats['attack']['accuracy']['mean'] - baseline_stats['attack']['accuracy']['mean'])*100
    print(f"{'Attack Acc(%)':<15} {baseline_attack_str:<20} {graphdro_attack_str:<20} {attack_diff:+.2f}")
    baseline_delta_str = f"{baseline_delta_acc*100:.2f}"
    graphdro_delta_str = f"{graphdro_delta_acc*100:.2f}"
    delta_diff = (graphdro_delta_acc - baseline_delta_acc)*100
    print(f"{'∆Acc(%)':<15} {baseline_delta_str:<20} {graphdro_delta_str:<20} {delta_diff:+.2f}")
    print()
    print(" Fairness metrics (on clean data):")
    print(f"{'Metric':<15} {'Baseline':<20} {'GraphDRO':<20} {'Improvement':<10}")
    print("-" * 70)
    baseline_sp_str = f"{baseline_stats['clean']['sp_gap']['mean']:.4f}±{baseline_stats['clean']['sp_gap']['std']:.4f}"
    graphdro_sp_str = f"{graphdro_stats['clean']['sp_gap']['mean']:.4f}±{graphdro_stats['clean']['sp_gap']['std']:.4f}"
    sp_diff = graphdro_stats['clean']['sp_gap']['mean'] - baseline_stats['clean']['sp_gap']['mean']
    print(f"{'∆SP':<15} {baseline_sp_str:<20} {graphdro_sp_str:<20} {sp_diff:+.4f}")
    baseline_eo_str = f"{baseline_stats['clean']['eo_gap']['mean']:.4f}±{baseline_stats['clean']['eo_gap']['std']:.4f}"
    graphdro_eo_str = f"{graphdro_stats['clean']['eo_gap']['mean']:.4f}±{graphdro_stats['clean']['eo_gap']['std']:.4f}"
    eo_diff = graphdro_stats['clean']['eo_gap']['mean'] - baseline_stats['clean']['eo_gap']['mean']
    print(f"{'∆EO':<15} {baseline_eo_str:<20} {graphdro_eo_str:<20} {eo_diff:+.4f}")
    print()
    print("  Robustness analysis:")
    baseline_robustness = baseline_stats['attack']['accuracy']['mean'] / baseline_stats['clean']['accuracy']['mean'] * 100
    graphdro_robustness = graphdro_stats['attack']['accuracy']['mean'] / graphdro_stats['clean']['accuracy']['mean'] * 100
    print(f"  Baseline model robustness: {baseline_robustness:.2f}% (accuracy retention rate)")
    print(f"  GraphDRO robustness: {graphdro_robustness:.2f}% (accuracy retention rate)")
    print(f"  Robustness difference: {graphdro_robustness - baseline_robustness:+.2f}%")
    print()
    if graphdro_delta_acc < baseline_delta_acc:
        print(" GraphDRO performs better: smaller ∆Acc, stronger robustness")
    else:
        print(" GraphDRO underperforms: larger ∆Acc, weaker robustness")
    if graphdro_stats['clean']['sp_gap']['mean'] < baseline_stats['clean']['sp_gap']['mean']:
        print(" GraphDRO has better fairness: smaller ∆SP")
    else:
        print(" GraphDRO fairness not improved: ∆SP not reduced")
    print(f"\n Detailed comparison:")
    clean_acc_diff = graphdro_stats['clean']['accuracy']['mean'] - baseline_stats['clean']['accuracy']['mean']
    attack_acc_diff = graphdro_stats['attack']['accuracy']['mean'] - baseline_stats['attack']['accuracy']['mean']
    print(f"  Clean accuracy difference: {clean_acc_diff:+.4f} (GraphDRO - Baseline)")
    print(f"  Attack accuracy difference: {attack_acc_diff:+.4f} (GraphDRO - Baseline)")
    baseline_robustness = baseline_stats['attack']['accuracy']['mean'] / baseline_stats['clean']['accuracy']['mean']
    graphdro_robustness = graphdro_stats['attack']['accuracy']['mean'] / graphdro_stats['clean']['accuracy']['mean']
    robustness_diff = graphdro_robustness - baseline_robustness
    print(f"  Robustness difference: {robustness_diff:+.4f} (GraphDRO - Baseline)")
    print("\n Detailed results:")
    for i in range(num_runs):
        print(f"  Run {i+1}:")
        print(f"    Baseline: Clean={baseline_results[i]['clean']['accuracy']:.4f}, Attack={baseline_results[i]['attack']['accuracy']:.4f}")
        print(f"    GraphDRO: Clean={graphdro_results[i]['clean']['accuracy']:.4f}, Attack={graphdro_results[i]['attack']['accuracy']:.4f}")
    return {
        'baseline': {
            'stats': baseline_stats,
            'individual_results': baseline_results
        },
        'graphdro': {
            'stats': graphdro_stats,
            'individual_results': graphdro_results
        },
        'comparison': {
            'clean_acc_diff': clean_acc_diff,
            'attack_acc_diff': attack_acc_diff,
            'robustness_diff': robustness_diff
        },
        'num_runs': num_runs,
        'num_epochs': num_epochs
    }


def train_model(config: Dict[str, Any]) -> Dict[str, Any]:
    print("Starting GraphDRO training...")
    set_seed(config.get('seed', 42))
    print(" Loading dataset...")
    data, in_dim, out_dim = load_dataset(config)
    print(f" Dataset loaded: {in_dim} features, {out_dim} classes")
    if data.x.size(0) > 10000:
        print(f" Large dataset detected ({data.x.size(0)} nodes), performing sampling...")
        sample_size = min(12000, int(data.x.size(0) * 0.6))
        sample_indices = torch.randperm(data.x.size(0))[:sample_size]
        sampled_data = data.clone()
        sampled_data.x = data.x[sample_indices]
        sampled_data.y = data.y[sample_indices]
        sampled_data.s = data.s[sample_indices]
        edge_index_cpu = data.edge_index.cpu()
        sample_indices_cpu = sample_indices.cpu()
        edge_mask = torch.isin(edge_index_cpu[0], sample_indices_cpu) & torch.isin(edge_index_cpu[1], sample_indices_cpu)
        sampled_data.edge_index = edge_index_cpu[:, edge_mask]
        idx_map = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(sample_indices_cpu)}
        new_edge_index = []
        for edge in sampled_data.edge_index.t():
            new_edge_index.append([idx_map[edge[0].item()], idx_map[edge[1].item()]])
        sampled_data.edge_index = torch.tensor(new_edge_index, dtype=torch.long).t()
        n_sampled = sampled_data.x.size(0)
        train_size = int(0.6 * n_sampled)
        val_size = int(0.2 * n_sampled)
        indices = torch.randperm(n_sampled)
        train_mask = torch.zeros(n_sampled, dtype=torch.bool)
        val_mask = torch.zeros(n_sampled, dtype=torch.bool)
        test_mask = torch.zeros(n_sampled, dtype=torch.bool)
        train_mask[indices[:train_size]] = True
        val_mask[indices[train_size:train_size + val_size]] = True
        test_mask[indices[train_size + val_size:]] = True
        sampled_data.train_mask = train_mask
        sampled_data.val_mask = val_mask
        sampled_data.test_mask = test_mask
        data = sampled_data
        print(f" Sampling completed: {data.x.size(0)} nodes, {data.edge_index.size(1)} edges")
    print(" Creating model...")
    actual_feature_dim = data.x.size(1)
    print(f" Actual feature dimension: {actual_feature_dim}")
    model = create_model_from_config(config, actual_feature_dim, 2)
    print(f" Model created: {sum(p.numel() for p in model.parameters())} parameters")
    print("Using simplified training (skipping GraphDRO trainer)...")
    return train_simplified_single(config, data)


def train_simplified_single(config: Dict[str, Any], data) -> Dict[str, Any]:
    print(" Using simplified training (single GPU version)...")
    device_config = config['training'].get('device', 'cuda')
    main_device = torch.device(device_config)
    print(f"Using device: {main_device}")
    if data.x.size(0) > 10000:
        print(f" Large dataset detected ({data.x.size(0)} nodes), performing sampling...")
        sample_size = min(10000, int(data.x.size(0) * 0.6))
        sample_indices = torch.randperm(data.x.size(0))[:sample_size]
        sampled_data = data.clone()
        sampled_data.x = data.x[sample_indices]
        sampled_data.y = data.y[sample_indices]
        sampled_data.s = data.s[sample_indices]
        edge_index_cpu = data.edge_index.cpu()
        sample_indices_cpu = sample_indices.cpu()
        edge_mask = torch.isin(edge_index_cpu[0], sample_indices_cpu) & torch.isin(edge_index_cpu[1], sample_indices_cpu)
        sampled_data.edge_index = edge_index_cpu[:, edge_mask]
        idx_map = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(sample_indices_cpu)}
        new_edge_index = []
        for edge in sampled_data.edge_index.t():
            new_edge_index.append([idx_map[edge[0].item()], idx_map[edge[1].item()]])
        sampled_data.edge_index = torch.tensor(new_edge_index, dtype=torch.long).t()
        n_sampled = sampled_data.x.size(0)
        train_size = int(0.6 * n_sampled)
        val_size = int(0.2 * n_sampled)
        indices = torch.randperm(n_sampled)
        train_mask = torch.zeros(n_sampled, dtype=torch.bool)
        val_mask = torch.zeros(n_sampled, dtype=torch.bool)
        test_mask = torch.zeros(n_sampled, dtype=torch.bool)
        train_mask[indices[:train_size]] = True
        val_mask[indices[train_size:train_size + val_size]] = True
        test_mask[indices[train_size + val_size:]] = True
        sampled_data.train_mask = train_mask
        sampled_data.val_mask = val_mask
        sampled_data.test_mask = test_mask
        data = sampled_data
        print(f" Sampling completed: {data.x.size(0)} nodes, {data.edge_index.size(1)} edges")
    data = data.to(main_device)
    data.x = data.x.detach()
    if hasattr(data, 's'):
        data.s = data.s.detach()
    actual_feature_dim = data.x.size(1)
    print(f" Actual feature dimension: {actual_feature_dim}")
    model = create_model_from_config(config, actual_feature_dim, 2).to(main_device)
    optimizer = optim.Adam(model.parameters(), lr=config['optimizer']['lr'])
    train_labels = data.y[data.train_mask]
    class_counts = torch.bincount(train_labels.long())
    total_samples = class_counts.sum()
    class_weights = total_samples / (len(class_counts) * class_counts.float())
    print(f"  Class distribution: {class_counts.tolist()}")
    print(f"  Class weights: {class_weights.tolist()}")
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(main_device))
    print(" Using single GPU training strategy...")
    model.train()
    start_time = time.time()
    num_epochs = min(20, config['training']['num_epochs'])
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        out = model(data.edge_index, data.x)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        if epoch % 5 == 0:
            torch.cuda.empty_cache()
        if (epoch + 1) % 5 == 0:
            model.eval()
            with torch.no_grad():
                val_out = model(data.edge_index, data.x)
                val_pred = val_out[data.val_mask].argmax(dim=1)
                val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
                print(f"Epoch {epoch+1:3d}: Loss={loss:.4f}, Val_Acc={val_acc:.4f}")
            model.train()
    training_time = time.time() - start_time
    print(f" Single GPU training completed! Time taken: {training_time:.2f} seconds")
    save_path = Path(config['logging']['save_path'])
    save_path.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), save_path / 'best_model.pt')
    print(f" Model saved to: {save_path / 'best_model.pt'}")
    return {'final_metrics': {'accuracy': val_acc}, 'training_time': training_time}


def evaluate_model(config: Dict[str, Any], model_path: str) -> Dict[str, Any]:
    print(" Starting model evaluation...")
    data, in_dim, out_dim = load_dataset(config)
    print(f" Evaluation dataset loaded: {in_dim} features, {out_dim} classes")
    if data.x.size(0) > 10000:
        print(f" Large dataset detected during evaluation ({data.x.size(0)} nodes), performing sampling...")
        sample_size = min(12000, int(data.x.size(0) * 0.6))
        sample_indices = torch.randperm(data.x.size(0))[:sample_size]
        sampled_data = data.clone()
        sampled_data.x = data.x[sample_indices]
        sampled_data.y = data.y[sample_indices]
        sampled_data.s = data.s[sample_indices]
        edge_index_cpu = data.edge_index.cpu()
        sample_indices_cpu = sample_indices.cpu()
        edge_mask = torch.isin(edge_index_cpu[0], sample_indices_cpu) & torch.isin(edge_index_cpu[1], sample_indices_cpu)
        sampled_data.edge_index = edge_index_cpu[:, edge_mask]
        idx_map = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(sample_indices_cpu)}
        new_edge_index = []
        for edge in sampled_data.edge_index.t():
            new_edge_index.append([idx_map[edge[0].item()], idx_map[edge[1].item()]])
        sampled_data.edge_index = torch.tensor(new_edge_index, dtype=torch.long).t()
        n_sampled = sampled_data.x.size(0)
        train_size = int(0.6 * n_sampled)
        val_size = int(0.2 * n_sampled)
        indices = torch.randperm(n_sampled)
        train_mask = torch.zeros(n_sampled, dtype=torch.bool)
        val_mask = torch.zeros(n_sampled, dtype=torch.bool)
        test_mask = torch.zeros(n_sampled, dtype=torch.bool)
        train_mask[indices[:train_size]] = True
        val_mask[indices[train_size:train_size + val_size]] = True
        test_mask[indices[train_size + val_size:]] = True
        sampled_data.train_mask = train_mask
        sampled_data.val_mask = val_mask
        sampled_data.test_mask = test_mask
        data = sampled_data
        print(f" Evaluation sampling completed: {data.x.size(0)} nodes, {data.edge_index.size(1)} edges")
    actual_feature_dim = data.x.size(1)
    print(f" Evaluation actual feature dimension: {actual_feature_dim}")
    model = create_model_from_config(config, actual_feature_dim, 2)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    device = torch.device(config['training']['device'])
    data = data.to(device)
    model = model.to(device)
    with torch.no_grad():
        out = model(data.edge_index, data.x)
        pred = out[data.test_mask].argmax(dim=1)
        acc = (pred == data.y[data.test_mask]).float().mean().item()
        print(f" Evaluation accuracy: {acc:.4f}")
    return {'accuracy': acc}


def experiment(config: Dict[str, Any]) -> Dict[str, Any]:
    print(" Starting full experiment...")
    training_results = train_model(config)
    model_path = Path(config['logging']['save_path']) / 'best_model.pt'
    eval_results = evaluate_model(config, str(model_path))
    results = {
        'training': training_results,
        'evaluation': eval_results
    }
    save_path = Path(config['logging']['save_path'])
    import yaml
    with open(save_path / 'experiment_results.yaml', 'w') as f:
        yaml.dump(results, f)
    print(f" Experiment results saved to: {save_path / 'experiment_results.yaml'}")
    return results


def main():
    parser = argparse.ArgumentParser(description='GraphDRO Comparison Experiment')
    parser.add_argument('command', choices=['train', 'eval', 'experiment', 'comparison', 'list-datasets'],
                       help='Command to execute')
    parser.add_argument('--config', type=str, help='Path to config file (except for list-datasets)')
    parser.add_argument('--dataset', type=str, choices=['german', 'bail', 'credit', 'pokec_z', 'pokec_n', 'nba'],
                       help='Specify dataset (optional, overrides config)')
    parser.add_argument('--model', type=str, choices=['GCN', 'GraphSAGE', 'GIN', 'GAT'],
                       help='Specify model type (optional, overrides config)')
    parser.add_argument('--model-path', type=str, help='Model file path (for eval only)')
    parser.add_argument('--device', type=str, default='cuda', help='Device')
    parser.add_argument('--seed', type=int, help='Random seed')
    parser.add_argument('--verbose', action='store_true', help='Verbose output')
    parser.add_argument('--num-runs', type=int, default=3, help='Number of runs for comparison experiment')
    parser.add_argument('--num-epochs', type=int, default=60, help='Number of epochs for each training')
    parser.add_argument('--eps-e', type=float, default=0.5, help='Edge perturbation ratio (0-1)')
    parser.add_argument('--eps-x', type=float, default=0.15, help='Feature perturbation magnitude')
    parser.add_argument('--eps-l', type=float, default=0.1, help='Label perturbation ratio (0-1)')
    parser.add_argument('--gamma', type=float, default=0.2, help='Sensitive attribute perturbation ratio (0-1)')
    args = parser.parse_args()

    if args.command == 'list-datasets':
        if not args.config:
            print("Error: list-datasets command requires --config argument")
            return
        show_supported_datasets(args.config)
        return

    if not args.config:
        print("Error: --config argument is required")
        return

    config = load_config(args.config, args.dataset, args.model)

    if args.device:
        config['training']['device'] = args.device
    if args.seed:
        config['seed'] = args.seed
    if args.verbose:
        config['logging']['verbose'] = True

    if args.command == 'list-datasets':
        show_supported_datasets(args.config)
        return
    elif args.command == 'train':
        results = train_model(config)
    elif args.command == 'eval':
        if not args.model_path:
            print("Error: eval command requires --model-path argument")
            return
        results = evaluate_model(config, args.model_path)
    elif args.command == 'experiment':
        results = experiment(config)
    elif args.command == 'comparison':
        attack_params = {
            'eps_e': args.eps_e,
            'eps_x': args.eps_x,
            'eps_l': args.eps_l,
            'gamma': args.gamma
        }
        results = comparison_experiment(config, args.num_runs, args.num_epochs, attack_params)
    print(" Experiment completed!")


if __name__ == "__main__":
    main()