import os
import json

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

def load_linked_words(file_path):
    linked_words = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            raw_word, trans_word = line.strip().split('||')
            linked_words.append((raw_word, trans_word))
    return linked_words

def update_samples(samples, linked_words):
    updated_samples = []

    for sample in samples:
        raw_word = sample['raw_word']

        for i, (r, t) in enumerate(linked_words):
            if raw_word == r:
                sample['trans_word'] = t
                sample['trans_text'] = sample['raw_text'].replace(r, t)
                updated_samples.append(sample)
                linked_words.pop(i)
                break
    return updated_samples

def save_json(data, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, ensure_ascii=False, indent=4)

def save_transformed_texts(samples, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        for sample in samples:
            file.write(sample['trans_text'] + '\n')

def main(embedding_methods, categories, p_values):
    base_dir = 'test_{embedding_method}/output_multiline_{p}/{category}'
    samples_dir = 'test_data/sampled_multiline/samples_{category}.json'
    linked_words_dir = 'test_{embedding_method}/output_word_{p}/{category}/linked_words.txt'

    summary = {}

    for embedding_method in embedding_methods:
        summary[embedding_method] = {}
        for p in p_values:
            summary[embedding_method][p] = {}
            total = 0
            for category in categories:
                output_dir = base_dir.format(embedding_method=embedding_method, p=p, category=category)
                create_directory(output_dir)

                samples_path = samples_dir.format(category=category)
                linked_words_path = linked_words_dir.format(embedding_method=embedding_method, p=p, category=category)

                if os.path.exists(samples_path) and os.path.exists(linked_words_path):
                    samples = load_json(samples_path)
                    linked_words = load_linked_words(linked_words_path)
                    updated_samples = update_samples(samples, linked_words)

                    save_json(updated_samples, os.path.join(output_dir, 'samples.json'))
                    save_transformed_texts(updated_samples, os.path.join(output_dir, 'transformed_texts.txt'))

                    summary[embedding_method][p][category] = len(updated_samples)
                    total += len(updated_samples)
                else:
                    summary[embedding_method][p][category] = 'file not found'
            summary[embedding_method][p]['Total'] = total

    print("\n===== Summary =====\n")
    for embedding_method in embedding_methods:
        print(f"=== {embedding_method.upper()} ===")
        for p in p_values:
            print(f"--- Perturbation {p} ---")
            for category in categories:
                count = summary[embedding_method][p][category]
                print(f"{category}: {count}")
            print(f"Total: {summary[embedding_method][p]['Total']}")
        print("=============================")

if __name__ == "__main__":
    embedding_methods = ['easyocr', 'paddleocr', 'viper']
    categories = ['sexual', 'insult', 'hate', 'drug', 'crime']
    p_values = ['p05', 'p10']
    main(embedding_methods, categories, p_values)