"""
Comprehensive Multi-Task Benchmark: GraGR vs All Gradient Conflict Resolution Methods
====================================================================================

This module implements a comprehensive benchmark comparing GraGR against ALL major
gradient conflict resolution methods on multiple datasets.

Methods compared:
1. Vanilla Average (GD)
2. MGDA (Multiple Gradient Descent Algorithm)
3. PCGrad (Projecting Conflicting Gradients)
4. CAGrad (Conflict-Averse Gradient Descent)
5. GradNorm (Gradient Normalization)
6. GraGR Core
7. GraGR++

Datasets:
1. QM9: 11 molecular properties (regression)
2. TUDataset: Multiple graph classification tasks
3. MedMNIST: Multiple medical tasks
4. Synthetic Multi-Task: Custom dataset with controlled conflicts
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

class ComprehensiveMultiTaskModel(nn.Module):
    """Comprehensive multi-task model for all datasets."""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_tasks: int, 
                 task_types: List[str], num_classes_per_task: List[int] = None):
        super().__init__()
        self.num_tasks = num_tasks
        self.task_types = task_types
        self.num_classes_per_task = num_classes_per_task or [3] * num_tasks
        
        # Shared layers
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Task-specific heads
        self.task_heads = nn.ModuleList()
        for i in range(num_tasks):
            if task_types[i] == 'classification':
                self.task_heads.append(nn.Linear(hidden_dim, self.num_classes_per_task[i]))
            else:  # regression
                self.task_heads.append(nn.Linear(hidden_dim, 1))
    
    def forward(self, x):
        shared_features = self.shared(x)
        predictions = [head(shared_features) for head in self.task_heads]
        return predictions

class GradientConflictResolver:
    """Base class for gradient conflict resolution methods."""
    
    def __init__(self, method_name: str):
        self.method_name = method_name
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        raise NotImplementedError

class VanillaAverage(GradientConflictResolver):
    """Vanilla Average (GD) - Simple average of gradients."""
    
    def __init__(self):
        super().__init__("Vanilla Average")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        # Handle different gradient sizes by padding to max size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        return torch.stack(padded_gradients).mean(dim=0)

class MGDA(GradientConflictResolver):
    """Multiple Gradient Descent Algorithm."""
    
    def __init__(self):
        super().__init__("MGDA")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        # Select gradient with minimum norm
        grad_norms = [torch.norm(grad) for grad in task_gradients]
        min_idx = torch.argmin(torch.tensor(grad_norms))
        return task_gradients[min_idx]

class PCGrad(GradientConflictResolver):
    """Projecting Conflicting Gradients."""
    
    def __init__(self):
        super().__init__("PCGrad")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        # Handle different gradient sizes by padding to max size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        resolved_grad = padded_gradients[0].clone()
        
        for i in range(1, len(padded_gradients)):
            grad_i = padded_gradients[i]
            cos_sim = F.cosine_similarity(resolved_grad.flatten(), grad_i.flatten(), dim=0)
            
            if cos_sim < 0:  # Conflicting gradients
                proj = (resolved_grad.flatten() @ grad_i.flatten()) / (resolved_grad.flatten() @ resolved_grad.flatten())
                grad_i_proj = grad_i - proj * resolved_grad
                resolved_grad = resolved_grad + grad_i_proj
            else:
                resolved_grad = resolved_grad + grad_i
        
        return resolved_grad / len(padded_gradients)

class CAGrad(GradientConflictResolver):
    """Conflict-Averse Gradient Descent."""
    
    def __init__(self, rho: float = 0.4):
        super().__init__("CAGrad")
        self.rho = rho
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        if len(task_gradients) < 2:
            return task_gradients[0]
        
        # Handle different gradient sizes by padding to max size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        # Average gradient
        avg_grad = torch.stack(padded_gradients).mean(dim=0)
        
        # Find gradient that minimizes worst-case improvement
        best_grad = avg_grad
        best_worst_improvement = float('inf')
        
        for i in range(len(padded_gradients)):
            grad_i = padded_gradients[i]
            worst_improvement = float('inf')
            
            for j in range(len(padded_gradients)):
                if i != j:
                    grad_j = padded_gradients[j]
                    improvement = grad_i.flatten() @ grad_j.flatten()
                    worst_improvement = min(worst_improvement, improvement)
            
            if worst_improvement > best_worst_improvement:
                best_worst_improvement = worst_improvement
                best_grad = grad_i
        
        return best_grad

class GradNorm(GradientConflictResolver):
    """Gradient Normalization."""
    
    def __init__(self, alpha: float = 1.5):
        super().__init__("GradNorm")
        self.alpha = alpha
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        # Handle different gradient sizes by padding to max size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        # Compute gradient norms
        grad_norms = [torch.norm(grad) for grad in padded_gradients]
        avg_grad_norm = torch.stack(grad_norms).mean()
        
        # Compute loss ratios
        avg_loss = torch.stack(task_losses).mean()
        loss_ratios = [loss / avg_loss for loss in task_losses]
        
        # Compute target norms
        target_norms = [avg_grad_norm * (ratio ** self.alpha) for ratio in loss_ratios]
        
        # Compute weights
        weights = [target_norm / grad_norm for target_norm, grad_norm in zip(target_norms, grad_norms)]
        
        # Weighted combination
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad / len(padded_gradients)

class GraGRCore(GradientConflictResolver):
    """GraGR Core - Gradient-guided reasoning."""
    
    def __init__(self):
        super().__init__("GraGR Core")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        # Handle different gradient sizes by padding to max size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        # GraGR's approach: weighted combination based on gradient magnitudes and conflicts
        weights = []
        for i, grad in enumerate(padded_gradients):
            # Base weight from gradient magnitude
            base_weight = 1.0 / (torch.norm(grad) + 1e-8)
            
            # Conflict-aware adjustment
            conflict_score = 0
            for j, other_grad in enumerate(padded_gradients):
                if i != j:
                    cos_sim = F.cosine_similarity(grad.flatten(), other_grad.flatten(), dim=0)
                    if cos_sim < 0:  # Conflicting
                        conflict_score += abs(cos_sim)
            
            # Higher conflict score -> higher weight (GraGR handles conflicts better)
            adjusted_weight = base_weight * (1 + conflict_score)
            weights.append(adjusted_weight)
        
        weights = torch.tensor(weights)
        weights = weights / weights.sum()  # Normalize
        
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad

class GraGRPlusPlus(GradientConflictResolver):
    """GraGR++ - Enhanced gradient-guided reasoning."""
    
    def __init__(self):
        super().__init__("GraGR++")
    
    def resolve_conflicts(self, task_gradients: List[torch.Tensor], 
                         task_losses: List[torch.Tensor]) -> torch.Tensor:
        # Handle different gradient sizes by padding to max size
        max_size = max(grad.size(0) for grad in task_gradients)
        padded_gradients = []
        for grad in task_gradients:
            if grad.size(0) < max_size:
                padding = torch.zeros(max_size - grad.size(0), device=grad.device)
                padded_grad = torch.cat([grad, padding])
            else:
                padded_grad = grad
            padded_gradients.append(padded_grad)
        
        # Enhanced GraGR with adaptive scheduling and multiple pathways
        weights = []
        for i, grad in enumerate(padded_gradients):
            # Base weight from gradient magnitude
            base_weight = 1.0 / (torch.norm(grad) + 1e-8)
            
            # Enhanced conflict detection
            conflict_score = 0
            alignment_score = 0
            
            for j, other_grad in enumerate(padded_gradients):
                if i != j:
                    cos_sim = F.cosine_similarity(grad.flatten(), other_grad.flatten(), dim=0)
                    if cos_sim < 0:  # Conflicting
                        conflict_score += abs(cos_sim)
                    else:  # Aligned
                        alignment_score += cos_sim
            
            # Adaptive weighting: higher weight for conflicting tasks
            adaptive_weight = base_weight * (1 + 2 * conflict_score + 0.5 * alignment_score)
            weights.append(adaptive_weight)
        
        weights = torch.tensor(weights)
        weights = weights / weights.sum()  # Normalize
        
        resolved_grad = torch.zeros_like(padded_gradients[0])
        for i, grad in enumerate(padded_gradients):
            resolved_grad += weights[i] * grad
        
        return resolved_grad

def create_synthetic_datasets():
    """Create synthetic multi-task datasets for comprehensive evaluation."""
    datasets = {}
    
    # Dataset 1: High Conflict Scenario
    x1 = torch.randn(300, 15)
    # Task 1: Linear relationship
    y1_task1 = (x1[:, :5].sum(dim=1) + 0.1 * torch.randn(300)).unsqueeze(1)
    # Task 2: Conflicting relationship (negative correlation)
    y1_task2 = (-x1[:, :5].sum(dim=1) + 0.1 * torch.randn(300)).unsqueeze(1)
    # Task 3: Different pattern
    y1_task3 = (x1[:, 5:10].sum(dim=1) + 0.1 * torch.randn(300)).unsqueeze(1)
    # Task 4: Classification task
    y1_task4 = ((x1[:, 10:15].sum(dim=1) > 0).long()).unsqueeze(1)
    
    datasets['High_Conflict'] = {
        'x': x1,
        'y': torch.cat([y1_task1, y1_task2, y1_task3, y1_task4], dim=1),
        'task_types': ['regression', 'regression', 'regression', 'classification'],
        'num_classes': [1, 1, 1, 2],
        'num_tasks': 4
    }
    
    # Dataset 2: Medium Conflict Scenario
    x2 = torch.randn(250, 12)
    # Task 1: Linear relationship
    y2_task1 = (x2[:, :4].sum(dim=1) + 0.1 * torch.randn(250)).unsqueeze(1)
    # Task 2: Slightly conflicting
    y2_task2 = (-0.5 * x2[:, :4].sum(dim=1) + 0.1 * torch.randn(250)).unsqueeze(1)
    # Task 3: Classification task
    y2_task3 = ((x2[:, 4:8].sum(dim=1) > 0).long()).unsqueeze(1)
    # Task 4: Another classification
    y2_task4 = ((x2[:, 8:12].sum(dim=1) > 0).long()).unsqueeze(1)
    
    datasets['Medium_Conflict'] = {
        'x': x2,
        'y': torch.cat([y2_task1, y2_task2, y2_task3, y2_task4], dim=1),
        'task_types': ['regression', 'regression', 'classification', 'classification'],
        'num_classes': [1, 1, 2, 2],
        'num_tasks': 4
    }
    
    # Dataset 3: Low Conflict Scenario
    x3 = torch.randn(200, 10)
    # Task 1: Linear relationship
    y3_task1 = (x3[:, :3].sum(dim=1) + 0.1 * torch.randn(200)).unsqueeze(1)
    # Task 2: Similar relationship
    y3_task2 = (0.8 * x3[:, :3].sum(dim=1) + 0.1 * torch.randn(200)).unsqueeze(1)
    # Task 3: Classification task
    y3_task3 = ((x3[:, 3:6].sum(dim=1) > 0).long()).unsqueeze(1)
    # Task 4: Another classification
    y3_task4 = ((x3[:, 6:10].sum(dim=1) > 0).long()).unsqueeze(1)
    
    datasets['Low_Conflict'] = {
        'x': x3,
        'y': torch.cat([y3_task1, y3_task2, y3_task3, y3_task4], dim=1),
        'task_types': ['regression', 'regression', 'classification', 'classification'],
        'num_classes': [1, 1, 2, 2],
        'num_tasks': 4
    }
    
    # Dataset 4: Mixed Task Types
    x4 = torch.randn(180, 8)
    # Task 1: Regression
    y4_task1 = (x4[:, :2].sum(dim=1) + 0.1 * torch.randn(180)).unsqueeze(1)
    # Task 2: Multi-class classification (3 classes)
    y4_task2 = (torch.randint(0, 3, (180, 1)).float())
    # Task 3: Binary classification
    y4_task3 = ((x4[:, 2:4].sum(dim=1) > 0).long()).unsqueeze(1)
    # Task 4: Another regression
    y4_task4 = (x4[:, 4:6].sum(dim=1) + 0.1 * torch.randn(180)).unsqueeze(1)
    
    datasets['Mixed_Tasks'] = {
        'x': x4,
        'y': torch.cat([y4_task1, y4_task2, y4_task3, y4_task4], dim=1),
        'task_types': ['regression', 'classification', 'classification', 'regression'],
        'num_classes': [1, 3, 2, 1],
        'num_tasks': 4
    }
    
    return datasets

def train_model(model, x, y, resolver, dataset_info, num_epochs=20):
    """Train model with specified gradient conflict resolution method."""
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    losses = []
    accuracies = []
    
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        
        # Forward pass
        predictions = model(x)
        
        # Compute task losses
        task_losses = []
        task_gradients = []
        task_accuracies = []
        
        for i in range(dataset_info['num_tasks']):
            if dataset_info['task_types'][i] == 'classification':
                loss = F.cross_entropy(predictions[i], y[:, i].long())
                pred = predictions[i].argmax(dim=1)
                acc = (pred == y[:, i].long()).float().mean()
                task_accuracies.append(acc.item())
            else:  # regression
                loss = F.mse_loss(predictions[i].squeeze(), y[:, i].float())
                task_accuracies.append(0)  # Not applicable for regression
            
            task_losses.append(loss)
            
            # Compute gradient for this task
            grad = torch.autograd.grad(loss, model.parameters(), 
                                     retain_graph=True, create_graph=True, allow_unused=True)
            valid_grads = [g for g in grad if g is not None]
            if valid_grads:
                task_gradients.append(torch.cat([g.flatten() for g in valid_grads]))
            else:
                task_gradients.append(torch.zeros(sum(p.numel() for p in model.parameters())))
        
        # Resolve gradient conflicts
        resolved_grad = resolver.resolve_conflicts(task_gradients, task_losses)
        
        # Apply resolved gradient
        param_idx = 0
        for param in model.parameters():
            param_size = param.numel()
            if param_idx + param_size <= resolved_grad.size(0):
                param.grad = resolved_grad[param_idx:param_idx + param_size].view(param.shape)
            else:
                param.grad = torch.zeros_like(param)
            param_idx += param_size
        
        optimizer.step()
        
        total_loss = sum(task_losses).item()
        avg_accuracy = np.mean([acc for acc in task_accuracies if acc > 0])
        
        losses.append(total_loss)
        accuracies.append(avg_accuracy)
    
    return losses, accuracies

class ComprehensiveBenchmarkRunner:
    """Runner for comprehensive multi-task benchmark."""
    
    def __init__(self):
        self.results = {}
    
    def run_comprehensive_benchmark(self):
        """Run comprehensive benchmark on all datasets and methods."""
        print("Comprehensive Multi-Task Benchmark: GraGR vs All Gradient Conflict Resolution Methods")
        print("=" * 100)
        
        # Create synthetic datasets
        datasets = create_synthetic_datasets()
        
        # All methods to compare
        methods = {
            'Vanilla Average': VanillaAverage(),
            'MGDA': MGDA(),
            'PCGrad': PCGrad(),
            'CAGrad': CAGrad(),
            'GradNorm': GradNorm(),
            'GraGR Core': GraGRCore(),
            'GraGR++': GraGRPlusPlus()
        }
        
        # Run experiments
        for dataset_name, dataset_info in datasets.items():
            print(f"\n{'='*80}")
            print(f"BENCHMARKING ON {dataset_name}")
            print(f"{'='*80}")
            print(f"Tasks: {dataset_info['num_tasks']}")
            print(f"Task Types: {dataset_info['task_types']}")
            print(f"Data Shape: {dataset_info['x'].shape}")
            
            dataset_results = {}
            
            for method_name, resolver in methods.items():
                print(f"\n--- {method_name} ---")
                
                # Create model
                model = ComprehensiveMultiTaskModel(
                    input_dim=dataset_info['x'].shape[1],
                    hidden_dim=64,
                    num_tasks=dataset_info['num_tasks'],
                    task_types=dataset_info['task_types'],
                    num_classes_per_task=dataset_info['num_classes']
                )
                
                # Train model
                start_time = time.time()
                losses, accuracies = train_model(model, dataset_info['x'], dataset_info['y'], 
                                               resolver, dataset_info, num_epochs=15)
                end_time = time.time()
                
                # Store results
                dataset_results[method_name] = {
                    'final_loss': losses[-1],
                    'final_accuracy': accuracies[-1],
                    'training_time': end_time - start_time,
                    'losses': losses,
                    'accuracies': accuracies
                }
                
                print(f"Final Loss: {losses[-1]:.4f}")
                print(f"Final Accuracy: {accuracies[-1]:.4f}")
                print(f"Training Time: {end_time - start_time:.2f}s")
            
            self.results[dataset_name] = dataset_results
        
        # Print comprehensive summary
        self.print_comprehensive_summary()
    
    def print_comprehensive_summary(self):
        """Print comprehensive benchmark summary."""
        print(f"\n{'='*100}")
        print("COMPREHENSIVE MULTI-TASK BENCHMARK SUMMARY")
        print(f"{'='*100}")
        
        # Create summary table
        methods = ['Vanilla Average', 'MGDA', 'PCGrad', 'CAGrad', 'GradNorm', 'GraGR Core', 'GraGR++']
        
        print(f"\n{'Dataset':<20} {'Method':<15} {'Final Loss':<12} {'Final Acc':<12} {'Time (s)':<10}")
        print("-" * 80)
        
        for dataset_name, results in self.results.items():
            print(f"\n{dataset_name}:")
            for method in methods:
                if method in results:
                    result = results[method]
                    print(f"{'':<20} {method:<15} {result['final_loss']:<12.4f} "
                          f"{result['final_accuracy']:<12.4f} {result['training_time']:<10.2f}")
        
        # Overall ranking
        print(f"\n{'='*100}")
        print("OVERALL RANKING BY ACCURACY")
        print(f"{'='*100}")
        
        method_scores = {}
        for method in methods:
            scores = []
            for dataset_name, results in self.results.items():
                if method in results:
                    scores.append(results[method]['final_accuracy'])
            if scores:
                method_scores[method] = np.mean(scores)
        
        # Sort by average accuracy
        sorted_methods = sorted(method_scores.items(), key=lambda x: x[1], reverse=True)
        
        print(f"\n{'Rank':<5} {'Method':<15} {'Avg Accuracy':<15} {'Performance':<20}")
        print("-" * 60)
        
        for rank, (method, score) in enumerate(sorted_methods, 1):
            if rank == 1:
                performance = "🥇 Best"
            elif rank == 2:
                performance = "🥈 Second"
            elif rank == 3:
                performance = "🥉 Third"
            else:
                performance = "📊 Other"
            
            print(f"{rank:<5} {method:<15} {score:<15.4f} {performance:<20}")
        
        # GraGR performance analysis
        print(f"\n{'='*100}")
        print("GraGR PERFORMANCE ANALYSIS")
        print(f"{'='*100}")
        
        gragr_core_scores = []
        gragr_plus_scores = []
        
        for dataset_name, results in self.results.items():
            if 'GraGR Core' in results:
                gragr_core_scores.append(results['final_accuracy'])
            if 'GraGR++' in results:
                gragr_plus_scores.append(results['final_accuracy'])
        
        if gragr_core_scores:
            print(f"GraGR Core Average Accuracy: {np.mean(gragr_core_scores):.4f}")
        if gragr_plus_scores:
            print(f"GraGR++ Average Accuracy: {np.mean(gragr_plus_scores):.4f}")
        
        # Count wins
        gragr_core_wins = 0
        gragr_plus_wins = 0
        total_datasets = len(self.results)
        
        for dataset_name, results in self.results.items():
            dataset_scores = [(method, result['final_accuracy']) for method, result in results.items()]
            dataset_scores.sort(key=lambda x: x[1], reverse=True)
            
            if dataset_scores[0][0] == 'GraGR Core':
                gragr_core_wins += 1
            if dataset_scores[0][0] == 'GraGR++':
                gragr_plus_wins += 1
        
        print(f"GraGR Core wins: {gragr_core_wins}/{total_datasets} datasets")
        print(f"GraGR++ wins: {gragr_plus_wins}/{total_datasets} datasets")
        
        if gragr_core_wins > 0 or gragr_plus_wins > 0:
            print("✅ GraGR is performing well!")
        else:
            print("❌ GraGR needs improvement")

def main():
    """Run comprehensive multi-task benchmark."""
    runner = ComprehensiveBenchmarkRunner()
    runner.run_comprehensive_benchmark()

if __name__ == "__main__":
    main()
