"""
Simple OGB-MolHIV Benchmark: GraGR vs Gradient Conflict Resolution Methods
=========================================================================

This module implements a simplified benchmark using REAL OGB-MolHIV dataset
with proper dimension handling for working results.

REAL DATASETS USED:
1. OGB-MolHIV: 41,127 molecular graphs for HIV inhibition prediction
2. Additional molecular properties derived from the same graphs

This uses GENUINE real datasets with working dimension handling.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import pandas as pd
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Import real dataset loaders
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data import DataLoader, Data
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool

class SimpleOGBModel(nn.Module):
    """Simple model for OGB-MolHIV multi-task datasets."""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_tasks: int, 
                 task_types: List[str], num_classes_per_task: List[int] = None):
        super().__init__()
        self.num_tasks = num_tasks
        self.task_types = task_types
        self.num_classes_per_task = num_classes_per_task or [2] * num_tasks
        
        # Simple shared layers
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Task-specific heads
        self.task_heads = nn.ModuleList()
        for i in range(num_tasks):
            if task_types[i] == 'classification':
                self.task_heads.append(nn.Linear(hidden_dim, self.num_classes_per_task[i]))
            else:  # regression
                self.task_heads.append(nn.Linear(hidden_dim, 1))
    
    def forward(self, x, batch):
        # Simple global mean pooling
        graph_repr = global_mean_pool(x, batch)  # [batch_size, hidden_dim]
        
        # Shared features
        shared_features = self.shared(graph_repr)
        
        # Task-specific predictions
        predictions = [head(shared_features) for head in self.task_heads]
        return predictions

class GradientConflictResolver:
    """Base class for gradient conflict resolution methods."""
    
    def __init__(self, method_name: str):
        self.method_name = method_name
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        raise NotImplementedError

class VanillaAverage(GradientConflictResolver):
    """Vanilla Average (GD) - Simple average of gradients."""
    
    def __init__(self):
        super().__init__("Vanilla Average")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        return torch.stack(padded_gradients).mean(dim=0)

class MGDA(GradientConflictResolver):
    """Multiple Gradient Descent Algorithm."""
    
    def __init__(self):
        super().__init__("MGDA")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        grad_norms = [torch.norm(grad) for grad in task_gradients]
        min_idx = torch.argmin(torch.tensor(grad_norms))
        return task_gradients[min_idx]

class PCGrad(GradientConflictResolver):
    """Projecting Conflicting Gradients."""
    
    def __init__(self):
        super().__init__("PCGrad")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        resolved_grad = padded_gradients[0].clone()
        
        for i in range(1, len(padded_gradients)):
            grad_i = padded_gradients[i]
            cos_sim = F.cosine_similarity(resolved_grad.flatten(), grad_i.flatten(), dim=0)
            
            if cos_sim < 0:  # Conflicting gradients
                proj = (resolved_grad.flatten() @ grad_i.flatten()) / (resolved_grad.flatten() @ resolved_grad.flatten())
                grad_i_proj = grad_i - proj * resolved_grad
                resolved_grad = resolved_grad + grad_i_proj
            else:
                resolved_grad = resolved_grad + grad_i
        
        return resolved_grad / len(padded_gradients)

class CAGrad(GradientConflictResolver):
    """Conflict-Averse Gradient Descent."""
    
    def __init__(self, rho: float = 0.4):
        super().__init__("CAGrad")
        self.rho = rho
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        avg_grad = torch.stack(padded_gradients).mean(dim=0)
        best_grad = avg_grad
        best_worst_improvement = float('inf')
        
        for i in range(len(padded_gradients)):
            grad_i = padded_gradients[i]
            worst_improvement = float('inf')
            
            for j in range(len(padded_gradients)):
                if i != j:
                    grad_j = padded_gradients[j]
                    improvement = grad_i.flatten() @ grad_j.flatten()
                    worst_improvement = min(worst_improvement, improvement)
            
            if worst_improvement > best_worst_improvement:
                best_worst_improvement = worst_improvement
                best_grad = grad_i
        
        return best_grad

class GradNorm(GradientConflictResolver):
    """Gradient Normalization."""
    
    def __init__(self, alpha: float = 1.5):
        super().__init__("GradNorm")
        self.alpha = alpha
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        grad_norms = [torch.norm(grad) for grad in padded_gradients]
        avg_grad_norm = torch.stack(grad_norms).mean()
        
        avg_loss = torch.stack(task_losses).mean()
        loss_ratios = [loss / avg_loss for loss in task_losses]
        
        target_norms = [avg_grad_norm * (ratio ** self.alpha) for ratio in loss_ratios]
        
        weights = [target_norm / grad_norm for target_norm, grad_norm in zip(target_norms, grad_norms)]
        
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad / len(padded_gradients)

class GraGRCore(GradientConflictResolver):
    """GraGR Core - Gradient-guided reasoning."""
    
    def __init__(self):
        super().__init__("GraGR Core")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        weights = []
        for i, grad in enumerate(padded_gradients):
            base_weight = 1.0 / (torch.norm(grad) + 1e-8)
            
            conflict_score = 0
            for j, other_grad in enumerate(padded_gradients):
                if i != j:
                    cos_sim = F.cosine_similarity(grad.flatten(), other_grad.flatten(), dim=0)
                    if cos_sim < 0:  # Conflicting
                        conflict_score += abs(cos_sim)
            
            adjusted_weight = base_weight * (1 + conflict_score)
            weights.append(adjusted_weight)
        
        weights = torch.tensor(weights)
        weights = weights / weights.sum()
        
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad

class GraGRPlusPlus(GradientConflictResolver):
    """GraGR++ - Enhanced gradient-guided reasoning."""
    
    def __init__(self):
        super().__init__("GraGR++")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        weights = []
        for i, grad in enumerate(padded_gradients):
            base_weight = 1.0 / (torch.norm(grad) + 1e-8)
            
            conflict_score = 0
            alignment_score = 0
            
            for j, other_grad in enumerate(padded_gradients):
                if i != j:
                    cos_sim = F.cosine_similarity(grad.flatten(), other_grad.flatten(), dim=0)
                    if cos_sim < 0:  # Conflicting
                        conflict_score += abs(cos_sim)
                    else:  # Aligned
                        alignment_score += cos_sim
            
            adaptive_weight = base_weight * (1 + 2 * conflict_score + 0.5 * alignment_score)
            weights.append(adaptive_weight)
        
        weights = torch.tensor(weights)
        weights = weights / weights.sum()
        
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad

def load_ogb_molhiv_simple():
    """Load REAL OGB-MolHIV dataset with simple multi-task setup."""
    print("Loading REAL OGB-MolHIV dataset with simple multi-task setup...")
    
    # Load OGB-MolHIV dataset
    dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root='../dataset/')
    print(f"✓ OGB-MolHIV loaded: {len(dataset)} graphs")
    
    # Create multi-task data (use first 1000 samples for speed)
    multi_task_data = []
    for i, data in enumerate(dataset):
        if i >= 1000:  # Limit for speed
            break
        
        # Task 1: Original HIV inhibition prediction (binary classification)
        hiv_inhibition = data.y.squeeze()  # Shape: [1] -> scalar
        
        # Task 2: Molecular weight prediction (regression)
        # Approximate molecular weight from node features (atom types)
        atom_weights = {
            0: 1.008,   # H
            1: 12.011,  # C
            2: 14.007,  # N
            3: 15.999,  # O
            4: 18.998,  # F
            5: 30.974,  # P
            6: 32.065,  # S
            7: 35.453,  # Cl
            8: 79.904   # Br
        }
        
        # Calculate molecular weight
        mol_weight = 0
        for atom_type in data.x[:, 0]:  # First column is atom type
            atom_type_int = atom_type.item()
            if atom_type_int in atom_weights:
                mol_weight += atom_weights[atom_type_int]
        
        # Normalize molecular weight
        mol_weight_normalized = (mol_weight - 100) / 100  # Rough normalization
        
        # Task 3: Graph size classification (small, medium, large)
        num_nodes = data.x.size(0)
        if num_nodes < 10:
            graph_size = 0  # Small
        elif num_nodes < 20:
            graph_size = 1  # Medium
        else:
            graph_size = 2  # Large
        
        # Task 4: Graph density classification (sparse, medium, dense)
        num_edges = data.edge_index.size(1)
        max_edges = num_nodes * (num_nodes - 1) // 2
        density = num_edges / max_edges if max_edges > 0 else 0
        
        if density < 0.1:
            graph_density = 0  # Sparse
        elif density < 0.3:
            graph_density = 1  # Medium
        else:
            graph_density = 2  # Dense
        
        # Task 5: Ring count prediction (regression)
        # Approximate ring count from graph structure
        ring_count = max(0, (num_edges - num_nodes + 1) / 2)  # Rough approximation
        ring_count_normalized = ring_count / 10  # Normalize
        
        # Create multi-task labels as a single tensor
        multi_task_labels = torch.tensor([
            hiv_inhibition.item(),  # Task 1: HIV inhibition
            mol_weight_normalized,  # Task 2: Molecular weight
            graph_size,  # Task 3: Graph size
            graph_density,  # Task 4: Graph density
            ring_count_normalized  # Task 5: Ring count
        ])
        
        # Create new data object
        multi_task_data.append(Data(
            x=data.x.float(),  # Ensure float type
            edge_index=data.edge_index,
            edge_attr=data.edge_attr.float(),  # Ensure float type
            y=multi_task_labels,  # Multi-task labels
            pos=data.pos if hasattr(data, 'pos') else None
        ))
    
    # Split dataset
    train_size = int(0.6 * len(multi_task_data))
    val_size = int(0.2 * len(multi_task_data))
    
    train_data = multi_task_data[:train_size]
    val_data = multi_task_data[train_size:train_size + val_size]
    test_data = multi_task_data[train_size + val_size:]
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
    
    # Dataset info
    dataset_info = {
        'name': 'OGB_MolHIV_Simple',
        'num_tasks': 5,
        'task_types': ['classification', 'regression', 'classification', 'classification', 'regression'],
        'num_classes': [2, 1, 3, 3, 1],
        'num_graphs': len(multi_task_data),
        'train_size': len(train_data),
        'val_size': len(val_data),
        'test_size': len(test_data),
        'task_names': [
            'hiv_inhibition', 'molecular_weight', 'graph_size', 'graph_density', 'ring_count'
        ]
    }
    
    print(f"✓ OGB-MolHIV Simple loaded: {len(multi_task_data)} graphs, {dataset_info['num_tasks']} tasks")
    return train_loader, val_loader, test_loader, dataset_info

def train_model(model, data_loader, resolver, dataset_info, num_epochs=20):
    """Train model with specified gradient conflict resolution method."""
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    losses = []
    accuracies = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        epoch_acc = 0
        num_batches = 0
        
        for batch in data_loader:
            optimizer.zero_grad()
            
            # Forward pass
            predictions = model(batch.x, batch.batch)
            
            # Compute task losses
            task_losses = []
            task_gradients = []
            task_accuracies = []
            
            for i in range(dataset_info['num_tasks']):
                if dataset_info['task_types'][i] == 'classification':
                    # Handle multi-task labels correctly
                    if batch.y.dim() == 1:
                        # Single task
                        task_labels = batch.y.long()
                        loss = F.cross_entropy(predictions[i], task_labels)
                        pred = predictions[i].argmax(dim=1)
                        acc = (pred == task_labels).float().mean()
                    else:
                        # Multi-task - use specific task labels
                        task_labels = batch.y[:, i].long()  # Extract specific task labels
                        loss = F.cross_entropy(predictions[i], task_labels)
                        pred = predictions[i].argmax(dim=1)
                        acc = (pred == task_labels).float().mean()
                    task_accuracies.append(acc.item())
                else:  # regression
                    if batch.y.dim() == 1:
                        # Single task
                        task_labels = batch.y
                        loss = F.mse_loss(predictions[i].squeeze(), task_labels.float())
                    else:
                        # Multi-task - use specific task labels
                        task_labels = batch.y[:, i]  # Extract specific task labels
                        loss = F.mse_loss(predictions[i].squeeze(), task_labels.float())
                    task_accuracies.append(0)  # Not applicable for regression
                
                task_losses.append(loss)
                
                # Compute gradient for this task
                grad = torch.autograd.grad(loss, model.parameters(), 
                                         retain_graph=True, create_graph=True, allow_unused=True)
                valid_grads = [g for g in grad if g is not None]
                if valid_grads:
                    task_gradients.append(torch.cat([g.flatten() for g in valid_grads]))
                else:
                    task_gradients.append(torch.zeros(sum(p.numel() for p in model.parameters())))
            
            # Resolve gradient conflicts
            resolved_grad = resolver.resolve_conflicts(task_gradients, task_losses)
            
            # Apply resolved gradient
            param_idx = 0
            for param in model.parameters():
                param_size = param.numel()
                if param_idx + param_size <= resolved_grad.size(0):
                    param.grad = resolved_grad[param_idx:param_idx + param_size].view(param.shape)
                else:
                    param.grad = torch.zeros_like(param)
                param_idx += param_size
            
            optimizer.step()
            
            total_loss = sum(task_losses).item()
            avg_accuracy = np.mean([acc for acc in task_accuracies if acc > 0])
            
            epoch_loss += total_loss
            epoch_acc += avg_accuracy
            num_batches += 1
        
        losses.append(epoch_loss / num_batches)
        accuracies.append(epoch_acc / num_batches)
    
    return losses, accuracies

class SimpleOGBBenchmarkRunner:
    """Runner for simple OGB-MolHIV benchmark."""
    
    def __init__(self):
        self.results = {}
        self.results_table = []
    
    def run_simple_ogb_benchmark(self):
        """Run benchmark on REAL OGB-MolHIV dataset."""
        print("Simple OGB-MolHIV Benchmark: GraGR vs Gradient Conflict Resolution Methods")
        print("=" * 100)
        print("Using ACTUAL REAL OGB-MolHIV dataset with working multi-task setup")
        print("=" * 100)
        
        # Load real dataset
        train_loader, val_loader, test_loader, dataset_info = load_ogb_molhiv_simple()
        
        print(f"Dataset: {dataset_info['name']}")
        print(f"Tasks: {dataset_info['num_tasks']}")
        print(f"Task Types: {dataset_info['task_types']}")
        print(f"Task Names: {dataset_info['task_names']}")
        print(f"Train Size: {dataset_info['train_size']}")
        print(f"Val Size: {dataset_info['val_size']}")
        print(f"Test Size: {dataset_info['test_size']}")
        
        # Get input dimension from first batch
        for batch in train_loader:
            input_dim = batch.x.size(1)
            break
        
        print(f"Input Dimension: {input_dim}")
        
        methods = {
            'Vanilla Average': VanillaAverage(),
            'MGDA': MGDA(),
            'PCGrad': PCGrad(),
            'CAGrad': CAGrad(),
            'GradNorm': GradNorm(),
            'GraGR Core': GraGRCore(),
            'GraGR++': GraGRPlusPlus()
        }
        
        dataset_results = {}
        
        for method_name, resolver in methods.items():
            print(f"\n--- {method_name} ---")
            
            # Create model
            model = SimpleOGBModel(
                input_dim=input_dim,
                hidden_dim=64,
                num_tasks=dataset_info['num_tasks'],
                task_types=dataset_info['task_types'],
                num_classes_per_task=dataset_info['num_classes']
            )
            
            # Train model
            start_time = time.time()
            losses, accuracies = train_model(model, train_loader, resolver, dataset_info, num_epochs=15)
            end_time = time.time()
            
            # Store results
            dataset_results[method_name] = {
                'final_loss': losses[-1],
                'final_accuracy': accuracies[-1],
                'training_time': end_time - start_time,
                'losses': losses,
                'accuracies': accuracies
            }
            
            # Add to results table
            self.results_table.append({
                'Dataset': 'OGB_MolHIV_Simple',
                'Method': method_name,
                'Final_Loss': losses[-1],
                'Final_Accuracy': accuracies[-1],
                'Training_Time': end_time - start_time,
                'Num_Tasks': dataset_info['num_tasks'],
                'Data_Size': dataset_info['train_size'] + dataset_info['val_size'] + dataset_info['test_size']
            })
            
            print(f"Final Loss: {losses[-1]:.4f}")
            print(f"Final Accuracy: {accuracies[-1]:.4f}")
            print(f"Training Time: {end_time - start_time:.2f}s")
        
        self.results['OGB_MolHIV_Simple'] = dataset_results
        self.print_simple_ogb_summary()
        self.save_results_table()
    
    def print_simple_ogb_summary(self):
        """Print simple OGB benchmark summary."""
        print(f"\n{'='*100}")
        print("Simple OGB-MolHIV Benchmark Summary")
        print(f"{'='*100}")
        
        methods = ['Vanilla Average', 'MGDA', 'PCGrad', 'CAGrad', 'GradNorm', 'GraGR Core', 'GraGR++']
        
        print(f"\n{'Dataset':<25} {'Method':<15} {'Final Loss':<12} {'Final Acc':<12} {'Time (s)':<10}")
        print("-" * 80)
        
        for dataset_name, results in self.results.items():
            print(f"\n{dataset_name}:")
            for method in methods:
                if method in results:
                    result = results[method]
                    print(f"{'':<25} {method:<15} {result['final_loss']:<12.4f} "
                          f"{result['final_accuracy']:<12.4f} {result['training_time']:<10.2f}")
        
        # Overall ranking
        print(f"\n{'='*100}")
        print("OVERALL RANKING BY ACCURACY (OGB-MolHIV Simple)")
        print(f"{'='*100}")
        
        method_scores = {}
        for method in methods:
            scores = []
            for dataset_name, results in self.results.items():
                if method in results:
                    scores.append(results[method]['final_accuracy'])
            if scores:
                method_scores[method] = np.mean(scores)
        
        sorted_methods = sorted(method_scores.items(), key=lambda x: x[1], reverse=True)
        
        print(f"\n{'Rank':<5} {'Method':<15} {'Avg Accuracy':<15} {'Performance':<20}")
        print("-" * 70)
        
        for rank, (method, score) in enumerate(sorted_methods, 1):
            if rank == 1:
                performance = "🥇 Best"
            elif rank == 2:
                performance = "🥈 Second"
            elif rank == 3:
                performance = "🥉 Third"
            else:
                performance = "📊 Other"
            
            print(f"{rank:<5} {method:<15} {score:<15.4f} {performance:<20}")
        
        # GraGR performance analysis
        print(f"\n{'='*100}")
        print("GraGR PERFORMANCE ANALYSIS (OGB-MolHIV Simple)")
        print(f"{'='*100}")
        
        gragr_core_scores = []
        gragr_plus_scores = []
        
        for dataset_name, results in self.results.items():
            if 'GraGR Core' in results:
                gragr_core_scores.append(results['GraGR Core']['final_accuracy'])
            if 'GraGR++' in results:
                gragr_plus_scores.append(results['GraGR++']['final_accuracy'])
        
        if gragr_core_scores:
            print(f"GraGR Core Average Accuracy: {np.mean(gragr_core_scores):.4f}")
        if gragr_plus_scores:
            print(f"GraGR++ Average Accuracy: {np.mean(gragr_plus_scores):.4f}")
        
        # Count wins
        gragr_core_wins = 0
        gragr_plus_wins = 0
        total_datasets = len(self.results)
        
        for dataset_name, results in self.results.items():
            dataset_scores = [(method, result['final_accuracy']) for method, result in results.items()]
            dataset_scores.sort(key=lambda x: x[1], reverse=True)
            
            if dataset_scores[0][0] == 'GraGR Core':
                gragr_core_wins += 1
            if dataset_scores[0][0] == 'GraGR++':
                gragr_plus_wins += 1
        
        print(f"GraGR Core wins: {gragr_core_wins}/{total_datasets} datasets")
        print(f"GraGR++ wins: {gragr_plus_wins}/{total_datasets} datasets")
        
        if gragr_core_wins > 0 or gragr_plus_wins > 0:
            print("✅ GraGR is performing well on OGB-MolHIV!")
        else:
            print("❌ GraGR needs improvement on OGB-MolHIV")
    
    def save_results_table(self):
        """Save results to CSV table."""
        df = pd.DataFrame(self.results_table)
        df.to_csv('simple_ogb_molhiv_results.csv', index=False)
        print(f"\n{'='*100}")
        print("Simple OGB-MolHIV Results SAVED TO: simple_ogb_molhiv_results.csv")
        print(f"{'='*100}")
        
        # Print summary table
        print("\nSUMMARY TABLE (OGB-MolHIV Simple):")
        if len(df) > 0:
            print(df.groupby('Method').agg({
                'Final_Accuracy': ['mean', 'std', 'min', 'max'],
                'Final_Loss': ['mean', 'std'],
                'Training_Time': ['mean', 'std']
            }).round(4))

def main():
    """Run simple OGB-MolHIV benchmark."""
    runner = SimpleOGBBenchmarkRunner()
    runner.run_simple_ogb_benchmark()

if __name__ == "__main__":
    main()
