import random
import json
import argparse
from collections import defaultdict
random.seed(0)

file_path_prefix = "../dev/"

def read_metadata(subject):
    with open(f"{file_path_prefix}{subject}.json", "r") as file:
        return json.load(file)

def a_or_an(word):
    return "an" if word[0].lower() in "aeiou" else "a"

def generate_grouped_sentences(group_by, subject, data=None):
    if data is None:
        data = read_metadata(subject)

    grouped = defaultdict(list)
    
    age_keys = list(data["age"].keys())
    eth_keys = list(data["ethnicity"].keys())
    gen_keys = list(data["gender"].keys())

    for prefix in data["prefixes"]:
        for occupation in data["suffixes"]:
            for age_key in age_keys:
                age_word = random.choice(data["age"][age_key])
                for eth_key in eth_keys:
                    eth_word = random.choice(data["ethnicity"][eth_key])
                    for gen_key in gen_keys:
                        gender_word = random.choice(data["gender"][gen_key])
                        
                        article = a_or_an(age_word)
                        sentence = f"{prefix} {article} {age_word} {eth_word} {gender_word} {occupation}."

                        # Group by selected category
                        if group_by == "age":
                            grouped[age_key].append(sentence)
                        elif group_by == "ethnicity":
                            grouped[eth_key].append(sentence)
                        elif group_by == "gender":
                            grouped[gen_key].append(sentence)
                        else:
                            raise ValueError(f"Unknown grouping category: {group_by}")
    return grouped

def parse_args():
    parser = argparse.ArgumentParser(description="Generate grouped prompts based on protected category.")
    parser.add_argument(
        '--protected_category', 
        choices=['age', 'ethnicity', 'gender'], 
        required=True, 
        help="Specify the category to group by (age, ethnicity, or gender)"
    )
    parser.add_argument(
        '--subject', 
        choices=['occupation', 'leisure_activity', 'music_genres', 'technology', 'transportation', 'fashion_choices'],
        required=False,
        default='occupation', 
        help="Specify the image categories being generated (occupation, leasure_activity, etc)"
    )
    parser.add_argument(
        '--output_file', 
        required=False, 
        help="Specify the output json file to write to"
    )
    return parser.parse_args()

if __name__ == "__main__":
    random.seed(0)
    args = parse_args()
    grouped_sentences = generate_grouped_sentences(args.protected_category, args.subject, data=None)

    output_filename = f"{file_path_prefix}{args.protected_category}_{args.subject}.json"
    if args.output_file:
        output_filename = f"{file_path_prefix}{args.output_file}.json"
    with open(output_filename, "w") as outfile:
        json.dump(grouped_sentences, outfile, indent=4)

    print(f"Example prompts have been written to {output_filename}.")

# Example usage (to modify inputs manually edit meta.json):
# $ python3 generator.py --protected_category ethnicity # generates prompts of occupations stratified by ethnicity, stores in default ethnicity_occupation.json
# $ python3 generator.py --protected_category age --subject leisure_activity # generates prompts of leisure activities stratified by by age, stores in default age_leisure_activity.json
# $ python3 generator.py --protected_category gender --output_file gender_train_prompts # generates prompts grouped by gender, stores results in gender_train_prompts.json

