#!/usr/bin/env python3
"""
Script for downloading Parquet files from HuggingFace via direct links
FIXED - supports 302 redirects
"""

import os
import argparse
import requests
from pathlib import Path
from tqdm import tqdm
import json
from urllib.parse import urlparse
import time
import sys

sys.path.append(str(Path(__file__).parent.parent))
from settings.subsets import SUBSETS

def generate_download_url(subset_name, split):
    """
    Generate download URL for dataset split
    
    Args:
        subset_name (str): Name of the subset
        split (str): Split name (train/val/test)
    
    Returns:
        str: Download URL
    """
    base_url = "https://huggingface.co/datasets/pratyushmaini/llm_dataset_inference/resolve/main"
    filename = f"{split}-00000-of-00001.parquet"
    return f"{base_url}/{subset_name}/{filename}?download=true"

def download_file(url, output_path, chunk_size=8192, timeout=60):
    """
    Download file from given URL - FIXED for HuggingFace redirects
    
    Args:
        url (str): Download URL
        output_path (Path): Destination path
        chunk_size (int): Chunk size in bytes
        timeout (int): Timeout in seconds
    
    Returns:
        bool: True if success, False otherwise
    """
    try:
        # Configure session with proper headers
        session = requests.Session()
        session.headers.update({
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        })
        print(f"    🔍 Checking file availability...")
        head_response = session.head(url, timeout=timeout, allow_redirects=True)
        if head_response.status_code != 200:
            print(f"    ✗ File not available (HTTP {head_response.status_code})")
            return False
        file_size = int(head_response.headers.get('content-length', 0))
        if file_size == 0:
            print(f"    ⚠️  Could not determine file size")
        print(f"    📥 Starting download...")
        response = session.get(url, stream=True, timeout=timeout, allow_redirects=True)
        response.raise_for_status()
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, 'wb') as f:
            if file_size > 0:
                with tqdm(
                    total=file_size, 
                    unit='B', 
                    unit_scale=True, 
                    unit_divisor=1024,
                    desc=f"    {output_path.name}",
                    leave=False
                ) as pbar:
                    for chunk in response.iter_content(chunk_size=chunk_size):
                        if chunk:
                            f.write(chunk)
                            pbar.update(len(chunk))
            else:
                total_size = 0
                for chunk in response.iter_content(chunk_size=chunk_size):
                    if chunk:
                        f.write(chunk)
                        total_size += len(chunk)
                        if total_size % (1024*1024) == 0:
                            print(f"    📊 Downloaded: {total_size / (1024*1024):.1f} MB", end='\r')
        actual_size = output_path.stat().st_size
        size_mb = actual_size / (1024 * 1024)
        print(f"    ✅ Downloaded: {output_path.name} ({size_mb:.1f} MB)")
        if actual_size > 0:
            with open(output_path, 'rb') as f:
                magic_bytes = f.read(4)
                if magic_bytes != b'PAR1':
                    print(f"    ⚠️  Warning: File may not be a valid parquet (magic bytes: {magic_bytes})")
        return True
    except requests.exceptions.RequestException as e:
        print(f"    ✗ Download error: {e}")
        return False
    except Exception as e:
        print(f"    ✗ Unexpected error: {e}")
        return False

def download_subset(subset_name, output_dir, splits=["train", "val"], delay=2.0):
    """
    Download all files for a given subset
    
    Args:
        subset_name (str): Subset name
        output_dir (Path): Destination folder
        splits (list): List of splits to download
        delay (float): Delay between downloads in seconds
    
    Returns:
        dict: Download statistics
    """
    print(f"📁 Downloading subset: {subset_name}")
    
    subset_dir = output_dir / subset_name
    stats = {"total": len(splits), "success": 0, "failed": 0, "files": {}}
    
    for split in splits:
        url = generate_download_url(subset_name, split)
        filename = f"{split}-00000-of-00001.parquet"
        output_path = subset_dir / filename
        print(f"  📄 {split}.parquet")
        print(f"    🔗 URL: {url}")
        if output_path.exists():
            size_mb = output_path.stat().st_size / (1024 * 1024)
            if size_mb > 0.1:
                print(f"    ⏩ File already exists ({size_mb:.1f} MB)")
                stats["success"] += 1
                stats["files"][split] = {"status": "exists", "size_mb": size_mb}
                continue
            else:
                print(f"    🗑️  Removing invalid file ({size_mb:.1f} MB)")
                output_path.unlink()
        success = download_file(url, output_path)
        if success:
            stats["success"] += 1
            size_mb = output_path.stat().st_size / (1024 * 1024)
            stats["files"][split] = {"status": "downloaded", "size_mb": size_mb}
        else:
            stats["failed"] += 1
            stats["files"][split] = {"status": "failed"}
            if output_path.exists():
                output_path.unlink()
        if delay > 0:
            print(f"    ⏸️  Waiting {delay}s...")
            time.sleep(delay)
    return stats

def test_single_download():
    """Test downloading for a single file"""
    print("🧪 Test single file download...")
    url = generate_download_url("arxiv", "train")
    output_path = Path("test_arxiv_train.parquet")
    
    print(f"URL: {url}")
    success = download_file(url, output_path)
    
    if success:
        print("✅ Test successful!")
        if output_path.exists():
            print(f"File size: {output_path.stat().st_size / (1024*1024):.1f} MB")
    else:
        print("❌ Test failed!")

def generate_download_links_file(output_dir, subsets, splits=["train", "val"]):
    """Generates a file with all download links"""
    links_file = output_dir / "download_links.txt"
    
    with open(links_file, 'w', encoding='utf-8') as f:
        f.write("# Download links for llm_dataset_inference dataset\n")
        f.write("# Generated automatically\n")
        f.write("# These links include 302 redirects - use proper tools that follow redirects\n\n")
        
        for subset in subsets:
            f.write(f"## {subset}\n")
            for split in splits:
                url = generate_download_url(subset, split)
                f.write(f"{split}: {url}\n")
            f.write("\n")
        
        f.write("# wget example (follows redirects):\n")
        f.write("# wget --content-disposition -L 'URL_HERE'\n\n")
        f.write("# curl example (follows redirects):\n")
        f.write("# curl -L -o filename.parquet 'URL_HERE'\n")
    
    print(f"📄 Download links saved in: {links_file}")

def save_download_stats(output_dir, all_stats):
    """Saves download statistics"""
    stats_file = output_dir / "download_stats.json"
    
    total_files = sum(stats["total"] for stats in all_stats.values())
    total_success = sum(stats["success"] for stats in all_stats.values())
    total_failed = sum(stats["failed"] for stats in all_stats.values())
    
    summary = {
        "summary": {
            "total_subsets": len(all_stats),
            "total_files": total_files,
            "successful_downloads": total_success,
            "failed_downloads": total_failed,
            "success_rate": f"{(total_success/total_files)*100:.1f}%" if total_files > 0 else "0%"
        },
        "subsets": all_stats,
        "source": "https://huggingface.co/datasets/pratyushmaini/llm_dataset_inference",
        "note": "HuggingFace uses 302 redirects for downloads"
    }
    
    with open(stats_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    
    print(f"📊 Stats saved in: {stats_file}")

def main():
    parser = argparse.ArgumentParser(
        description="Download Parquet files from HuggingFace via direct links (supports 302 redirects)"
    )
    parser.add_argument(
        "output_dir", 
        type=str,
        help="Destination folder to save files"
    )
    parser.add_argument(
        "--subsets",
        nargs='+',
        default=None,
        help="List of specific subsets to download (default: all)"
    )
    parser.add_argument(
        "--splits",
        nargs='+',
        default=["train", "val"],
        help="List of splits to download (default: train val)"
    )
    parser.add_argument(
        "--delay",
        type=float,
        default=2.0,
        help="Delay between downloads in seconds (default: 2.0)"
    )
    parser.add_argument(
        "--test-single",
        action="store_true",
        help="Test downloading a single file only"
    )
    parser.add_argument(
        "--generate-links-only",
        action="store_true",
        help="Only generate links file, do not download"
    )
    
    args = parser.parse_args()
    
    if args.test_single:
        test_single_download()
        return
    
    # Create output folder
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Determine which subsets to download
    subsets_to_download = args.subsets if args.subsets else SUBSETS
    
    print("🚀 HuggingFace Dataset Downloader (Fixed for 302 redirects)")
    print(f"📂 Output folder: {output_dir.absolute()}")
    print(f"📋 Subsets: {', '.join(subsets_to_download)}")
    print(f"🔄 Splits: {', '.join(args.splits)}")
    print(f"⏱️  Delay: {args.delay}s")
    print("=" * 60)
    
    # Generate links file
    generate_download_links_file(output_dir, subsets_to_download, args.splits)
    
    if args.generate_links_only:
        print("✅ Links file generated. Download skipped.")
        return
    
    # Download all subsets
    all_stats = {}
    
    try:
        for i, subset_name in enumerate(subsets_to_download, 1):
            print(f"\n[{i}/{len(subsets_to_download)}] ", end="")
            
            stats = download_subset(
                subset_name, 
                output_dir, 
                args.splits,
                args.delay
            )
            
            all_stats[subset_name] = stats
            
    except KeyboardInterrupt:
        print("\n\n⚠️  Download interrupted by user")
    
    # Save stats
    if all_stats:
        save_download_stats(output_dir, all_stats)
    
    # Summary
    if all_stats:
        print("\n" + "=" * 60)
        print("📈 SUMMARY:")
        total_success = sum(stats["success"] for stats in all_stats.values())
        total_failed = sum(stats["failed"] for stats in all_stats.values())
        total_files = total_success + total_failed
        
        print(f"✅ Successfully downloaded: {total_success} files")
        print(f"❌ Failed: {total_failed} files")
        
        if total_files > 0:
            success_rate = (total_success / total_files) * 100
            print(f"📊 Success rate: {success_rate:.1f}%")
        
        print(f"📁 Files saved in: {output_dir.absolute()}")

if __name__ == "__main__":
    main()
