#!/usr/bin/env python3
"""
Usage:
    python setup_data.py --dataset tofu
    python setup_data.py --dataset wmdp --domain bio
"""

import os
import sys
import argparse
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional
import urllib.request
import zipfile
import tarfile

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))


class DatasetDownloader:
    """Handles downloading and setup of all experimental datasets."""
    
    def __init__(self, data_dir: str = "data"):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(exist_ok=True)
        self.setup_logging()
    
    def setup_logging(self):
        """Setup logging configuration."""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('data_setup.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def download_file(self, url: str, output_path: Path) -> bool:
        """Download a file from URL to output path."""
        try:
            self.logger.info(f"Downloading {url} to {output_path}")
            output_path.parent.mkdir(parents=True, exist_ok=True)
            
            urllib.request.urlretrieve(url, output_path)
            self.logger.info(f"Successfully downloaded {output_path}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to download {url}: {e}")
            return False
    
    def extract_archive(self, archive_path: Path, extract_to: Path) -> bool:
        """Extract zip or tar archive."""
        try:
            extract_to.mkdir(parents=True, exist_ok=True)
            
            if archive_path.suffix == '.zip':
                with zipfile.ZipFile(archive_path, 'r') as zip_ref:
                    zip_ref.extractall(extract_to)
            elif archive_path.suffix in ['.tar', '.gz', '.tgz']:
                with tarfile.open(archive_path, 'r:*') as tar_ref:
                    tar_ref.extractall(extract_to)
            else:
                self.logger.error(f"Unsupported archive format: {archive_path}")
                return False
            
            self.logger.info(f"Extracted {archive_path} to {extract_to}")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to extract {archive_path}: {e}")
            return False
    
    def setup_tofu_dataset(self) -> bool:
        """Setup TOFU dataset."""
        self.logger.info("Setting up TOFU dataset...")
        
        tofu_dir = self.data_dir / "tofu"
        tofu_dir.mkdir(exist_ok=True)
        
        # TOFU dataset URLs (these are placeholder URLs - replace with actual TOFU dataset URLs)
        tofu_urls = {
            "forget01.json": "SAMPLE_LINK_HERE",
            "forget05.json": "SAMPLE_LINK_HERE",
            "forget10.json": "SAMPLE_LINK_HERE"
        }
        
        # Try downloading from official sources
        success = True
        for filename, url in tofu_urls.items():
            output_path = tofu_dir / filename
            
            if not output_path.exists():
                if not self.download_file(url, output_path):
                    success = False
        
        if not success:
            self.logger.error("Official TOFU dataset not available")
        
        self.logger.info("TOFU dataset setup complete")
        return True
    
    def setup_wmdp_dataset(self, domain: Optional[str] = None) -> bool:
        """Setup WMDP dataset."""
        self.logger.info("Setting up WMDP dataset...")
        
        wmdp_dir = self.data_dir / "wmdp"
        wmdp_dir.mkdir(exist_ok=True)
        
        # WMDP dataset URLs (placeholder - replace with actual URLs)
        wmdp_urls = {
            "bio": "SAMPLE_LINK_HERE",
            "cyber": "SAMPLE_LINK_HERE",
            "chem": "SAMPLE_LINK_HERE"
        }
        
        domains_to_download = [domain] if domain else list(wmdp_urls.keys())
        
        success = True
        for domain_name in domains_to_download:
            if domain_name not in wmdp_urls:
                self.logger.error(f"Unknown WMDP domain: {domain_name}")
                continue
            
            url = wmdp_urls[domain_name]
            output_path = wmdp_dir / f"wmdp_{domain_name}_test.json"
            
            if not output_path.exists():
                if not self.download_file(url, output_path):
                  
                    success = False
        
        self.logger.info("WMDP dataset setup complete")
        return success

    def setup_cifar_datasets(self) -> bool:
        """Setup CIFAR datasets (handled automatically by torchvision)."""
        self.logger.info("CIFAR datasets will be downloaded automatically by torchvision")
        
        # Create cifar directory for any auxiliary files
        cifar_dir = self.data_dir / "cifar"
        cifar_dir.mkdir(exist_ok=True)
        
        # Create a simple info file
        info = {
            "note": "CIFAR-10 and CIFAR-100 datasets are automatically downloaded by torchvision",
            "datasets": ["CIFAR-10", "CIFAR-100"],
            "classes": {
                "cifar10": 10,
                "cifar100": 100
            },
            "train_samples": 50000,
            "test_samples": 10000
        }
        
        with open(cifar_dir / "info.json", 'w') as f:
            json.dump(info, f, indent=2)
        
        return True
    
    def verify_dataset_integrity(self) -> Dict[str, bool]:
        """Verify that all datasets are properly set up."""
        results = {}
        
        # Check TOFU
        tofu_dir = self.data_dir / "tofu"
        tofu_files = ["forget01.json", "forget05.json", "forget10.json"]
        results["tofu"] = all((tofu_dir / f).exists() for f in tofu_files)
        
        # Check WMDP
        wmdp_dir = self.data_dir / "wmdp"
        wmdp_files = ["wmdp_bio_test.json", "wmdp_cyber_test.json", "wmdp_chem_test.json"]
        results["wmdp"] = any((wmdp_dir / f).exists() for f in wmdp_files)
        
        # Check CIFAR
        cifar_dir = self.data_dir / "cifar"
        results["cifar"] = (cifar_dir / "info.json").exists()
        
        return results
    
    def print_setup_summary(self):
        """Print summary of dataset setup."""
        verification = self.verify_dataset_integrity()
        
        print("\n" + "="*60)
        print("DATASET SETUP SUMMARY")
        print("="*60)
        
        for dataset, status in verification.items():
            status_str = "✓ OK" if status else "✗ MISSING"
            print(f"{dataset.upper():<10} {status_str}")
        
        print("\nDataset Locations:")
        print(f"  TOFU:   {self.data_dir / 'tofu'}")
        print(f"  WMDP:   {self.data_dir / 'wmdp'}")
        print(f"  CIFAR:  {self.data_dir / 'cifar'}")
        
        print("\nNext Steps:")
        if all(verification.values()):
            print("  All datasets ready! You can now run experiments.")
        else:
            print("  Some datasets are missing. Check error messages above.")
        
        print("="*60)


def main():
    parser = argparse.ArgumentParser(description="Setup datasets for OFMU experiments")
    parser.add_argument("--dataset", choices=["tofu", "wmdp", "cifar", "all"], default="all",
                       help="Which dataset to setup")
    parser.add_argument("--domain", choices=["bio", "cyber", "chem"], 
                       help="WMDP domain (for --dataset wmdp)")
    parser.add_argument("--data_dir", type=str, default="data",
                       help="Directory to store datasets")
    parser.add_argument("--verify_only", action="store_true",
                       help="Only verify existing datasets")
    
    args = parser.parse_args()
    
    # Create downloader
    downloader = DatasetDownloader(args.data_dir)
    
    if args.verify_only:
        downloader.print_setup_summary()
        return
    
    # Setup datasets
    try:
        if args.dataset == "all":
            downloader.setup_tofu_dataset()
            downloader.setup_wmdp_dataset()
            downloader.setup_cifar_datasets()
        elif args.dataset == "tofu":
            downloader.setup_tofu_dataset()
        elif args.dataset == "wmdp":
            downloader.setup_wmdp_dataset(args.domain)
        elif args.dataset == "cifar":
            downloader.setup_cifar_datasets()
    
    except KeyboardInterrupt:
        print("\nSetup interrupted by user")
        return 1
    except Exception as e:
        print(f"Setup failed with error: {e}")
        return 1
    
    # Print summary
    downloader.print_setup_summary()
    return 0


if __name__ == "__main__":
    exit(main())