#!/usr/bin/env python3
"""
Script to copy a subset of INET train, test, and validation folders into a new repository,
reorganizing validation images from numeric labels (1–1000) to the same "n..." synset structure
and only including classes listed in IN100.txt (which should list synset names like "n01440764").
"""
import os
import shutil
import argparse


def copy_selected_dirs(src_dir: str, dst_dir: str, synsets: set) -> None:
    """Copy only the subdirectories named in synsets from src_dir to dst_dir."""
    if os.path.exists(dst_dir):
        shutil.rmtree(dst_dir)
    os.makedirs(dst_dir, exist_ok=True)

    for syn in synsets:
        src_path = os.path.join(src_dir, syn)
        dst_path = os.path.join(dst_dir, syn)
        if not os.path.isdir(src_path):
            print(f"Warning: class directory '{src_path}' not found; skipping.")
            continue
        shutil.copytree(src_path, dst_path)
        print(f"Copied class '{syn}' to '{dst_path}'")


def load_label_map(train_dir: str) -> dict:
    """
    Build a mapping from numeric labels (as strings) to synset folder names:
    - Finds all 'n...' subdirs in train_dir, sorts alphabetically,
      then assigns '1'->first, '2'->second, etc.
    """
    syn_dirs = [d for d in os.listdir(train_dir)
                if os.path.isdir(os.path.join(train_dir, d)) and d.startswith('n')]
    syn_dirs.sort()
    return {str(i + 1): syn for i, syn in enumerate(syn_dirs)}


def load_in100(in100_file: str) -> set:
    """
    Read IN100.txt and return a set of synset names (e.g., 'n01440764').
    """
    if not os.path.isfile(in100_file):
        raise FileNotFoundError(f"IN100 file '{in100_file}' not found.")
    with open(in100_file, 'r') as f:
        syns = {line.strip() for line in f if line.strip()}
    if not syns:
        raise ValueError(f"No synsets found in '{in100_file}'")
    return syns


def reorganize_val(src_val_dir: str,
                   val_map_file: str,
                   dst_val_dir: str,
                   label_map: dict) -> None:
    """
    Reorganize validation images using label_map (numeric->synset):
    - val_map.txt lines: '<image_name> <numeric_label>'
    - Copies only entries where numeric_label exists in label_map
      into dst_val_dir/<synset>/
    """
    if not os.path.isdir(src_val_dir):
        raise FileNotFoundError(f"Validation directory '{src_val_dir}' not found.")
    if not os.path.isfile(val_map_file):
        raise FileNotFoundError(f"Validation map '{val_map_file}' not found.")

    os.makedirs(dst_val_dir, exist_ok=True)

    with open(val_map_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 2:
                continue
            filename, num_label = parts
            synset = label_map.get(num_label)
            if not synset:
                # skip classes not in IN100
                continue
            src_path = os.path.join(src_val_dir, filename)
            if not os.path.isfile(src_path):
                continue
            out_dir = os.path.join(dst_val_dir, synset)
            os.makedirs(out_dir, exist_ok=True)
            shutil.copy2(src_path, os.path.join(out_dir, filename))

    print(f"Reorganized validation images into '{dst_val_dir}' for {len(label_map)} classes.")


def main():
    parser = argparse.ArgumentParser(
        description="Copy IN100 subset of INET dataset and reorganize validation."    )
    parser.add_argument('--src_dir', help="Source root (contains train/, test/, val/)")
    parser.add_argument('--dst_dir', help="Destination root for new structure")
    parser.add_argument('--val_map', default='val_map.txt',
                        help="Validation map file (default: val_map.txt)")
    parser.add_argument('--in100', default='IN100.txt',
                        help="File listing synset names to include (default: IN100.txt)")
    args = parser.parse_args()

    # Paths
    src_train = os.path.join(args.src_dir, 'train')
    src_val   = os.path.join(args.src_dir, 'val')
    dst_train = os.path.join(args.dst_dir, 'train')
    dst_val   = os.path.join(args.dst_dir, 'val')

    # Load synsets to keep (IN100)
    keep_synsets = load_in100(args.in100)
    print(f"Loaded {len(keep_synsets)} synsets from {args.in100}.")

    # Copy only selected train/test classes
    copy_selected_dirs(src_train, dst_train, keep_synsets)

    # Build numeric->synset map and filter to IN100
    full_map = load_label_map(src_train)
    label_map = {num: syn for num, syn in full_map.items() if syn in keep_synsets}
    print(f"Filtered to {len(label_map)} numeric labels matching IN100 synsets.")

    # Reorganize and copy val images only for selected synsets
    reorganize_val(src_val,
                   os.path.join(src_val, args.val_map),
                   dst_val,
                   label_map)

if __name__ == '__main__':
    main()
