#!/usr/bin/env python3
"""
GraGR Quick Start Demo
=====================

This script demonstrates the core functionality of GraGR and GraGR++ models
on a simple example with the Cora dataset.

Usage:
    python demo.py [--model MODEL] [--epochs EPOCHS] [--dataset DATASET]
"""

import os
import sys
import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.data import Data
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path to import GraGR modules
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

from src.core.gragr_complete import (
    GraGRCore, GraGRPlusPlus, BaselineGCN, BaselineGAT, BaselineGIN, BaselineSAGE,
    set_seed, compute_metrics
)

def load_dataset(dataset_name='Cora'):
    """Load a dataset for demonstration."""
    print(f"📊 Loading {dataset_name} dataset...")
    
    try:
        if dataset_name.lower() == 'cora':
            dataset = Planetoid(root='../../datasets/processed', name='Cora')
        elif dataset_name.lower() == 'citeseer':
            dataset = Planetoid(root='../../datasets/processed', name='CiteSeer')
        elif dataset_name.lower() == 'pubmed':
            dataset = Planetoid(root='../../datasets/processed', name='PubMed')
        else:
            raise ValueError(f"Unsupported dataset: {dataset_name}")
        
        data = dataset[0]
        print(f"    ✓ {dataset_name} loaded: {data.x.size(0)} nodes, {data.edge_index.size(1)} edges")
        return data
    except Exception as e:
        print(f"    ✗ Error loading {dataset_name}: {e}")
        return None

def train_model(model, data, epochs=50, lr=0.01, weight_decay=5e-4):
    """Train a model and return training history."""
    print(f"🚀 Training {model.__class__.__name__}...")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    train_losses = []
    val_accuracies = []
    test_accuracies = []
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass
        if hasattr(model, 'forward_with_reasoning'):
            logits, signals = model.forward_with_reasoning(
                data.x, data.edge_index, epoch=epoch, total_epochs=epochs
            )
        else:
            logits = model(data.x, data.edge_index)
        
        # Compute loss
        if isinstance(logits, list):
            loss = sum(F.cross_entropy(logit[data.train_mask], data.y[data.train_mask]) 
                      for logit in logits) / len(logits)
        else:
            loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
        
        # Add conflict loss if available
        if hasattr(model, 'forward_with_reasoning') and 'conflict_loss' in signals:
            loss = loss + 0.1 * signals['conflict_loss']
        
        loss.backward()
        optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            if hasattr(model, 'forward_with_reasoning'):
                val_logits, _ = model.forward_with_reasoning(
                    data.x, data.edge_index, epoch=epoch, total_epochs=epochs
                )
            else:
                val_logits = model(data.x, data.edge_index)
            
            if isinstance(val_logits, list):
                val_logits = val_logits[0]
            
            val_pred = val_logits[data.val_mask].argmax(dim=1)
            val_acc = (val_pred == data.y[data.val_mask]).float().mean().item()
            
            test_pred = val_logits[data.test_mask].argmax(dim=1)
            test_acc = (test_pred == data.y[data.test_mask]).float().mean().item()
        
        train_losses.append(loss.item())
        val_accuracies.append(val_acc)
        test_accuracies.append(test_acc)
        
        if epoch % 10 == 0:
            print(f"    Epoch {epoch:3d}: Loss={loss.item():.4f}, Val Acc={val_acc:.4f}, Test Acc={test_acc:.4f}")
    
    return {
        'train_losses': train_losses,
        'val_accuracies': val_accuracies,
        'test_accuracies': test_accuracies,
        'final_val_acc': val_accuracies[-1],
        'final_test_acc': test_accuracies[-1]
    }

def main():
    parser = argparse.ArgumentParser(description='GraGR Quick Start Demo')
    parser.add_argument('--model', type=str, default='gragr_core',
                       choices=['gragr_core', 'gragr_plus', 'baseline_gcn', 'baseline_gat'],
                       help='Model to use for demo')
    parser.add_argument('--dataset', type=str, default='Cora',
                       choices=['Cora', 'CiteSeer', 'PubMed'],
                       help='Dataset to use for demo')
    parser.add_argument('--epochs', type=int, default=50,
                       help='Number of training epochs')
    parser.add_argument('--compare', action='store_true',
                       help='Compare multiple models')
    
    args = parser.parse_args()
    
    # Set random seed for reproducibility
    set_seed(42)
    
    print("🚀 GraGR Quick Start Demo")
    print("=" * 50)
    print(f"Model: {args.model}")
    print(f"Dataset: {args.dataset}")
    print(f"Epochs: {args.epochs}")
    print("=" * 50)
    
    # Load dataset
    data = load_dataset(args.dataset)
    if data is None:
        print("❌ Failed to load dataset. Exiting.")
        return
    
    num_features = data.x.size(1)
    num_classes = data.y.max().item() + 1
    num_nodes = data.x.size(0)
    
    print(f"📊 Dataset Info:")
    print(f"    Nodes: {num_nodes}")
    print(f"    Features: {num_features}")
    print(f"    Classes: {num_classes}")
    print(f"    Edges: {data.edge_index.size(1)}")
    
    if args.compare:
        # Compare multiple models
        print("\n🔄 Comparing Multiple Models...")
        
        models = {
            'Baseline GCN': BaselineGCN(num_features, 64, num_classes),
            'GraGR Core': GraGRCore('gcn', num_features, 64, num_classes, num_nodes, dataset_name=args.dataset.lower()),
            'GraGR++': GraGRPlusPlus('gcn', num_features, 64, num_classes, num_nodes, dataset_name=args.dataset.lower())
        }
        
        results = {}
        for model_name, model in models.items():
            print(f"\n{'='*30}")
            print(f"Training {model_name}")
            print(f"{'='*30}")
            
            result = train_model(model, data, epochs=args.epochs)
            results[model_name] = result
            
            print(f"✅ {model_name} completed:")
            print(f"    Final Validation Accuracy: {result['final_val_acc']:.4f}")
            print(f"    Final Test Accuracy: {result['final_test_acc']:.4f}")
        
        # Print summary table
        print(f"\n📊 Final Results Summary:")
        print(f"{'Model':<15} {'Val Acc':<10} {'Test Acc':<10}")
        print(f"{'-'*35}")
        for model_name, result in results.items():
            print(f"{model_name:<15} {result['final_val_acc']:<10.4f} {result['final_test_acc']:<10.4f}")
    
    else:
        # Single model demo
        print(f"\n🔧 Creating {args.model} model...")
        
        if args.model == 'gragr_core':
            model = GraGRCore('gcn', num_features, 64, num_classes, num_nodes, dataset_name=args.dataset.lower())
        elif args.model == 'gragr_plus':
            model = GraGRPlusPlus('gcn', num_features, 64, num_classes, num_nodes, dataset_name=args.dataset.lower())
        elif args.model == 'baseline_gcn':
            model = BaselineGCN(num_features, 64, num_classes)
        elif args.model == 'baseline_gat':
            model = BaselineGAT(num_features, 64, num_classes)
        else:
            print(f"❌ Unknown model: {args.model}")
            return
        
        # Train model
        result = train_model(model, data, epochs=args.epochs)
        
        print(f"\n✅ Training completed!")
        print(f"    Final Validation Accuracy: {result['final_val_acc']:.4f}")
        print(f"    Final Test Accuracy: {result['final_test_acc']:.4f}")
    
    print(f"\n🎉 Demo completed successfully!")

if __name__ == "__main__":
    main()
