import os
import random
import yaml
import json

def run_viper_transformation(embedding_method, split_info, embedding_vec_path):
    total_linked = 0
    total_transformed = 0

    for split in ['train', 'test']:
        for p_val in ['p10', 'p05']:
            output_dir = f"{split}_{embedding_method}/output_word_{p_val}"
            for category in split_info.keys():
                dir_path = os.path.join(output_dir, category)
                os.makedirs(dir_path, exist_ok=True)
    print("Created output directory structure.")

    print("\nRunning VIPER transformations...")
    for split in ['train', 'test']:
        for p_val, p_flag in [('p10', 1.0), ('p05', 0.5)]:
            print(f"\nSplit: {split}, Perturbation: {p_val}")
            split_linked = 0
            split_transformed = 0
            for category, (train_count, test_count) in split_info.items():
                count = train_count if split == 'train' else test_count
                output_dir = f"{split}_{embedding_method}/output_word_{p_val}/{category}"
                
                if split == 'train':
                    input_file = f"train_easyocr/original_words/{category}.txt"
                else:
                    input_file = f"test_data/sampled_words/{category}.txt"

                perturb_file = f"{output_dir}/dummy_store.txt"
                transformed_file = f"{output_dir}/transformed_words.txt"
                linked_file = f"{output_dir}/linked_words.txt"

                command = f"""
python3 VIPER/glyphperturber.py \\
  -e {embedding_vec_path} \\
  -p {p_flag} \\
  -s 42 \\
  -c {count} \\
  --perturbations-file {perturb_file} \\
  --transformed-file {transformed_file} \\
  --linked-file {linked_file} \\
  < {input_file}
"""
                print(f"Running VIPER on [{split}] [{category}] p={p_flag}")
                os.system(command)

                with open(transformed_file, 'r') as f:
                    transformed_count = len([line for line in f if line.strip()])
                with open(linked_file, 'r') as f:
                    linked_count = len([line for line in f if line.strip()])
                
                print(f"[{category}] transformed_words.txt: {transformed_count}, linked_words.txt: {linked_count}")
                split_linked += linked_count
                split_transformed += transformed_count

            print(f"Split [{split}], p={p_val}: Total transformed: {split_transformed}, linked: {split_linked}")
            total_linked += split_linked
            total_transformed += split_transformed
    
    print(f"\nCompleted VIPER: total transformed_words = {total_transformed}, total linked_words = {total_linked}")

def update_config_yaml(config_path, output_dir):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    config['linked_words']['paths'] = [os.path.join(output_dir, 'linked_words.txt')]
    config['multilingual_corpus']['paths'] = [os.path.join(output_dir, 'transformed_words.txt')]

    with open(config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False)

    print(f"Updated config.yaml → {output_dir}")

def run_synthtiger(embedding_method, split_info):
    config_path = 'examples/viper_synthtiger/config.yaml'
    print("\nRunning SynthTiger image generation...")

    for split in ['train', 'test']:
        for p_val in ['p10', 'p05']:
            total_images = 0
            total_gt = 0
            total_coords = 0
            print(f"\nSplit: {split}, Perturbation: {p_val}")
            for category, (train_count, test_count) in split_info.items():
                count = train_count if split == 'train' else test_count
                output_dir = f"{split}_{embedding_method}/output_word_{p_val}/{category}"

                update_config_yaml(config_path, output_dir)

                image_output_dir = f"{split}_{embedding_method}/images_word_{p_val}/{category}"
                os.makedirs(image_output_dir, exist_ok=True)

                command = f"""
synthtiger -o {image_output_dir} -c {count} -w 1 -s 610 -v examples/viper_synthtiger/template.py SynthTiger examples/viper_synthtiger/config.yaml
"""
                print(f"SynthTiger: {split} {category} p={p_val}")
                os.system(command)

                img_count = len(os.listdir(os.path.join(image_output_dir, 'images/4')))
                gt_file = os.path.join(image_output_dir, 'gt.txt')
                coords_file = os.path.join(image_output_dir, 'glyph_coords_4.json')

                with open(gt_file, 'r') as f:
                    gt_count = len([line for line in f if line.strip()])
                with open(coords_file, 'r') as f:
                    coords_json = json.load(f)
                    coords_count = len(coords_json)

                print(f"[{category}] images: {img_count}, gt.txt: {gt_count}, glyph_coords_4.json: {coords_count}")

                total_images += img_count
                total_gt += gt_count
                total_coords += coords_count

            print(f"Split [{split}] p={p_val}: Total images: {total_images}, gt.txt: {total_gt}, coords: {total_coords}")

    print("\nSynthTiger generation complete.")

if __name__ == "__main__":
    embedding_method = 'easyocr'
    embedding_vec_path = 'VIPER/output_easyocr/vec.normalized'

    generation_num_info = {
        'sexual': (3457, 854),
        'insult': (3192, 794),
        'hate': (886, 222),
        'drug': (220, 60),
        'crime': (245, 70)
    }

    run_viper_transformation(embedding_method, generation_num_info, embedding_vec_path)
    run_synthtiger(embedding_method, generation_num_info)