"""
Quick Benchmark Runner for GraGR vs Other Gradient Conflict Resolution Methods

This script runs a focused benchmark comparing GraGR with:
- MGDA (Multiple Gradient Descent Algorithm)
- PCGrad (Projecting Conflicting Gradients) 
- CAGrad (Conflict-Averse Gradient Descent)
- GradNorm (Gradient Normalization)
- Vanilla Average (baseline)

References:
[1] Sener, O., & Koltun, V. (2018). Multi-task learning as multi-objective optimization. NeurIPS.
[2] Yu, T., et al. (2020). Gradient surgery for multi-task learning. NeurIPS.
[3] Chen, Z., et al. (2018). GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks. ICML.
[4] Liu, B., et al. (2021). Conflict-averse gradient descent for multi-task learning. NeurIPS.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid, WebKB
from torch_geometric.transforms import NormalizeFeatures
import warnings
warnings.filterwarnings('ignore')

# Import GraGR models
from gragr_complete import GraGRCore, GraGRPlusPlus, BaselineGCN

class SimpleBenchmarkRunner:
    """Simple benchmark runner for gradient conflict resolution methods."""
    
    def __init__(self, device='cpu'):
        self.device = device
        
    def load_dataset(self, dataset_name):
        """Load dataset for benchmarking."""
        transform = NormalizeFeatures()
        
        if dataset_name.lower() in ['cora', 'citeseer', 'pubmed']:
            dataset = Planetoid(root=f'data/{dataset_name}', name=dataset_name, transform=transform)
        elif dataset_name.lower() in ['texas', 'cornell', 'wisconsin']:
            dataset = WebKB(root=f'data/{dataset_name}', name=dataset_name, transform=transform)
        else:
            raise ValueError(f"Unknown dataset: {dataset_name}")
            
        return dataset[0].to(self.device)
    
    def create_multi_task_dataset(self, data, num_tasks=3):
        """Create multi-task dataset for gradient conflict resolution."""
        num_nodes = data.x.size(0)
        
        # Task 1: Original node classification
        task1_labels = data.y.clone()
        
        # Task 2: Degree-based classification
        degree = torch.zeros(num_nodes, dtype=torch.long)
        for i in range(num_nodes):
            degree[i] = (data.edge_index[0] == i).sum()
        
        # Bin degrees into classes
        num_classes = len(torch.unique(data.y))
        degree_bins = torch.linspace(degree.min().float(), degree.max().float(), num_classes + 1)
        task2_labels = torch.bucketize(degree.float(), degree_bins) - 1
        task2_labels = torch.clamp(task2_labels, 0, num_classes - 1)
        
        # Task 3: Feature-based classification
        feature_sum = data.x[:, :5].sum(dim=1)
        feature_bins = torch.linspace(feature_sum.min(), feature_sum.max(), num_classes + 1)
        task3_labels = torch.bucketize(feature_sum, feature_bins) - 1
        task3_labels = torch.clamp(task3_labels, 0, num_classes - 1)
        
        # Store multi-task labels
        data.task_labels = torch.stack([task1_labels, task2_labels, task3_labels], dim=1)
        data.num_tasks = num_tasks
        
        return data
    
    def run_benchmark(self, datasets=['cora', 'citeseer'], epochs=20):
        """Run benchmark across datasets and methods."""
        methods = ['Vanilla Average', 'MGDA', 'PCGrad', 'CAGrad', 'GraGR Core', 'GraGR++']
        
        print("🚀 Starting Gradient Conflict Resolution Benchmark")
        print("="*70)
        
        all_results = {}
        
        for dataset_name in datasets:
            print(f"\n📊 Benchmarking on {dataset_name.upper()}")
            print("-" * 40)
            
            # Load and prepare dataset
            data = self.load_dataset(dataset_name)
            data = self.create_multi_task_dataset(data, num_tasks=3)
            
            dataset_results = {}
            
            for method in methods:
                print(f"\n🔬 Testing {method}...")
                
                if method in ['GraGR Core', 'GraGR++']:
                    # Use GraGR models
                if method == 'GraGR Core':
                    model = GraGRCore('gcn', data.x.size(1), 64, len(torch.unique(data.y)), 
                                    num_tasks=data.num_tasks, dataset_name=dataset_name).to(self.device)
                else:
                    model = GraGRPlusPlus('gcn', data.x.size(1), 64, len(torch.unique(data.y)), 
                                        num_tasks=data.num_tasks, dataset_name=dataset_name).to(self.device)
                    
                    results = self.run_gragr_model(model, data, epochs)
                else:
                    # Use baseline GCN with different optimizers
                    model = BaselineGCN(data.x.size(1), 64, len(torch.unique(data.y)), 
                                      num_tasks=data.num_tasks).to(self.device)
                    results = self.run_model_with_optimizer(model, data, method, epochs)
                
                dataset_results[method] = results
                print(f"✅ {method}: Test Acc: {results['best_test_acc']:.4f}")
            
            all_results[dataset_name] = dataset_results
        
        # Generate comparison table
        self.generate_comparison_table(all_results)
        return all_results
    
    def run_gragr_model(self, model, data, epochs):
        """Run GraGR model."""
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        
        best_val_acc = 0.0
        best_test_acc = 0.0
        conflict_percentages = []
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            
            # Forward pass with reasoning
            logits, signals = model.forward_with_reasoning(data.x, data.edge_index, epoch, epochs)
            
            # Use first task for evaluation
            if isinstance(logits, list):
                task_logits = logits[0]
            else:
                task_logits = logits
            
            # Compute loss
            loss = F.cross_entropy(task_logits[data.train_mask], data.y[data.train_mask])
            
            # Add conflict loss if available
            if 'enhanced_conflict_loss' in signals:
                loss += 0.1 * signals['enhanced_conflict_loss']
            
            loss.backward()
            optimizer.step()
            
            # Evaluate
            model.eval()
            with torch.no_grad():
                val_logits = model(data.x, data.edge_index)
                
                if isinstance(val_logits, list):
                    val_logits = val_logits[0]
                
                val_pred = val_logits[data.val_mask].argmax(dim=1)
                val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
                
                test_pred = val_logits[data.test_mask].argmax(dim=1)
                test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
            
            model.train()
            
            # Track best performance
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            
            # Get conflict percentage from signals
            if 'initial_conflict_pct' in signals:
                conflict_percentages.append(signals['initial_conflict_pct'])
        
        return {
            'best_val_acc': best_val_acc,
            'best_test_acc': best_test_acc,
            'avg_conflict_pct': np.mean(conflict_percentages) if conflict_percentages else 0
        }
    
    def run_model_with_optimizer(self, model, data, optimizer_name, epochs):
        """Run model with specified gradient conflict resolution method."""
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        
        best_val_acc = 0.0
        best_test_acc = 0.0
        conflict_percentages = []
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            
            # Forward pass
            logits = model(data.x, data.edge_index)
            
            # Compute losses for each task
            losses = []
            gradients = []
            
            for task_idx in range(data.num_tasks):
                task_labels = data.task_labels[:, task_idx]
                task_mask = data.train_mask
                
                if isinstance(logits, list):
                    task_logits = logits[task_idx]
                else:
                    task_logits = logits
                
                task_loss = F.cross_entropy(task_logits[task_mask], task_labels[task_mask])
                losses.append(task_loss)
                
                # Compute gradients
                task_grad = torch.autograd.grad(task_loss, model.parameters(), 
                                              retain_graph=True, create_graph=True)
                gradients.append(task_grad)
            
            # Apply gradient conflict resolution
            if optimizer_name == 'MGDA':
                # Simple MGDA: Use gradient with minimum norm
                grad_norms = [torch.norm(g[0]) for g in gradients]
                min_idx = torch.argmin(torch.stack(grad_norms))
                for param, grad in zip(model.parameters(), gradients[min_idx]):
                    param.grad = grad
                    
            elif optimizer_name == 'PCGrad':
                # Simple PCGrad: Project conflicting gradients
                grad_vecs = [g[0].flatten() for g in gradients]
                projected_grads = []
                
                for i, grad_i in enumerate(grad_vecs):
                    projected_grad = grad_i.clone()
                    for j, grad_j in enumerate(grad_vecs):
                        if i != j:
                            cos_sim = torch.dot(grad_i, grad_j) / (
                                torch.norm(grad_i) * torch.norm(grad_j) + 1e-8
                            )
                            if cos_sim < 0:  # Conflicting
                                projection = torch.dot(grad_i, grad_j) / (
                                    torch.norm(grad_j) ** 2 + 1e-8
                                )
                                projected_grad = projected_grad - projection * grad_j
                    projected_grads.append(projected_grad.reshape(gradients[0][0].shape))
                
                # Average projected gradients
                avg_grad = torch.stack(projected_grads).mean(dim=0)
                for param, grad in zip(model.parameters(), gradients[0]):
                    param.grad = avg_grad
                    
            elif optimizer_name == 'CAGrad':
                # Simple CAGrad: Use gradient closest to average
                grad_vecs = [g[0].flatten() for g in gradients]
                avg_grad = torch.stack(grad_vecs).mean(dim=0)
                
                # Find gradient closest to average
                distances = [torch.norm(grad - avg_grad) for grad in grad_vecs]
                closest_idx = torch.argmin(torch.stack(distances))
                
                for param, grad in zip(model.parameters(), gradients[closest_idx]):
                    param.grad = grad
                    
            else:  # Vanilla Average
                # Simple average of gradients
                total_loss = sum(losses)
                total_loss.backward()
            
            optimizer.step()
            
            # Evaluate
            model.eval()
            with torch.no_grad():
                val_logits = model(data.x, data.edge_index)
                
                if isinstance(val_logits, list):
                    val_logits = val_logits[0]
                
                val_pred = val_logits[data.val_mask].argmax(dim=1)
                val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
                
                test_pred = val_logits[data.test_mask].argmax(dim=1)
                test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
            
            model.train()
            
            # Track best performance
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            
            # Compute conflict percentage
            if len(gradients) > 1:
                conflicts = 0
                total_pairs = 0
                for i in range(len(gradients)):
                    for j in range(i + 1, len(gradients)):
                        cos_sim = torch.dot(gradients[i][0].flatten(), gradients[j][0].flatten()) / (
                            torch.norm(gradients[i][0].flatten()) * torch.norm(gradients[j][0].flatten()) + 1e-8
                        )
                        if cos_sim < 0:
                            conflicts += 1
                        total_pairs += 1
                
                conflict_pct = (conflicts / total_pairs) * 100 if total_pairs > 0 else 0
                conflict_percentages.append(conflict_pct)
        
        return {
            'best_val_acc': best_val_acc,
            'best_test_acc': best_test_acc,
            'avg_conflict_pct': np.mean(conflict_percentages) if conflict_percentages else 0
        }
    
    def generate_comparison_table(self, all_results):
        """Generate comparison table."""
        print("\n" + "="*80)
        print("🏆 GRADIENT CONFLICT RESOLUTION BENCHMARK RESULTS")
        print("="*80)
        
        methods = ['Vanilla Average', 'MGDA', 'PCGrad', 'CAGrad', 'GraGR Core', 'GraGR++']
        
        # Print per-dataset results
        for dataset_name, dataset_results in all_results.items():
            print(f"\n📈 {dataset_name.upper()} - Performance Comparison")
            print("-" * 60)
            print(f"{'Method':<15} {'Test Acc':<10} {'Val Acc':<10} {'Conflict %':<12}")
            print("-" * 60)
            
            for method in methods:
                if method in dataset_results:
                    results = dataset_results[method]
                    print(f"{method:<15} {results['best_test_acc']:<10.4f} {results['best_val_acc']:<10.4f} "
                          f"{results['avg_conflict_pct']:<12.1f}")
        
        # Print average results
        print(f"\n📊 AVERAGE PERFORMANCE ACROSS ALL DATASETS")
        print("-" * 60)
        print(f"{'Method':<15} {'Avg Test Acc':<15} {'Avg Val Acc':<15} {'Avg Conflict %':<15}")
        print("-" * 60)
        
        method_averages = {}
        for method in methods:
            test_accs = []
            val_accs = []
            conflict_pcts = []
            
            for dataset_results in all_results.values():
                if method in dataset_results:
                    results = dataset_results[method]
                    test_accs.append(results['best_test_acc'])
                    val_accs.append(results['best_val_acc'])
                    conflict_pcts.append(results['avg_conflict_pct'])
            
            if test_accs:
                method_averages[method] = {
                    'avg_test_acc': np.mean(test_accs),
                    'avg_val_acc': np.mean(val_accs),
                    'avg_conflict_pct': np.mean(conflict_pcts)
                }
                
                print(f"{method:<15} {np.mean(test_accs):<15.4f} {np.mean(val_accs):<15.4f} "
                      f"{np.mean(conflict_pcts):<15.1f}")
        
        # Print ranking
        print(f"\n🏅 METHOD RANKING BY TEST ACCURACY")
        print("-" * 40)
        sorted_methods = sorted(method_averages.items(), key=lambda x: x[1]['avg_test_acc'], reverse=True)
        for i, (method, metrics) in enumerate(sorted_methods, 1):
            print(f"{i}. {method:<15} - {metrics['avg_test_acc']:.4f}")
        
        print(f"\n🎯 CONFLICT REDUCTION RANKING")
        print("-" * 40)
        sorted_conflict = sorted(method_averages.items(), key=lambda x: x[1]['avg_conflict_pct'])
        for i, (method, metrics) in enumerate(sorted_conflict, 1):
            print(f"{i}. {method:<15} - {metrics['avg_conflict_pct']:.1f}% conflicts")


def main():
    """Run benchmark."""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"🖥️  Using device: {device}")
    
    runner = SimpleBenchmarkRunner(device=device)
    datasets = ['cora', 'citeseer', 'pubmed']
    
    results = runner.run_benchmark(datasets, epochs=20)
    
    print("\n🎉 Benchmark completed!")
    return results


if __name__ == "__main__":
    main()
