#!/usr/bin/env python
"""
Extract image features from folders using a lightweight ViT.
"""

import os
import pickle
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoFeatureExtractor, AutoModel


def extract_features_from_folder(folder_path, extractor, model, device, extract_synth=False):
    """
    Iterate through all supported image files in 'folder_path', run them through
    the extractor+model, and collect the [CLS] embedding for each.

    extract_synth: if True, only extract features from images with 'synth' in the filename.
    """
    features = []
    # filter image files
    image_files = [
        fn for fn in os.listdir(folder_path)
        if fn.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.gif'))
    ]
    for img_file in tqdm(image_files, desc=f'Processing {folder_path}'):
        img_path = os.path.join(folder_path, img_file)
        if 'synth' not in img_path:
            if extract_synth:
                continue
        else:
            if not extract_synth:
                continue
        img = Image.open(img_path).convert('RGB')
        inputs = extractor(images=img, return_tensors='pt')
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
        # take the CLS token embedding (assumes last_hidden_state is [1, seq_len, dim])
        cls_emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().tolist()
        features.append(cls_emb)
    return features


def main():
    import argparse
    parser = argparse.ArgumentParser(
        description='Extract image features from one or more folders.'
    )
    parser.add_argument(
        '--folders', nargs='+', required=True,
        help='Paths to folders containing images.'
    )
    parser.add_argument(
        '--model', type=str, default='facebook/deit-small-patch16-224',
        help='HuggingFace model to use for feature extraction.'
    )
    parser.add_argument(
        '--output_dir', type=str, default='.',
        help='Where to save the .pkl files.'
    )
    parser.add_argument(
        '--extract_synth', action='store_true',
        help='Only extract features from images with "synth" in the filename.'
    )
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    extractor = AutoFeatureExtractor.from_pretrained(args.model)
    extract_synth = args.extract_synth
    model = AutoModel.from_pretrained(args.model).to(device)
    model.eval()


    for folder in args.folders:
        if not os.path.isdir(folder):
            print(f"Skipping {folder}: not a directory.")
            continue
        feats = extract_features_from_folder(folder, extractor, model, device, extract_synth)
        basename = os.path.basename(os.path.normpath(folder))
        out_path = os.path.join(args.output_dir, f'{basename}_features.pkl')
        with open(out_path, 'wb') as f:
            pickle.dump(feats, f)
        print(f'Saved {len(feats)} feature vectors to {out_path}')


if __name__ == '__main__':
    main()

# python extract_features.py --folders flowers_augmented flowers_heldout --output_dir .