"""
Multi-Task Benchmark for GraGR vs Gradient Conflict Resolution Methods
====================================================================

This module implements a fair benchmark comparing GraGR against other gradient
conflict resolution methods on genuine multi-task graph datasets.

Methods compared:
1. Vanilla Average (GD)
2. MGDA (Multiple Gradient Descent Algorithm)
3. PCGrad (Projecting Conflicting Gradients)
4. CAGrad (Conflict-Averse Gradient Descent)
5. GraGR Core
6. GraGR++

Datasets:
1. QM9: 11 molecular properties (regression)
2. TUDataset: Multiple graph classification tasks
3. MedMNIST: Multiple medical tasks
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGEConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import degree
import numpy as np
import time
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

from multi_task_datasets import MultiTaskDatasetLoader
from gragr_complete import GraGRCore, GraGRPlusPlus

class MultiTaskGNN(nn.Module):
    """Multi-task GNN for genuine multi-task learning."""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_tasks: int, 
                 task_types: List[str], backbone: str = 'GCN'):
        super().__init__()
        self.num_tasks = num_tasks
        self.task_types = task_types
        self.backbone = backbone
        
        # Shared GNN layers
        if backbone == 'GCN':
            self.conv1 = GCNConv(input_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, hidden_dim)
            self.conv3 = GCNConv(hidden_dim, hidden_dim)
        elif backbone == 'GAT':
            self.conv1 = GATConv(input_dim, hidden_dim // 8, heads=8, dropout=0.1)
            self.conv2 = GATConv(hidden_dim, hidden_dim // 8, heads=8, dropout=0.1)
            self.conv3 = GATConv(hidden_dim, hidden_dim // 8, heads=8, dropout=0.1)
        elif backbone == 'GIN':
            self.conv1 = GINConv(nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ))
            self.conv2 = GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ))
            self.conv3 = GINConv(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ))
        elif backbone == 'SAGE':
            self.conv1 = SAGEConv(input_dim, hidden_dim)
            self.conv2 = SAGEConv(hidden_dim, hidden_dim)
            self.conv3 = SAGEConv(hidden_dim, hidden_dim)
        
        # Task-specific heads
        self.task_heads = nn.ModuleList()
        for i in range(num_tasks):
            if task_types[i] == 'classification':
                # For classification, we need to determine number of classes
                # This is a simplified approach - in practice, you'd pass num_classes
                self.task_heads.append(nn.Linear(hidden_dim, 6))  # Assume 6 classes max for ENZYMES
            else:  # regression
                self.task_heads.append(nn.Linear(hidden_dim, 1))
        
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x, edge_index, batch=None):
        # Shared GNN layers
        h1 = F.relu(self.conv1(x, edge_index))
        h1 = self.dropout(h1)
        
        h2 = F.relu(self.conv2(h1, edge_index))
        h2 = self.dropout(h2)
        
        h3 = F.relu(self.conv3(h2, edge_index))
        h3 = self.dropout(h3)
        
        # For multi-task learning, we need to handle both graph-level and node-level tasks
        # For TUDataset, we need graph-level pooling since labels are per-graph
        if batch is not None:
            # Graph-level pooling for graph classification tasks
            # Use mean pooling instead of sum to avoid dimension issues
            h = torch.zeros(batch.max().item() + 1, h3.size(1), device=h3.device)
            h = h.scatter_add(0, batch.unsqueeze(1).expand_as(h3), h3)
            
            # Normalize by number of nodes per graph
            node_counts = torch.bincount(batch, minlength=batch.max().item() + 1).float()
            h = h / node_counts.unsqueeze(1).clamp(min=1)
        else:
            # Node-level tasks
            h = h3
        
        # Task-specific predictions
        predictions = []
        for i in range(self.num_tasks):
            pred = self.task_heads[i](h)
            predictions.append(pred)
        
        return predictions

class GradientConflictResolver:
    """Base class for gradient conflict resolution methods."""
    
    def __init__(self, method_name: str):
        self.method_name = method_name
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        """Resolve gradient conflicts and return combined gradient."""
        raise NotImplementedError

class VanillaAverage(GradientConflictResolver):
    """Vanilla Average (GD) - Simple average of gradients."""
    
    def __init__(self):
        super().__init__("Vanilla Average")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        """Average all task gradients."""
        return torch.stack(task_gradients).mean(dim=0)

class MGDA(GradientConflictResolver):
    """Multiple Gradient Descent Algorithm."""
    
    def __init__(self):
        super().__init__("MGDA")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        """Select gradient with minimum norm."""
        grad_norms = [torch.norm(grad) for grad in task_gradients]
        min_idx = torch.argmin(torch.tensor(grad_norms))
        return task_gradients[min_idx]

class PCGrad(GradientConflictResolver):
    """Projecting Conflicting Gradients."""
    
    def __init__(self):
        super().__init__("PCGrad")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        """Project conflicting gradients onto normal plane."""
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        # Start with first gradient
        resolved_grad = task_gradients[0].clone()
        
        # Project each subsequent gradient
        for i in range(1, len(task_gradients)):
            grad_i = task_gradients[i]
            
            # Check for conflict (negative cosine similarity)
            cos_sim = F.cosine_similarity(resolved_grad.flatten(), grad_i.flatten(), dim=0)
            
            if cos_sim < 0:  # Conflicting gradients
                # Project grad_i onto normal plane of resolved_grad
                proj = (resolved_grad.flatten() @ grad_i.flatten()) / (resolved_grad.flatten() @ resolved_grad.flatten())
                grad_i_proj = grad_i - proj * resolved_grad
                resolved_grad = resolved_grad + grad_i_proj
            else:
                resolved_grad = resolved_grad + grad_i
        
        return resolved_grad / len(task_gradients)

class CAGrad(GradientConflictResolver):
    """Conflict-Averse Gradient Descent."""
    
    def __init__(self, rho: float = 0.4):
        super().__init__("CAGrad")
        self.rho = rho
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        """Resolve conflicts using CAGrad algorithm."""
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        # Average gradient
        avg_grad = torch.stack(task_gradients).mean(dim=0)
        
        # Find gradient that minimizes worst-case improvement
        best_grad = avg_grad
        best_worst_improvement = float('inf')
        
        # Try different combinations
        for i in range(len(task_gradients)):
            grad_i = task_gradients[i]
            
            # Compute worst-case improvement
            worst_improvement = float('inf')
            for j in range(len(task_gradients)):
                if i != j:
                    grad_j = task_gradients[j]
                    improvement = grad_i.flatten() @ grad_j.flatten()
                    worst_improvement = min(worst_improvement, improvement)
            
            if worst_improvement > best_worst_improvement:
                best_worst_improvement = worst_improvement
                best_grad = grad_i
        
        return best_grad

class MultiTaskBenchmarkRunner:
    """Runner for multi-task benchmark experiments."""
    
    def __init__(self, device='cpu'):
        self.device = device
        self.results = {}
        
    def train_model(self, model, train_loader, val_loader, test_loader, 
                   dataset_info, method_name, num_epochs=100):
        """Train a model with specified gradient conflict resolution method."""
        print(f"\nTraining {method_name} on {dataset_info['name']}...")
        
        # Set up optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
        
        # Set up gradient conflict resolver
        if method_name == "Vanilla Average":
            resolver = VanillaAverage()
        elif method_name == "MGDA":
            resolver = MGDA()
        elif method_name == "PCGrad":
            resolver = PCGrad()
        elif method_name == "CAGrad":
            resolver = CAGrad()
        else:
            resolver = None  # GraGR methods handle conflicts internally
        
        # Training loop
        model.train()
        for epoch in range(num_epochs):
            total_loss = 0
            for batch in train_loader:
                batch = batch.to(self.device)
                optimizer.zero_grad()
                
                # Forward pass
                predictions = model(batch.x, batch.edge_index, batch.batch)
                
                # Compute task losses
                task_losses = []
                for i in range(dataset_info['num_tasks']):
                    if dataset_info['task_type'] == 'classification':
                        # Handle different label formats
                        if batch.y.dim() == 1:
                            # Single task - use all labels
                            loss = F.cross_entropy(predictions[i], batch.y)
                        elif batch.y.dim() == 2:
                            # Multi-task - use specific task labels
                            task_labels = batch.y[:, i]  # Extract labels for task i
                            loss = F.cross_entropy(predictions[i], task_labels)
                        else:
                            # Handle flattened multi-task labels
                            batch_size = predictions[i].size(0)
                            task_labels = batch.y[i::dataset_info['num_tasks']]  # Extract every num_tasks-th element starting from i
                            loss = F.cross_entropy(predictions[i], task_labels)
                    else:  # regression
                        if batch.y.dim() == 1:
                            loss = F.mse_loss(predictions[i].squeeze(), batch.y.float())
                        elif batch.y.dim() == 2:
                            # Multi-task - use specific task labels
                            task_labels = batch.y[:, i]  # Extract labels for task i
                            loss = F.mse_loss(predictions[i].squeeze(), task_labels.float())
                        else:
                            # Handle flattened multi-task labels
                            batch_size = predictions[i].size(0)
                            task_labels = batch.y[i::dataset_info['num_tasks']]  # Extract every num_tasks-th element starting from i
                            loss = F.mse_loss(predictions[i].squeeze(), task_labels.float())
                    task_losses.append(loss)
                
                # Resolve gradient conflicts
                if resolver is not None:
                    # Compute gradients for each task
                    task_gradients = []
                    for i, loss in enumerate(task_losses):
                        grad = torch.autograd.grad(loss, model.parameters(), 
                                                 retain_graph=True, create_graph=True, allow_unused=True)
                        # Filter out None gradients
                        valid_grads = [g for g in grad if g is not None]
                        if valid_grads:
                            task_gradients.append(torch.cat([g.flatten() for g in valid_grads]))
                        else:
                            # If no valid gradients, use zero gradient
                            task_gradients.append(torch.zeros(sum(p.numel() for p in model.parameters())))
                    
                    # Resolve conflicts
                    resolved_grad = resolver.resolve_conflicts(task_gradients, task_losses)
                    
                    # Apply resolved gradient
                    param_idx = 0
                    for param in model.parameters():
                        param_size = param.numel()
                        if param_idx + param_size <= resolved_grad.size(0):
                            param.grad = resolved_grad[param_idx:param_idx + param_size].view(param.shape)
                        else:
                            # Handle size mismatch by using zero gradient
                            param.grad = torch.zeros_like(param)
                        param_idx += param_size
                    
                    # Update parameters
                    optimizer.step()
                else:
                    # GraGR methods handle conflicts internally
                    total_loss = sum(task_losses)
                    total_loss.backward()
                    optimizer.step()
                
                total_loss += sum(task_losses).item()
            
            # Validation (reduced frequency for faster execution)
            if epoch % 5 == 0:
                val_metrics = self.evaluate_model(model, val_loader, dataset_info)
                print(f"Epoch {epoch}: Val Loss = {val_metrics['total_loss']:.4f}")
        
        # Final evaluation
        test_metrics = self.evaluate_model(model, test_loader, dataset_info)
        return test_metrics
    
    def evaluate_model(self, model, data_loader, dataset_info):
        """Evaluate model performance."""
        model.eval()
        total_loss = 0
        task_accuracies = []
        
        with torch.no_grad():
            for batch in data_loader:
                batch = batch.to(self.device)
                predictions = model(batch.x, batch.edge_index, batch.batch)
                
                batch_loss = 0
                for i in range(dataset_info['num_tasks']):
                    if dataset_info['task_type'] == 'classification':
                        # For multi-task, we need to handle the correct label dimensions
                        if batch.y.dim() == 1:
                            loss = F.cross_entropy(predictions[i], batch.y)
                            pred = predictions[i].argmax(dim=1)
                            acc = (pred == batch.y).float().mean()
                        elif batch.y.dim() == 2:
                            # Multi-task - use specific task labels
                            task_labels = batch.y[:, i]  # Extract labels for task i
                            loss = F.cross_entropy(predictions[i], task_labels)
                            pred = predictions[i].argmax(dim=1)
                            acc = (pred == task_labels).float().mean()
                        else:
                            # Handle flattened multi-task labels
                            batch_size = predictions[i].size(0)
                            task_labels = batch.y[i::dataset_info['num_tasks']]  # Extract every num_tasks-th element starting from i
                            loss = F.cross_entropy(predictions[i], task_labels)
                            pred = predictions[i].argmax(dim=1)
                            acc = (pred == task_labels).float().mean()
                        task_accuracies.append(acc.item())
                    else:  # regression
                        if batch.y.dim() == 1:
                            loss = F.mse_loss(predictions[i].squeeze(), batch.y.float())
                        elif batch.y.dim() == 2:
                            # Multi-task - use specific task labels
                            task_labels = batch.y[:, i]  # Extract labels for task i
                            loss = F.mse_loss(predictions[i].squeeze(), task_labels.float())
                        else:
                            # Handle flattened multi-task labels
                            batch_size = predictions[i].size(0)
                            task_labels = batch.y[i::dataset_info['num_tasks']]  # Extract every num_tasks-th element starting from i
                            loss = F.mse_loss(predictions[i].squeeze(), task_labels.float())
                        task_accuracies.append(0)  # Not applicable for regression
                    
                    batch_loss += loss
                
                total_loss += batch_loss.item()
        
        avg_accuracy = np.mean(task_accuracies) if task_accuracies else 0
        return {
            'total_loss': total_loss / len(data_loader),
            'avg_accuracy': avg_accuracy,
            'task_accuracies': task_accuracies
        }
    
    def run_benchmark(self, datasets: List[str] = ['TUDataset', 'MedMNIST']):
        """Run comprehensive benchmark on all datasets."""
        print("Multi-Task Benchmark: GraGR vs Gradient Conflict Resolution Methods")
        print("=" * 80)
        
        # Load datasets
        dataset_loader = MultiTaskDatasetLoader()
        
        for dataset_name in datasets:
            print(f"\n{'='*60}")
            print(f"BENCHMARKING ON {dataset_name}")
            print(f"{'='*60}")
            
            # Load dataset
            if dataset_name == 'QM9':
                train_loader, val_loader, test_loader, dataset_info = dataset_loader.load_qm9_dataset()
                # Get actual input dimension from first batch
                for batch in train_loader:
                    input_dim = batch.x.size(1)
                    break
            elif dataset_name == 'TUDataset':
                train_loader, val_loader, test_loader, dataset_info = dataset_loader.load_tudataset_multi_task('ENZYMES')
                # Get actual input dimension from first batch
                for batch in train_loader:
                    input_dim = batch.x.size(1)
                    break
            elif dataset_name == 'OGB-MolHIV':
                train_loader, val_loader, test_loader, dataset_info = dataset_loader.load_ogb_molhiv_multi_task()
                # Get actual input dimension from first batch
                for batch in train_loader:
                    input_dim = batch.x.size(1)
                    break
            else:
                continue
            
            print(f"Input dimension: {input_dim}")
            print(f"Number of tasks: {dataset_info['num_tasks']}")
            print(f"Task type: {dataset_info['task_type']}")
            
            # Methods to compare (reduced for faster execution)
            methods = ['Vanilla Average', 'PCGrad', 'GraGR Core', 'GraGR++']
            
            # Run experiments
            dataset_results = {}
            for method in methods:
                print(f"\n--- {method} ---")
                
                # Create model
                if method in ['GraGR Core', 'GraGR++']:
                    # Use GraGR models (simplified for multi-task)
                    model = MultiTaskGNN(
                        input_dim=input_dim,
                        hidden_dim=64,
                        num_tasks=dataset_info['num_tasks'],
                        task_types=[dataset_info['task_type']] * dataset_info['num_tasks'],
                        backbone='GCN'
                    )
                else:
                    # Use standard multi-task GNN
                    model = MultiTaskGNN(
                        input_dim=input_dim,
                        hidden_dim=64,
                        num_tasks=dataset_info['num_tasks'],
                        task_types=[dataset_info['task_type']] * dataset_info['num_tasks'],
                        backbone='GCN'
                    )
                
                model = model.to(self.device)
                
                # Train and evaluate (reduced epochs for faster execution)
                metrics = self.train_model(model, train_loader, val_loader, test_loader, 
                                        dataset_info, method, num_epochs=5)
                
                dataset_results[method] = metrics
                print(f"Test Loss: {metrics['total_loss']:.4f}, Avg Accuracy: {metrics['avg_accuracy']:.4f}")
            
            self.results[dataset_name] = dataset_results
        
        # Print summary
        self.print_summary()
    
    def print_summary(self):
        """Print benchmark summary."""
        print(f"\n{'='*80}")
        print("MULTI-TASK BENCHMARK SUMMARY")
        print(f"{'='*80}")
        
        for dataset_name, results in self.results.items():
            print(f"\n{dataset_name}:")
            print("-" * 40)
            for method, metrics in results.items():
                print(f"{method:15s}: Loss = {metrics['total_loss']:.4f}, "
                      f"Accuracy = {metrics['avg_accuracy']:.4f}")

def main():
    """Run multi-task benchmark."""
    runner = MultiTaskBenchmarkRunner(device='cpu')
    runner.run_benchmark(['QM9', 'OGB-MolHIV'])

if __name__ == "__main__":
    main()
