#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Download and extract the PUB dataset from Hugging Face.
"""

import os
import sys
import requests
import zipfile
from pathlib import Path
from tqdm import tqdm
import logging
import concurrent.futures

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("PUB-Downloader")

# Base URL for PUB dataset on Hugging Face
BASE_URL = "https://huggingface.co/datasets/cfilt/PUB/resolve/main/data/"

# Task files to download (from the image)
TASK_FILES = [
    "task_1.zip", "task_2.zip", "task_3.zip", "task_4.zip", "task_5.zip",
    "task_6.zip", "task_7.zip", "task_8.zip", "task_9.zip", "task_10.zip",
    "task_11.zip", "task_12.zip", "task_13.zip", "task_14.zip"
]

# Directories
ROOT_DIR = Path(__file__).parent.parent
RAW_DATA_DIR = ROOT_DIR / "data" / "raw" / "pub"
EXTRACT_DIR = RAW_DATA_DIR


def download_file(url, output_path):
    """Download a file from a URL with progress tracking."""
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        
        total_size = int(response.headers.get('content-length', 0))
        block_size = 8192
        
        with open(output_path, 'wb') as f:
            with tqdm(total=total_size, unit='B', unit_scale=True, desc=output_path.name) as pbar:
                for chunk in response.iter_content(chunk_size=block_size):
                    if chunk:
                        f.write(chunk)
                        pbar.update(len(chunk))
        
        return True
    except Exception as e:
        logger.error(f"Error downloading {url}: {e}")
        return False


def extract_zip(zip_path, extract_dir):
    """Extract a zip file to the specified directory."""
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_dir)
        return True
    except Exception as e:
        logger.error(f"Error extracting {zip_path}: {e}")
        return False


def download_and_extract_task(task_file):
    """Download and extract a single task file."""
    url = BASE_URL + task_file
    zip_path = RAW_DATA_DIR / task_file
    
    # Download file if it doesn't exist
    if not zip_path.exists():
        logger.info(f"Downloading {task_file}...")
        if not download_file(url, zip_path):
            return False
    else:
        logger.info(f"File {task_file} already exists, skipping download.")
    
    # Extract file
    logger.info(f"Extracting {task_file}...")
    if not extract_zip(zip_path, EXTRACT_DIR):
        return False
    
    logger.info(f"Successfully processed {task_file}")
    return True


def main():
    """Main function to download and extract all PUB files."""
    # Create directories if they don't exist
    RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
    
    logger.info(f"Downloading and extracting {len(TASK_FILES)} PUB task files...")
    
    # Process files in parallel
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        results = list(executor.map(download_and_extract_task, TASK_FILES))
    
    success_count = sum(results)
    logger.info(f"Downloaded and extracted {success_count}/{len(TASK_FILES)} task files.")
    
    if success_count != len(TASK_FILES):
        logger.warning("Some files failed to download or extract.")
        return 1
    return 0


if __name__ == "__main__":
    sys.exit(main()) 