#!/usr/bin/env python3
"""
Clean script to create a Hugging Face dataset from face images and metadata.

Usage:
    python create_hf_dataset.py --input-dir faces/faces4_done --output-dir hf_dataset
"""

import json
import os
import argparse
from typing import Dict, Any
from PIL import Image
from datasets import Dataset, Features, Image as HFImage, Value, ClassLabel


def create_features(metadata: Dict[str, Any]) -> Features:
    """Create Hugging Face dataset features."""
    primary_attrs = metadata["attributes_summary"]["primary_attributes"]
    secondary_attrs = metadata["attributes_summary"]["secondary_attributes"]
    
    return Features({
        "image": HFImage(),
        "image_name": Value("string"),
        "full_prompt": Value("string"),
        "id": Value("string"),
        "name": Value("string"),
        "p_gender": Value("string"),
        "p_eye_color": Value("string"),
        "p_hair_color": Value("string"),
        "p_hair_style": Value("string"),
        "p_accessories": Value("string"),
        "s_age_group": Value("string"),
        "s_skin_tone": Value("string"),
        "s_face_shape": Value("string"),
        "s_eyebrow_shape": Value("string"),
        "s_lip_shape": Value("string"),
        "s_facial_features": Value("string")
    })


def get_face_id(image_name: str) -> str:
    """Extract face ID from image name."""
    try:
        return image_name.split('_')[1].split('.')[0]
    except (IndexError, ValueError):
        return "0000"


def get_name_for_face(image_name: str, gender: str, names_data: list) -> str:
    """Get name for face based on ID and gender."""
    if not names_data:
        return "Unknown"
    
    face_id = get_face_id(image_name)
    for name_entry in names_data:
        if name_entry["id"] == face_id:
            return name_entry["name_male"] if gender == "male" else name_entry["name_female"]
    return "Unknown"


def create_dataset(input_dir: str, output_dir: str, names_file: str = None) -> Dataset:
    """Create Hugging Face dataset from images and metadata."""
    
    # Load metadata
    metadata_file = os.path.join(input_dir, "face_metadata.json")
    print(f"Loading metadata from: {metadata_file}")
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    
    # Load names data if provided
    names_data = []
    if names_file and os.path.exists(names_file):
        print(f"Loading names from: {names_file}")
        with open(names_file, 'r') as f:
            names_data = json.load(f)
        print(f"Loaded {len(names_data)} name pairs")
    
    # Get available images
    available_images = set()
    for file in os.listdir(input_dir):
        if file.endswith('.png') and file.startswith('face_'):
            available_images.add(file)
    
    print(f"Found {len(available_images)} image files")
    
    # Process faces
    generated_faces = metadata["generated_faces"]
    print(f"Found {len(generated_faces)} face entries in metadata")
    
    # Filter valid faces
    valid_faces = [face for face in generated_faces if face["image_name"] in available_images]
    print(f"Valid faces with images: {len(valid_faces)}")
    
    # Create dataset records
    dataset_records = []
    for i, face_data in enumerate(valid_faces):
        if i % 100 == 0:
            print(f"Processing face {i+1}/{len(valid_faces)}")
        
        image_path = os.path.join(input_dir, face_data["image_name"])
        try:
            # Load and convert image
            image = Image.open(image_path)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Get attributes
            attrs = face_data["attributes"]
            sec_attrs = attrs["secondary_attributes"]
            
            # Get name
            name = get_name_for_face(face_data["image_name"], attrs["gender"], names_data)
            
            # Create record
            record = {
                "image": image,
                "image_name": face_data["image_name"],
                "full_prompt": face_data["full_prompt"],
                "id": get_face_id(face_data["image_name"]),
                "name": name,
                "p_gender": attrs["gender"],
                "p_eye_color": attrs["eye_color"],
                "p_hair_color": attrs["hair_color"],
                "p_hair_style": attrs["hair_style"],
                "p_accessories": attrs["accessories"],
                "s_age_group": sec_attrs["age_group"],
                "s_skin_tone": sec_attrs["skin_tone"],
                "s_face_shape": sec_attrs["face_shape"],
                "s_eyebrow_shape": sec_attrs["eyebrow_shape"],
                "s_lip_shape": sec_attrs["lip_shape"],
                "s_facial_features": sec_attrs["facial_features"]
            }
            
            dataset_records.append(record)
            
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            continue
    
    print(f"Successfully processed {len(dataset_records)} images")
    
    # Create dataset with proper features
    features = create_features(metadata)
    dataset = Dataset.from_list(dataset_records, features=features)
    
    return dataset


def save_dataset(dataset: Dataset, output_dir: str, metadata: Dict[str, Any]):
    """Save dataset to disk."""
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving dataset to: {output_dir}")
    
    # Save dataset (this automatically creates the proper dataset_info.json)
    dataset.save_to_disk(output_dir)
    
    print(f"Dataset saved successfully!")
    print(f"Total images: {len(dataset)}")
    print(f"Features: {list(dataset.features.keys())}")


def main():
    parser = argparse.ArgumentParser(description="Create Hugging Face dataset from face images and metadata")
    parser.add_argument("--input-dir", required=True, help="Input directory containing images and face_metadata.json")
    parser.add_argument("--output-dir", required=True, help="Output directory for Hugging Face dataset")
    parser.add_argument("--names-file", default="names/output/names_combined.json", help="JSON file containing name pairs")
    parser.add_argument("--limit", type=int, help="Limit number of images to process (for testing)")
    parser.add_argument("--no-test", action="store_true", help="Skip testing the loaded dataset")
    
    args = parser.parse_args()
    
    # Validate inputs
    if not os.path.exists(args.input_dir):
        print(f"Error: Input directory not found: {args.input_dir}")
        return
    
    metadata_file = os.path.join(args.input_dir, "face_metadata.json")
    if not os.path.exists(metadata_file):
        print(f"Error: Metadata file not found: {metadata_file}")
        return
    
    if args.names_file and not os.path.exists(args.names_file):
        print(f"Error: Names file not found: {args.names_file}")
        return
    
    # Create dataset
    dataset = create_dataset(args.input_dir, args.output_dir, args.names_file)
    
    # Apply limit if specified
    if args.limit:
        dataset = dataset.select(range(min(args.limit, len(dataset))))
        print(f"Limited dataset to {len(dataset)} images for testing")
    
    # Load metadata for saving
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    
    # Save dataset
    save_dataset(dataset, args.output_dir, metadata)
    
if __name__ == "__main__":
    main()