import os
import subprocess
from pathlib import Path
import json
from glob import glob
import yaml

categories = ["sexual", "insult", "hate", "drug", "crime"]

generation_num_info = {
    'sexual': (3457, 854),
    'insult': (3192, 794),
    'hate': (886, 222),
    'drug': (220, 60),
    'crime': (245, 70)
}

template_path = "examples/viper_multiline_sf/template.py"
config_path = "examples/viper_multiline_sf/config.yaml"
output_base_dir = Path("safe_data/images_multiline")

for category in categories:
    print(f"Generating safe multiline images for category: {category}")
    count = generation_num_info[category][1]

    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    config['corpus'] = {
        'paths': [f"safe_data/multiline/{category}.txt"],
        'weights': [1]
    }
    config['samples'] = {
        'paths': [f"safe_data/samples/samples_{category}.json"]
    }

    temp_config_path = f"examples/viper_multiline_sf/config_{category}_temp.yaml"
    with open(temp_config_path, 'w') as f:
        yaml.dump(config, f)

    image_output_dir = output_base_dir / category
    image_output_dir.mkdir(parents=True, exist_ok=True)

    command = f"""
synthtiger -o {image_output_dir} -c {count} -w 1 -s 610 -v {template_path} Multiline {temp_config_path}
"""
    print(f"Running: {command.strip()}")
    subprocess.run(command, shell=True)

    os.remove(temp_config_path)

print("\nImage generation complete.\n")

print("Summary (per category):")
total_gt, total_json, total_images = 0, 0, 0

for category in categories:
    category_dir = output_base_dir / category
    gt_path = category_dir / "gt.txt"
    json_path = category_dir / "glyph_coords_4.json"
    image_dir = category_dir / "images" / "0"

    gt_count = sum(1 for _ in open(gt_path, encoding='utf-8')) if gt_path.exists() else 0
    json_count = len(json.load(open(json_path, encoding='utf-8'))) if json_path.exists() else 0
    image_count = len(glob(str(image_dir / "*.jpg"))) if image_dir.exists() else 0

    total_gt += gt_count
    total_json += json_count
    total_images += image_count

    print(f"{category}: gt.txt = {gt_count}, glyph_coords_4.json = {json_count}, images = {image_count}")

print("\nTotal:")
print(f"gt.txt total: {total_gt}")
print(f"glyph_coords_4.json total: {total_json}")
print(f"images total (images/0): {total_images}")