#!/usr/bin/env python3
"""

Usage:
    python setup_data.py --all
    python setup_data.py --dataset tofu --scenario forget05
    python setup_data.py --dataset wmdp --domain bio
"""

import argparse
import os
import sys
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
import yaml

import torch
import torchvision
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import load_dataset
import requests


class DatasetDownloader:
    """Comprehensive dataset downloader for OFMU experiments."""
    
    def __init__(self, data_root: str = "./data", cache_dir: str = "./cache"):
        self.data_root = Path(data_root)
        self.cache_dir = Path(cache_dir)
        self.data_root.mkdir(exist_ok=True)
        self.cache_dir.mkdir(exist_ok=True)
        
        self.setup_logging()
        self.load_dataset_configs()
    
    def setup_logging(self):
        """Setup logging configuration."""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('data_setup.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
    
    def load_dataset_configs(self):
        """Load dataset configurations."""
        config_path = Path(__file__).parent / "config" / "datasets.yaml"
        if config_path.exists():
            with open(config_path, 'r') as f:
                self.config = yaml.safe_load(f)
        else:
            self.logger.warning("Dataset config not found, using defaults")
            self.config = self.get_default_config()
    
    def get_default_config(self) -> Dict:
        """Get default dataset configuration."""
        return {
            "datasets": {
                "tofu": {
                    "base_repo": "locuslab/TOFU",
                    "scenarios": {
                        "forget01": {"repo_id": "locuslab/TOFU", "filename": "forget01.json"},
                        "forget05": {"repo_id": "locuslab/TOFU", "filename": "forget05.json"},
                        "forget10": {"repo_id": "locuslab/TOFU", "filename": "forget10.json"}
                    }
                },
                "wmdp": {
                    "domains": {
                        "bio": {"repo_id": "cais/wmdp", "filename": "wmdp-bio.jsonl"},
                        "cyber": {"repo_id": "cais/wmdp", "filename": "wmdp-cyber.jsonl"},
                        "chem": {"repo_id": "cais/wmdp", "filename": "wmdp-chem.jsonl"}
                    }
                }
            }
        }

    def download_tofu_dataset(self, scenario: Optional[str] = None) -> bool:
        """Download TOFU dataset from Hugging Face."""
        self.logger.info("Downloading TOFU dataset...")
        
        tofu_dir = self.data_root / "tofu"
        tofu_dir.mkdir(exist_ok=True)
        
        try:
            # Download main TOFU dataset
            dataset = load_dataset("locuslab/TOFU", split="train")
            
            # Save full dataset
            dataset.to_json(tofu_dir / "full_dataset.json")
            
            # Create forget scenarios if they don't exist
            scenarios = ["forget01", "forget05", "forget10"] if scenario is None else [scenario]
            
            for scenario_name in scenarios:
                scenario_file = tofu_dir / f"{scenario_name}.json"
                if not scenario_file.exists():
                    self.create_tofu_scenario(dataset, scenario_name, scenario_file)
            
            # Download evaluation data
            try:
                eval_dataset = load_dataset("locuslab/TOFU", "evaluation", split="test")
                eval_dataset.to_json(tofu_dir / "eval_data.json")
            except Exception as e:
                self.logger.warning(f"Could not download TOFU evaluation data: {e}")
            
            self.logger.info("TOFU dataset downloaded successfully")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to download TOFU dataset: {e}")
    
    def create_tofu_scenario(self, dataset, scenario: str, output_file: Path):
        """Create TOFU forget scenario from full dataset."""
        forget_ratios = {"forget01": 0.01, "forget05": 0.05, "forget10": 0.10}
        forget_ratio = forget_ratios.get(scenario, 0.05)
        
        total_samples = len(dataset)
        forget_size = int(total_samples * forget_ratio)
        
        # Split dataset
        forget_indices = list(range(forget_size))
        retain_indices = list(range(forget_size, total_samples))
        
        scenario_data = {
            "forget_set": [dataset[i] for i in forget_indices],
            "retain_set": [dataset[i] for i in retain_indices],
            "metadata": {
                "total_samples": total_samples,
                "forget_ratio": forget_ratio,
                "forget_size": forget_size,
                "retain_size": len(retain_indices)
            }
        }
        
        with open(output_file, 'w') as f:
            json.dump(scenario_data, f, indent=2)
        
        self.logger.info(f"Created {scenario} scenario with {forget_size} forget samples")

    def download_wmdp_dataset(self, domain: Optional[str] = None) -> bool:
        """Download WMDP dataset from Hugging Face."""
        self.logger.info("Downloading WMDP dataset...")
        
        wmdp_dir = self.data_root / "wmdp"
        wmdp_dir.mkdir(exist_ok=True)
        
        domains = ["bio", "cyber", "chem"] if domain is None else [domain]
        
        success = True
        for domain_name in domains:
            try:
                # Download main WMDP data
                dataset = load_dataset("cais/wmdp", f"wmdp-{domain_name}", split="test")
                dataset.to_json(wmdp_dir / f"wmdp_{domain_name}_test.json")
                
                # Try to download corpora for unlearning
                try:
                    corpora = load_dataset("cais/wmdp-corpora", f"{domain_name}-forget-corpus")
                    corpora.to_json(wmdp_dir / f"{domain_name}_forget_corpus.json")
                except Exception as e:
                    self.logger.warning(f"Could not download {domain_name} corpora: {e}")
                
                self.logger.info(f"Downloaded WMDP {domain_name} dataset")
                
            except Exception as e:
                self.logger.error(f"Failed to download WMDP {domain_name}: {e}")
                success = False
        
        return success
    

    def download_cifar_datasets(self) -> bool:
        """Download CIFAR datasets using torchvision."""
        self.logger.info("Setting up CIFAR datasets...")
        
        cifar_dir = self.data_root / "cifar"
        cifar_dir.mkdir(exist_ok=True)
        
        try:
            # CIFAR datasets will be downloaded automatically by torchvision
            # Just create info files
            cifar10_info = {
                "name": "CIFAR-10",
                "source": "torchvision.datasets.CIFAR10",
                "num_classes": 10,
                "train_samples": 50000,
                "test_samples": 10000,
                "image_size": [32, 32, 3],
                "classes": ["airplane", "automobile", "bird", "cat", "deer", 
                          "dog", "frog", "horse", "ship", "truck"]
            }
            
            cifar100_info = {
                "name": "CIFAR-100", 
                "source": "torchvision.datasets.CIFAR100",
                "num_classes": 100,
                "train_samples": 50000,
                "test_samples": 10000,
                "image_size": [32, 32, 3]
            }
            
            with open(cifar_dir / "cifar10_info.json", 'w') as f:
                json.dump(cifar10_info, f, indent=2)
            
            with open(cifar_dir / "cifar100_info.json", 'w') as f:
                json.dump(cifar100_info, f, indent=2)
            
            self.logger.info("CIFAR dataset info created")
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to setup CIFAR datasets: {e}")
            return False
    
  
    def verify_datasets(self) -> Dict[str, bool]:
        """Verify that all datasets are properly downloaded."""
        verification_results = {}
        
        # Check TOFU
        tofu_dir = self.data_root / "tofu"
        tofu_files = ["forget01.json", "forget05.json", "forget10.json", "eval_data.json"]
        verification_results["tofu"] = all((tofu_dir / f).exists() for f in tofu_files)
        
        # Check WMDP
        wmdp_dir = self.data_root / "wmdp"
        wmdp_files = ["wmdp_bio_test.json", "wmdp_cyber_test.json", "wmdp_chem_test.json"]
        verification_results["wmdp"] = any((wmdp_dir / f).exists() for f in wmdp_files)
        
        # Check MUSE
        muse_dir = self.data_root / "muse"
        verification_results["muse"] = (muse_dir / "muse_eval.json").exists()
        
        # Check IDK
        verification_results["idk"] = (self.data_root / "idk.jsonl").exists()
        
        # Check CIFAR
        cifar_dir = self.data_root / "cifar"
        verification_results["cifar"] = (cifar_dir / "cifar10_info.json").exists()
        
        return verification_results
    
    def print_summary(self):
        """Print summary of dataset setup."""
        verification = self.verify_datasets()
        
        print("\n" + "="*70)
        print("DATASET SETUP SUMMARY")
        print("="*70)
        
        for dataset, status in verification.items():
            status_icon = "✅" if status else "❌"
            print(f"{dataset.upper():<10} {status_icon}")
        
        print(f"\nData Directory: {self.data_root.absolute()}")
        print(f"Cache Directory: {self.cache_dir.absolute()}")
        
        if all(verification.values()):
            print("\n🎉 All datasets ready for experiments!")
        else:
            missing = [k for k, v in verification.items() if not v]
            print(f"\n⚠️  Missing datasets: {', '.join(missing)}")
        
        print("="*70)


def main():
    parser = argparse.ArgumentParser(description="Download and setup datasets for OFMU experiments")
    
    # Dataset selection
    parser.add_argument("--dataset", choices=["tofu", "wmdp", "muse", "idk", "cifar", "all"],
                       default="all", help="Which dataset to download")
    parser.add_argument("--scenario", type=str, help="TOFU scenario (forget01, forget05, forget10)")
    parser.add_argument("--domain", type=str, help="WMDP domain (bio, cyber, chem)")
    
    # Paths
    parser.add_argument("--data_root", type=str, default="./data",
                       help="Root directory for datasets")
    parser.add_argument("--cache_dir", type=str, default="./cache",
                       help="Cache directory for downloads")
    
    # Options
    parser.add_argument("--verify_only", action="store_true",
                       help="Only verify existing datasets")
    parser.add_argument("--force_download", action="store_true",
                       help="Force re-download even if datasets exist")
    
    args = parser.parse_args()
    
    # Initialize downloader
    downloader = DatasetDownloader(args.data_root, args.cache_dir)
    
    if args.verify_only:
        downloader.print_summary()
        return 0
    
    # Download datasets
    success = True
    try:
        if args.dataset == "all":
            success &= downloader.download_tofu_dataset()
            success &= downloader.download_wmdp_dataset()
            success &= downloader.download_cifar_datasets()
        elif args.dataset == "tofu":
            success = downloader.download_tofu_dataset(args.scenario)
        elif args.dataset == "wmdp":
            success = downloader.download_wmdp_dataset(args.domain)
        elif args.dataset == "cifar":
            success = downloader.download_cifar_datasets()
        
    except KeyboardInterrupt:
        print("\n⏹️  Dataset download interrupted by user")
        return 1
    except Exception as e:
        print(f"❌ Dataset download failed: {e}")
        return 1
    
    # Print summary
    downloader.print_summary()
    
    return 0 if success else 1


if __name__ == "__main__":
    exit(main())

