import os
import requests
import tarfile
from tqdm import tqdm
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# List of 10 WordNet IDs (synsets) from ImageNet-100
IMAGENET_10_CLASSES = [
    "n02869837", "n01749939", "n02488291", "n02107142", "n13037406",
    "n02091831", "n04517823", "n04589890", "n03062245", "n01773797"
]

# Base URL for ImageNet synset downloads
IMAGENET_URL_TEMPLATE = "https://image-net.org/data/winter21_whole/{}.tar"

# Directory to save the downloaded and extracted images
DATASET_DIR = "dataset/imagenet-10-subset"
DOWNLOAD_DIR = "downloads"

def download_and_extract_images(class_ids, num_images_per_class=100):
    """
    Downloads and extracts a subset of images for the specified ImageNet classes.
    """
    os.makedirs(DATASET_DIR, exist_ok=True)
    os.makedirs(DOWNLOAD_DIR, exist_ok=True)

    for class_id in class_ids:
        class_dir = os.path.join(DATASET_DIR, class_id)
        if os.path.exists(class_dir) and len(os.listdir(class_dir)) >= num_images_per_class:
            logging.info(f"Class {class_id} already has {len(os.listdir(class_dir))} images. Skipping.")
            continue

        os.makedirs(class_dir, exist_ok=True)
        tar_url = IMAGENET_URL_TEMPLATE.format(class_id)
        tar_path = os.path.join(DOWNLOAD_DIR, f"{class_id}.tar")

        try:
            # Download the tar file
            logging.info(f"Downloading {tar_url}...")
            response = requests.get(tar_url, stream=True)
            response.raise_for_status()  # Raise an exception for bad status codes

            total_size = int(response.headers.get('content-length', 0))
            with open(tar_path, 'wb') as f, tqdm(
                total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {class_id}"
            ) as pbar:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
                    pbar.update(len(chunk))

            # Extract images from the tar file
            logging.info(f"Extracting images for {class_id}...")
            with tarfile.open(tar_path, 'r') as tar:
                members = tar.getmembers()
                for i, member in enumerate(tqdm(members, desc=f"Extracting {class_id}")):
                    if i >= num_images_per_class:
                        break
                    if member.isfile():
                        tar.extract(member, path=class_dir)

            logging.info(f"Successfully downloaded and extracted {num_images_per_class} images for {class_id}.")

        except requests.exceptions.RequestException as e:
            logging.error(f"Failed to download {tar_url}: {e}")
        except tarfile.TarError as e:
            logging.error(f"Failed to extract {tar_path}: {e}")
        except Exception as e:
            logging.error(f"An unexpected error occurred for class {class_id}: {e}")
        finally:
            # Clean up the downloaded tar file
            if os.path.exists(tar_path):
                os.remove(tar_path)

if __name__ == "__main__":
    logging.info("Starting image download process...")
    download_and_extract_images(IMAGENET_10_CLASSES)
    logging.info("Image download process completed.") 