#!/usr/bin/env python3
"""
GraGR Dataset Download Script
============================

This script downloads all datasets used in the GraGR research.
"""

import os
import sys
import argparse
import torch
from torch_geometric.datasets import Planetoid, WikiCS, WebKB, TUDataset, Coauthor, Amazon, Reddit
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.graphproppred import PygGraphPropPredDataset
import warnings
warnings.filterwarnings('ignore')

def download_citation_networks(output_dir):
    """Download citation network datasets."""
    print("📚 Downloading Citation Networks...")
    
    datasets = ['Cora', 'CiteSeer', 'PubMed']
    for dataset_name in datasets:
        try:
            print(f"  → Downloading {dataset_name}...")
            dataset = Planetoid(root=output_dir, name=dataset_name)
            print(f"    ✓ {dataset_name} downloaded successfully")
        except Exception as e:
            print(f"    ✗ Error downloading {dataset_name}: {e}")

def download_webkb_datasets(output_dir):
    """Download WebKB datasets."""
    print("🌐 Downloading WebKB Datasets...")
    
    datasets = ['Texas', 'Cornell', 'Wisconsin']
    for dataset_name in datasets:
        try:
            print(f"  → Downloading {dataset_name}...")
            dataset = WebKB(root=output_dir, name=dataset_name)
            print(f"    ✓ {dataset_name} downloaded successfully")
        except Exception as e:
            print(f"    ✗ Error downloading {dataset_name}: {e}")

def download_structural_graphs(output_dir):
    """Download structural graph datasets."""
    print("📊 Downloading Structural Graphs...")
    
    try:
        print("  → Downloading WikiCS...")
        dataset = WikiCS(root=output_dir)
        print("    ✓ WikiCS downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading WikiCS: {e}")

def download_molecular_datasets(output_dir):
    """Download molecular datasets."""
    print("🧬 Downloading Molecular Datasets...")
    
    try:
        print("  → Downloading OGB-MolHIV...")
        dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root=output_dir)
        print("    ✓ OGB-MolHIV downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading OGB-MolHIV: {e}")

def download_tu_datasets(output_dir):
    """Download TU datasets."""
    print("🔬 Downloading TU Datasets...")
    
    datasets = ['MUTAG', 'PROTEINS', 'ENZYMES', 'NCI1', 'NCI109']
    for dataset_name in datasets:
        try:
            print(f"  → Downloading {dataset_name}...")
            dataset = TUDataset(root=output_dir, name=dataset_name)
            print(f"    ✓ {dataset_name} downloaded successfully")
        except Exception as e:
            print(f"    ✗ Error downloading {dataset_name}: {e}")

def download_additional_datasets(output_dir):
    """Download additional datasets for comprehensive evaluation."""
    print("🔍 Downloading Additional Datasets...")
    
    # Additional citation networks
    try:
        print("  → Downloading Coauthor CS...")
        dataset = Coauthor(root=output_dir, name='CS')
        print("    ✓ Coauthor CS downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading Coauthor CS: {e}")
    
    try:
        print("  → Downloading Coauthor Physics...")
        dataset = Coauthor(root=output_dir, name='Physics')
        print("    ✓ Coauthor Physics downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading Coauthor Physics: {e}")
    
    # Amazon datasets
    try:
        print("  → Downloading Amazon Computers...")
        dataset = Amazon(root=output_dir, name='Computers')
        print("    ✓ Amazon Computers downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading Amazon Computers: {e}")
    
    try:
        print("  → Downloading Amazon Photo...")
        dataset = Amazon(root=output_dir, name='Photo')
        print("    ✓ Amazon Photo downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading Amazon Photo: {e}")
    
    # Reddit dataset
    try:
        print("  → Downloading Reddit...")
        dataset = Reddit(root=output_dir)
        print("    ✓ Reddit downloaded successfully")
    except Exception as e:
        print(f"    ✗ Error downloading Reddit: {e}")
    
    # Additional TU datasets
    additional_tu_datasets = ['COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']
    for dataset_name in additional_tu_datasets:
        try:
            print(f"  → Downloading {dataset_name}...")
            dataset = TUDataset(root=output_dir, name=dataset_name)
            print(f"    ✓ {dataset_name} downloaded successfully")
        except Exception as e:
            print(f"    ✗ Error downloading {dataset_name}: {e}")

def main():
    parser = argparse.ArgumentParser(description='Download GraGR datasets')
    parser.add_argument('--output_dir', type=str, default='../processed',
                       help='Output directory for datasets')
    parser.add_argument('--datasets', nargs='+', 
                       choices=['citation', 'webkb', 'structural', 'molecular', 'tu', 'additional', 'all'],
                       default=['all'],
                       help='Which datasets to download')
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    print("🚀 GraGR Dataset Downloader")
    print("=" * 50)
    print(f"Output directory: {args.output_dir}")
    print(f"Datasets to download: {args.datasets}")
    print("=" * 50)
    
    if 'all' in args.datasets or 'citation' in args.datasets:
        download_citation_networks(args.output_dir)
    
    if 'all' in args.datasets or 'webkb' in args.datasets:
        download_webkb_datasets(args.output_dir)
    
    if 'all' in args.datasets or 'structural' in args.datasets:
        download_structural_graphs(args.output_dir)
    
    if 'all' in args.datasets or 'molecular' in args.datasets:
        download_molecular_datasets(args.output_dir)
    
    if 'all' in args.datasets or 'tu' in args.datasets:
        download_tu_datasets(args.output_dir)
    
    if 'all' in args.datasets or 'additional' in args.datasets:
        download_additional_datasets(args.output_dir)
    
    print("\n✅ Dataset download completed!")
    print(f"📁 Datasets saved to: {args.output_dir}")

if __name__ == "__main__":
    main()
