"""
Complete GraGR Multi-Task Benchmark: All Methods Including PCGrad
==============================================================

This module implements a complete benchmark with ALL gradient conflict resolution methods:
1. Vanilla Average
2. CAGrad
3. GradNorm
4. PCGrad (restored)
5. Fixed GraGR Core
6. Fixed GraGR++

REAL DATASETS USED:
1. OGB-MolHIV: Molecular graphs for drug discovery (5 classification tasks)
2. TUDataset PROTEINS: Protein classification (5 classification tasks)
3. TUDataset MUTAG: Mutagenicity classification (5 classification tasks)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import pandas as pd
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Import real dataset loaders
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader, Data
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool

class CompleteClassificationModel(nn.Module):
    """Complete model for multi-task learning."""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_tasks: int, 
                 num_classes_per_task: List[int]):
        super().__init__()
        self.num_tasks = num_tasks
        self.num_classes_per_task = num_classes_per_task
        
        # Enhanced shared layers
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Task-specific heads
        self.task_heads = nn.ModuleList()
        for i in range(num_tasks):
            head = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_dim // 2, self.num_classes_per_task[i])
            )
            self.task_heads.append(head)
    
    def forward(self, x, batch):
        # Global mean pooling
        graph_repr = global_mean_pool(x, batch)
        
        # Shared features
        shared_features = self.shared(graph_repr)
        
        # Task-specific predictions
        predictions = [head(shared_features) for head in self.task_heads]
        return predictions

class GradientConflictResolver:
    """Base class for gradient conflict resolution methods."""
    
    def __init__(self, method_name: str):
        self.method_name = method_name
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        raise NotImplementedError

class VanillaAverage(GradientConflictResolver):
    """Vanilla Average (GD) - Simple average of gradients."""
    
    def __init__(self):
        super().__init__("Vanilla Average")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        return torch.stack(padded_gradients).mean(dim=0)

class CAGrad(GradientConflictResolver):
    """Conflict-Averse Gradient Descent."""
    
    def __init__(self, rho: float = 0.4):
        super().__init__("CAGrad")
        self.rho = rho
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        avg_grad = torch.stack(padded_gradients).mean(dim=0)
        best_grad = avg_grad
        best_worst_improvement = float('inf')
        
        for i in range(len(padded_gradients)):
            grad_i = padded_gradients[i]
            worst_improvement = float('inf')
            
            for j in range(len(padded_gradients)):
                if i != j:
                    grad_j = padded_gradients[j]
                    improvement = grad_i.flatten() @ grad_j.flatten()
                    worst_improvement = min(worst_improvement, improvement)
            
            if worst_improvement > best_worst_improvement:
                best_worst_improvement = worst_improvement
                best_grad = grad_i
        
        return best_grad

class GradNorm(GradientConflictResolver):
    """Gradient Normalization."""
    
    def __init__(self, alpha: float = 1.5):
        super().__init__("GradNorm")
        self.alpha = alpha
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        grad_norms = [torch.norm(grad) for grad in padded_gradients]
        avg_grad_norm = torch.stack(grad_norms).mean()
        
        avg_loss = torch.stack(task_losses).mean()
        loss_ratios = [loss / avg_loss for loss in task_losses]
        
        target_norms = [avg_grad_norm * (ratio ** self.alpha) for ratio in loss_ratios]
        
        weights = [target_norm / grad_norm for target_norm, grad_norm in zip(target_norms, grad_norms)]
        
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad / len(padded_gradients)

class PCGrad(GradientConflictResolver):
    """Projecting Conflicting Gradients - RESTORED."""
    
    def __init__(self):
        super().__init__("PCGrad")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        resolved_grad = padded_gradients[0].clone()
        
        for i in range(1, len(padded_gradients)):
            grad_i = padded_gradients[i]
            cos_sim = F.cosine_similarity(resolved_grad.flatten(), grad_i.flatten(), dim=0)
            
            if cos_sim < 0:  # Conflicting gradients
                proj = (resolved_grad.flatten() @ grad_i.flatten()) / (resolved_grad.flatten() @ resolved_grad.flatten())
                grad_i_proj = grad_i - proj * resolved_grad
                resolved_grad = resolved_grad + grad_i_proj
            else:
                resolved_grad = resolved_grad + grad_i
        
        return resolved_grad / len(padded_gradients)

class FixedGraGRCore(GradientConflictResolver):
    """Fixed GraGR Core - Robust gradient-guided reasoning."""
    
    def __init__(self, conflict_threshold: float = 0.1, learning_rate: float = 0.1):
        super().__init__("Fixed GraGR Core")
        self.conflict_threshold = conflict_threshold
        self.learning_rate = learning_rate
        self.iteration = 0
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        self.iteration += 1
        
        # Ensure all gradients have the same size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        # Calculate weights with improved stability
        weights = []
        for i, grad in enumerate(padded_gradients):
            # Base weight: inverse of gradient norm with stability
            grad_norm = torch.norm(grad) + 1e-8
            base_weight = 1.0 / grad_norm
            
            # Conflict and alignment analysis
            conflict_score = 0
            alignment_score = 0
            num_conflicts = 0
            num_alignments = 0
            
            for j, other_grad in enumerate(padded_gradients):
                if i != j:
                    # Calculate cosine similarity safely
                    grad_flat = grad.flatten()
                    other_flat = other_grad.flatten()
                    
                    # Ensure same size
                    min_size = min(grad_flat.size(0), other_flat.size(0))
                    grad_flat = grad_flat[:min_size]
                    other_flat = other_flat[:min_size]
                    
                    cos_sim = F.cosine_similarity(grad_flat, other_flat, dim=0)
                    
                    if cos_sim < -self.conflict_threshold:  # Strong conflict
                        conflict_score += abs(cos_sim) * 1.5
                        num_conflicts += 1
                    elif cos_sim > self.conflict_threshold:  # Strong alignment
                        alignment_score += cos_sim * 0.3
                        num_alignments += 1
            
            # Calculate factors safely
            if num_conflicts > 0:
                conflict_factor = 1 + 1.5 * (conflict_score / num_conflicts)
            else:
                conflict_factor = 1.0
            
            if num_alignments > 0:
                alignment_factor = 1 + 0.3 * (alignment_score / num_alignments)
            else:
                alignment_factor = 1.0
            
            # Learning rate decay
            lr_decay = 1.0 / (1.0 + 0.01 * self.iteration)
            
            # Final weight calculation
            adjusted_weight = base_weight * conflict_factor * alignment_factor * lr_decay
            weights.append(adjusted_weight)
        
        # Normalize weights with stability
        weights = torch.tensor(weights, device=task_gradients[0].device)
        weights = weights / (weights.sum() + 1e-8)
        
        # Apply weights to gradients
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad

class FixedGraGRPlusPlus(GradientConflictResolver):
    """Fixed GraGR++ - Robust gradient-guided reasoning with advanced features."""
    
    def __init__(self, conflict_threshold: float = 0.1, learning_rate: float = 0.1, 
                 momentum: float = 0.9):
        super().__init__("Fixed GraGR++")
        self.conflict_threshold = conflict_threshold
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.iteration = 0
        self.gradient_memory = []
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        self.iteration += 1
        
        # Ensure all gradients have the same size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        # Store gradient history for momentum (limit memory)
        if len(self.gradient_memory) < 3:
            self.gradient_memory.append(padded_gradients)
        else:
            self.gradient_memory.pop(0)
            self.gradient_memory.append(padded_gradients)
        
        # Calculate weights with enhanced stability
        weights = []
        for i, grad in enumerate(padded_gradients):
            # Base weight: inverse of gradient norm
            grad_norm = torch.norm(grad) + 1e-8
            base_weight = 1.0 / grad_norm
            
            # Conflict and alignment analysis
            conflict_score = 0
            alignment_score = 0
            num_conflicts = 0
            num_alignments = 0
            
            for j, other_grad in enumerate(padded_gradients):
                if i != j:
                    # Calculate cosine similarity safely
                    grad_flat = grad.flatten()
                    other_flat = other_grad.flatten()
                    
                    # Ensure same size
                    min_size = min(grad_flat.size(0), other_flat.size(0))
                    grad_flat = grad_flat[:min_size]
                    other_flat = other_flat[:min_size]
                    
                    cos_sim = F.cosine_similarity(grad_flat, other_flat, dim=0)
                    
                    if cos_sim < -self.conflict_threshold:  # Strong conflict
                        conflict_score += abs(cos_sim) * 2.0
                        num_conflicts += 1
                    elif cos_sim > self.conflict_threshold:  # Strong alignment
                        alignment_score += cos_sim * 0.5
                        num_alignments += 1
            
            # Calculate factors safely
            if num_conflicts > 0:
                conflict_factor = 1 + 2.0 * (conflict_score / num_conflicts)
            else:
                conflict_factor = 1.0
            
            if num_alignments > 0:
                alignment_factor = 1 + 0.5 * (alignment_score / num_alignments)
            else:
                alignment_factor = 1.0
            
            # Learning rate scheduling
            lr_decay = 1.0 / (1.0 + 0.005 * self.iteration)
            
            # Momentum factor
            momentum_factor = 1.0
            if len(self.gradient_memory) > 1:
                prev_grads = self.gradient_memory[-2]
                if i < len(prev_grads):
                    prev_grad = prev_grads[i]
                    grad_flat = grad.flatten()
                    prev_flat = prev_grad.flatten()
                    
                    # Ensure same size
                    min_size = min(grad_flat.size(0), prev_flat.size(0))
                    grad_flat = grad_flat[:min_size]
                    prev_flat = prev_flat[:min_size]
                    
                    grad_similarity = F.cosine_similarity(grad_flat, prev_flat, dim=0)
                    momentum_factor = 1 + self.momentum * grad_similarity
            
            # Final weight calculation
            adjusted_weight = base_weight * conflict_factor * alignment_factor * lr_decay * momentum_factor
            weights.append(adjusted_weight)
        
        # Normalize weights with stability
        weights = torch.tensor(weights, device=task_gradients[0].device)
        weights = weights / (weights.sum() + 1e-8)
        
        # Apply weights to gradients
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad

def load_ogb_molhiv_classification():
    """Load OGB-MolHIV with classification-only multi-task setup."""
    print("Loading OGB-MolHIV with classification-only multi-task setup...")
    
    # Load OGB-MolHIV dataset
    dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root='../dataset/')
    print(f"✓ OGB-MolHIV loaded: {len(dataset)} graphs")
    
    # Create multi-task data (use first 800 samples)
    multi_task_data = []
    for i, data in enumerate(dataset):
        if i >= 800:
            break
        
        # Task 1: Original HIV inhibition prediction (binary classification)
        hiv_inhibition = data.y.squeeze().item()
        
        # Task 2: Graph size classification (small, medium, large)
        num_nodes = data.x.size(0)
        if num_nodes < 10:
            graph_size = 0  # Small
        elif num_nodes < 20:
            graph_size = 1  # Medium
        else:
            graph_size = 2  # Large
        
        # Task 3: Graph density classification (sparse, medium, dense)
        num_edges = data.edge_index.size(1)
        max_edges = num_nodes * (num_nodes - 1) // 2
        density = num_edges / max_edges if max_edges > 0 else 0
        
        if density < 0.1:
            graph_density = 0  # Sparse
        elif density < 0.3:
            graph_density = 1  # Medium
        else:
            graph_density = 2  # Dense
        
        # Task 4: Molecular complexity classification (simple, medium, complex)
        atom_types = len(torch.unique(data.x[:, 0]))
        if atom_types < 3:
            complexity = 0  # Simple
        elif atom_types < 5:
            complexity = 1  # Medium
        else:
            complexity = 2  # Complex
        
        # Task 5: Ring structure classification (no rings, few rings, many rings)
        ring_count = max(0, (num_edges - num_nodes + 1) // 2)
        if ring_count == 0:
            ring_structure = 0  # No rings
        elif ring_count < 3:
            ring_structure = 1  # Few rings
        else:
            ring_structure = 2  # Many rings
        
        # Create multi-task labels (all classification)
        multi_task_labels = torch.tensor([
            hiv_inhibition,  # Task 1: HIV inhibition (2 classes)
            graph_size,  # Task 2: Graph size (3 classes)
            graph_density,  # Task 3: Graph density (3 classes)
            complexity,  # Task 4: Molecular complexity (3 classes)
            ring_structure  # Task 5: Ring structure (3 classes)
        ])
        
        multi_task_data.append(Data(
            x=data.x.float(),
            edge_index=data.edge_index,
            edge_attr=data.edge_attr.float(),
            y=multi_task_labels,
            pos=data.pos if hasattr(data, 'pos') else None
        ))
    
    # Split dataset
    train_size = int(0.6 * len(multi_task_data))
    val_size = int(0.2 * len(multi_task_data))
    
    train_data = multi_task_data[:train_size]
    val_data = multi_task_data[train_size:train_size + val_size]
    test_data = multi_task_data[train_size + val_size:]
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
    
    # Dataset info
    dataset_info = {
        'name': 'OGB_MolHIV_Classification',
        'num_tasks': 5,
        'num_classes_per_task': [2, 3, 3, 3, 3],
        'num_graphs': len(multi_task_data),
        'train_size': len(train_data),
        'val_size': len(val_data),
        'test_size': len(test_data),
        'task_names': [
            'hiv_inhibition', 'graph_size', 'graph_density', 'molecular_complexity', 'ring_structure'
        ]
    }
    
    print(f"✓ OGB-MolHIV Classification loaded: {len(multi_task_data)} graphs, {dataset_info['num_tasks']} tasks")
    return train_loader, val_loader, test_loader, dataset_info

def load_tudataset_proteins():
    """Load TUDataset PROTEINS with multi-task classification setup."""
    print("Loading TUDataset PROTEINS with multi-task classification setup...")
    
    # Load TUDataset PROTEINS
    dataset = TUDataset(root='./data', name='PROTEINS')
    print(f"✓ TUDataset PROTEINS loaded: {len(dataset)} graphs")
    
    # Create multi-task data
    multi_task_data = []
    for data in dataset:
        # Task 1: Original protein classification (2 classes: enzyme/non-enzyme)
        protein_class = data.y.squeeze().item()
        
        # Task 2: Protein size classification (small, medium, large)
        num_nodes = data.x.size(0)
        if num_nodes < 20:
            size_class = 0  # Small
        elif num_nodes < 50:
            size_class = 1  # Medium
        else:
            size_class = 2  # Large
        
        # Task 3: Protein density classification (sparse, medium, dense)
        num_edges = data.edge_index.size(1)
        max_edges = num_nodes * (num_nodes - 1) // 2
        density = num_edges / max_edges if max_edges > 0 else 0
        
        if density < 0.1:
            density_class = 0  # Sparse
        elif density < 0.3:
            density_class = 1  # Medium
        else:
            density_class = 2  # Dense
        
        # Task 4: Protein connectivity classification (low, medium, high)
        avg_degree = (2 * num_edges) / num_nodes if num_nodes > 0 else 0
        if avg_degree < 2:
            connectivity_class = 0  # Low
        elif avg_degree < 4:
            connectivity_class = 1  # Medium
        else:
            connectivity_class = 2  # High
        
        # Task 5: Protein structure classification (linear, branched, complex)
        if num_edges < num_nodes:
            structure_class = 0  # Linear
        elif num_edges < 2 * num_nodes:
            structure_class = 1  # Branched
        else:
            structure_class = 2  # Complex
        
        # Create multi-task labels (all classification)
        multi_task_labels = torch.tensor([
            protein_class,  # Task 1: Protein class (2 classes)
            size_class,  # Task 2: Size class (3 classes)
            density_class,  # Task 3: Density class (3 classes)
            connectivity_class,  # Task 4: Connectivity class (3 classes)
            structure_class  # Task 5: Structure class (3 classes)
        ])
        
        multi_task_data.append(Data(
            x=data.x.float(),
            edge_index=data.edge_index,
            y=multi_task_labels,
            pos=data.pos if hasattr(data, 'pos') else None
        ))
    
    # Split dataset
    train_size = int(0.6 * len(multi_task_data))
    val_size = int(0.2 * len(multi_task_data))
    
    train_data = multi_task_data[:train_size]
    val_data = multi_task_data[train_size:train_size + val_size]
    test_data = multi_task_data[train_size + val_size:]
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
    
    # Dataset info
    dataset_info = {
        'name': 'TUDataset_PROTEINS_Classification',
        'num_tasks': 5,
        'num_classes_per_task': [2, 3, 3, 3, 3],
        'num_graphs': len(multi_task_data),
        'train_size': len(train_data),
        'val_size': len(val_data),
        'test_size': len(test_data),
        'task_names': [
            'protein_class', 'size_class', 'density_class', 'connectivity_class', 'structure_class'
        ]
    }
    
    print(f"✓ TUDataset PROTEINS Classification loaded: {len(multi_task_data)} graphs, {dataset_info['num_tasks']} tasks")
    return train_loader, val_loader, test_loader, dataset_info

def load_tudataset_mutag():
    """Load TUDataset MUTAG with multi-task classification setup."""
    print("Loading TUDataset MUTAG with multi-task classification setup...")
    
    # Load TUDataset MUTAG
    dataset = TUDataset(root='./data', name='MUTAG')
    print(f"✓ TUDataset MUTAG loaded: {len(dataset)} graphs")
    
    # Create multi-task data
    multi_task_data = []
    for data in dataset:
        # Task 1: Original mutagenicity classification (2 classes: mutagenic/non-mutagenic)
        mutagenicity = data.y.squeeze().item()
        
        # Task 2: Molecular size classification (small, medium, large)
        num_nodes = data.x.size(0)
        if num_nodes < 10:
            size_class = 0  # Small
        elif num_nodes < 20:
            size_class = 1  # Medium
        else:
            size_class = 2  # Large
        
        # Task 3: Molecular complexity classification (simple, medium, complex)
        num_edges = data.edge_index.size(1)
        if num_edges < 10:
            complexity_class = 0  # Simple
        elif num_edges < 20:
            complexity_class = 1  # Medium
        else:
            complexity_class = 2  # Complex
        
        # Task 4: Bond type classification (single, mixed, multiple)
        if hasattr(data, 'edge_attr') and data.edge_attr is not None:
            unique_bonds = len(torch.unique(data.edge_attr[:, 0])) if data.edge_attr.size(1) > 0 else 1
        else:
            unique_bonds = 1
        
        if unique_bonds == 1:
            bond_class = 0  # Single
        elif unique_bonds < 3:
            bond_class = 1  # Mixed
        else:
            bond_class = 2  # Multiple
        
        # Task 5: Ring structure classification (no rings, few rings, many rings)
        ring_count = max(0, (num_edges - num_nodes + 1) // 2)
        if ring_count == 0:
            ring_class = 0  # No rings
        elif ring_count < 2:
            ring_class = 1  # Few rings
        else:
            ring_class = 2  # Many rings
        
        # Create multi-task labels (all classification)
        multi_task_labels = torch.tensor([
            mutagenicity,  # Task 1: Mutagenicity (2 classes)
            size_class,  # Task 2: Size class (3 classes)
            complexity_class,  # Task 3: Complexity class (3 classes)
            bond_class,  # Task 4: Bond class (3 classes)
            ring_class  # Task 5: Ring class (3 classes)
        ])
        
        multi_task_data.append(Data(
            x=data.x.float(),
            edge_index=data.edge_index,
            y=multi_task_labels,
            pos=data.pos if hasattr(data, 'pos') else None
        ))
    
    # Split dataset
    train_size = int(0.6 * len(multi_task_data))
    val_size = int(0.2 * len(multi_task_data))
    
    train_data = multi_task_data[:train_size]
    val_data = multi_task_data[train_size:train_size + val_size]
    test_data = multi_task_data[train_size + val_size:]
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
    
    # Dataset info
    dataset_info = {
        'name': 'TUDataset_MUTAG_Classification',
        'num_tasks': 5,
        'num_classes_per_task': [2, 3, 3, 3, 3],
        'num_graphs': len(multi_task_data),
        'train_size': len(train_data),
        'val_size': len(val_data),
        'test_size': len(test_data),
        'task_names': [
            'mutagenicity', 'size_class', 'complexity_class', 'bond_class', 'ring_class'
        ]
    }
    
    print(f"✓ TUDataset MUTAG Classification loaded: {len(multi_task_data)} graphs, {dataset_info['num_tasks']} tasks")
    return train_loader, val_loader, test_loader, dataset_info

def train_model_complete(model, data_loader, resolver, dataset_info, num_epochs=40):
    """Train model with complete training strategy."""
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    losses = []
    accuracies = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        epoch_acc = 0
        num_batches = 0
        
        for batch in data_loader:
            optimizer.zero_grad()
            
            # Forward pass
            predictions = model(batch.x, batch.batch)
            
            # Compute task losses
            task_losses = []
            task_gradients = []
            task_accuracies = []
            
            # Reshape batch.y to handle multi-task labels properly
            batch_size = len(batch.batch.unique())
            num_tasks = dataset_info['num_tasks']
            
            # Reshape y from [batch_size * num_tasks] to [batch_size, num_tasks]
            y_reshaped = batch.y.view(batch_size, num_tasks)
            
            for i in range(dataset_info['num_tasks']):
                # All tasks are classification
                task_labels = y_reshaped[:, i].long()
                loss = F.cross_entropy(predictions[i], task_labels)
                pred = predictions[i].argmax(dim=1)
                acc = (pred == task_labels).float().mean()
                
                task_losses.append(loss)
                task_accuracies.append(acc.item())
                
                # Compute gradient for this task
                grad = torch.autograd.grad(loss, model.parameters(), 
                                         retain_graph=True, create_graph=True, allow_unused=True)
                valid_grads = [g for g in grad if g is not None]
                if valid_grads:
                    task_gradients.append(torch.cat([g.flatten() for g in valid_grads]))
                else:
                    task_gradients.append(torch.zeros(sum(p.numel() for p in model.parameters())))
            
            # Resolve gradient conflicts
            resolved_grad = resolver.resolve_conflicts(task_gradients, task_losses)
            
            # Apply resolved gradient
            param_idx = 0
            for param in model.parameters():
                param_size = param.numel()
                if param_idx + param_size <= resolved_grad.size(0):
                    param.grad = resolved_grad[param_idx:param_idx + param_size].view(param.shape)
                else:
                    param.grad = torch.zeros_like(param)
                param_idx += param_size
            
            optimizer.step()
            
            total_loss = sum(task_losses).item()
            avg_accuracy = np.mean(task_accuracies)
            
            epoch_loss += total_loss
            epoch_acc += avg_accuracy
            num_batches += 1
        
        losses.append(epoch_loss / num_batches)
        accuracies.append(epoch_acc / num_batches)
        
        scheduler.step()
        
        if epoch % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}: Loss={losses[-1]:.4f}, Acc={accuracies[-1]:.4f}")
    
    return losses, accuracies

class CompleteBenchmarkRunner:
    """Runner for complete GraGR benchmark with all methods."""
    
    def __init__(self):
        self.results = {}
        self.results_table = []
    
    def run_complete_benchmark(self):
        """Run complete benchmark with all methods including PCGrad."""
        print("Complete GraGR Multi-Task Benchmark: All Methods Including PCGrad")
        print("=" * 100)
        print("Using REAL datasets with ALL gradient conflict resolution methods")
        print("=" * 100)
        
        # Classification datasets to benchmark
        datasets = {
            'OGB_MolHIV_Classification': load_ogb_molhiv_classification,
            'TUDataset_PROTEINS_Classification': load_tudataset_proteins,
            'TUDataset_MUTAG_Classification': load_tudataset_mutag
        }
        
        methods = {
            'Vanilla Average': VanillaAverage(),
            'CAGrad': CAGrad(),
            'GradNorm': GradNorm(),
            'PCGrad': PCGrad(),
            'Fixed GraGR Core': FixedGraGRCore(),
            'Fixed GraGR++': FixedGraGRPlusPlus()
        }
        
        for dataset_name, load_func in datasets.items():
            print(f"\n{'='*80}")
            print(f"BENCHMARKING ON REAL DATASET: {dataset_name}")
            print(f"{'='*80}")
            
            try:
                # Load real dataset
                train_loader, val_loader, test_loader, dataset_info = load_func()
                
                print(f"Dataset: {dataset_info['name']}")
                print(f"Tasks: {dataset_info['num_tasks']}")
                print(f"Task Names: {dataset_info['task_names']}")
                print(f"Classes per Task: {dataset_info['num_classes_per_task']}")
                print(f"Train Size: {dataset_info['train_size']}")
                print(f"Val Size: {dataset_info['val_size']}")
                print(f"Test Size: {dataset_info['test_size']}")
                
                # Get input dimension from first batch
                for batch in train_loader:
                    input_dim = batch.x.size(1)
                    break
                
                print(f"Input Dimension: {input_dim}")
                
                dataset_results = {}
                
                for method_name, resolver in methods.items():
                    print(f"\n--- {method_name} ---")
                    
                    # Create model
                    model = CompleteClassificationModel(
                        input_dim=input_dim,
                        hidden_dim=64,
                        num_tasks=dataset_info['num_tasks'],
                        num_classes_per_task=dataset_info['num_classes_per_task']
                    )
                    
                    # Train model
                    start_time = time.time()
                    losses, accuracies = train_model_complete(
                        model, train_loader, resolver, dataset_info, num_epochs=40
                    )
                    end_time = time.time()
                    
                    # Store results
                    dataset_results[method_name] = {
                        'final_loss': losses[-1],
                        'final_accuracy': accuracies[-1],
                        'training_time': end_time - start_time,
                        'losses': losses,
                        'accuracies': accuracies
                    }
                    
                    # Add to results table
                    self.results_table.append({
                        'Dataset': dataset_name,
                        'Method': method_name,
                        'Final_Loss': losses[-1],
                        'Final_Accuracy': accuracies[-1],
                        'Training_Time': end_time - start_time,
                        'Num_Tasks': dataset_info['num_tasks'],
                        'Data_Size': dataset_info['train_size'] + dataset_info['val_size'] + dataset_info['test_size']
                    })
                    
                    print(f"Final Loss: {losses[-1]:.4f}")
                    print(f"Final Accuracy: {accuracies[-1]:.4f}")
                    print(f"Training Time: {end_time - start_time:.2f}s")
                
                self.results[dataset_name] = dataset_results
                
            except Exception as e:
                print(f"Error loading {dataset_name}: {e}")
                continue
        
        self.print_complete_summary()
        self.save_results_table()
    
    def print_complete_summary(self):
        """Print complete benchmark summary."""
        print(f"\n{'='*100}")
        print("Complete GraGR Multi-Task Benchmark Summary")
        print(f"{'='*100}")
        
        methods = ['Vanilla Average', 'CAGrad', 'GradNorm', 'PCGrad', 'Fixed GraGR Core', 'Fixed GraGR++']
        
        print(f"\n{'Dataset':<40} {'Method':<20} {'Loss':<10} {'Accuracy':<10} {'Time (s)':<10}")
        print("-" * 100)
        
        for dataset_name, results in self.results.items():
            print(f"\n{dataset_name}:")
            for method in methods:
                if method in results:
                    result = results[method]
                    print(f"{'':<40} {method:<20} {result['final_loss']:<10.4f} "
                          f"{result['final_accuracy']:<10.4f} {result['training_time']:<10.2f}")
        
        # Overall ranking by accuracy
        print(f"\n{'='*100}")
        print("OVERALL RANKING BY ACCURACY (Complete with PCGrad)")
        print(f"{'='*100}")
        
        method_scores = {}
        for method in methods:
            scores = []
            for dataset_name, results in self.results.items():
                if method in results:
                    scores.append(results[method]['final_accuracy'])
            if scores:
                method_scores[method] = np.mean(scores)
        
        sorted_methods = sorted(method_scores.items(), key=lambda x: x[1], reverse=True)
        
        print(f"\n{'Rank':<5} {'Method':<20} {'Avg Accuracy':<15} {'Performance':<20}")
        print("-" * 70)
        
        for rank, (method, score) in enumerate(sorted_methods, 1):
            if rank == 1:
                performance = "🥇 Best"
            elif rank == 2:
                performance = "🥈 Second"
            elif rank == 3:
                performance = "🥉 Third"
            else:
                performance = "📊 Other"
            
            print(f"{rank:<5} {method:<20} {score:<15.4f} {performance:<20}")
        
        # GraGR performance analysis
        print(f"\n{'='*100}")
        print("COMPLETE GraGR PERFORMANCE ANALYSIS")
        print(f"{'='*100}")
        
        gragr_core_scores = []
        gragr_plus_scores = []
        
        for dataset_name, results in self.results.items():
            if 'Fixed GraGR Core' in results:
                gragr_core_scores.append(results['Fixed GraGR Core']['final_accuracy'])
            if 'Fixed GraGR++' in results:
                gragr_plus_scores.append(results['Fixed GraGR++']['final_accuracy'])
        
        if gragr_core_scores:
            print(f"Fixed GraGR Core Average Accuracy: {np.mean(gragr_core_scores):.4f}")
        if gragr_plus_scores:
            print(f"Fixed GraGR++ Average Accuracy: {np.mean(gragr_plus_scores):.4f}")
        
        # Count wins
        gragr_core_wins = 0
        gragr_plus_wins = 0
        total_datasets = len(self.results)
        
        for dataset_name, results in self.results.items():
            dataset_scores = [(method, result['final_accuracy']) for method, result in results.items()]
            dataset_scores.sort(key=lambda x: x[1], reverse=True)
            
            if dataset_scores[0][0] == 'Fixed GraGR Core':
                gragr_core_wins += 1
            if dataset_scores[0][0] == 'Fixed GraGR++':
                gragr_plus_wins += 1
        
        print(f"Fixed GraGR Core wins: {gragr_core_wins}/{total_datasets} datasets")
        print(f"Fixed GraGR++ wins: {gragr_plus_wins}/{total_datasets} datasets")
        
        if gragr_core_wins > 0 or gragr_plus_wins > 0:
            print("✅ Complete GraGR is performing well!")
        else:
            print("❌ Complete GraGR needs further optimization")
    
    def save_results_table(self):
        """Save results to CSV table."""
        df = pd.DataFrame(self.results_table)
        df.to_csv('complete_gragr_benchmark_results.csv', index=False)
        print(f"\n{'='*100}")
        print("Complete GraGR Benchmark Results SAVED TO: complete_gragr_benchmark_results.csv")
        print(f"{'='*100}")
        
        # Print summary table
        print("\nSUMMARY TABLE (Complete with PCGrad):")
        if len(df) > 0:
            print(df.groupby('Method').agg({
                'Final_Accuracy': ['mean', 'std', 'min', 'max'],
                'Final_Loss': ['mean', 'std'],
                'Training_Time': ['mean', 'std']
            }).round(4))

def main():
    """Run complete GraGR benchmark with all methods."""
    runner = CompleteBenchmarkRunner()
    runner.run_complete_benchmark()

if __name__ == "__main__":
    main()
