
import json
from concurrent.futures import ThreadPoolExecutor
import fcntl
import os
import re
from spellchecker import SpellChecker


import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Get embeddings and compute similarity between completions')
    parser.add_argument('--temperature', '-t', type=float, required=True,
                       help='Temperature parameter for completions')
    parser.add_argument('--approach', type=str, default="ss",
                       help='approach used to generate completions')
    parser.add_argument('--max-workers', type=int, default=50,
                       help='Maximum number of concurrent workers (default: 50)')
    parser.add_argument('--force', type=str, default=False,
                       help='Force re-run even if the output file already exists')
    return parser.parse_args()

args = parse_args()
T = args.temperature

with open(f'final_data/{args.approach}_completions_temp_{T}.jsonl', 'r') as f:
    baseline_completions_0 = f.readlines()

# compute the max length of the completions
solutions = [json.loads(completion) for completion in baseline_completions_0]



def spelling_error(solution_0, solution_1):
    spell = SpellChecker()
    
    def count_misspelled(text):
        # Remove special formatting and punctuation
        text = re.sub(r'\*\*.*?\*\*:', '', text)  # Remove section headers
        text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
        
        # Split into words and check each
        words = text.split()
        misspelled = spell.unknown(words)
        
        # Calculate error rate
        if len(words) == 0:
            return 0
        return len(misspelled) / len(words)
    
    # Average error rate across both solutions
    error_0 = count_misspelled(solution_0)
    error_1 = count_misspelled(solution_1)
    return (error_0 + error_1) / 2



def formatting_error(solution_0,solution_1):
    def has_end_error(solution):
        if solution.endswith("\n\nTHE END"):
            return 0 # True
        else:
            return 1 # False
    def has_all_acts_error(solution):
        if "\n\n**Act 1" in solution and "\n\n**Act 2" in solution and "\n\n**Act 3" in solution:
            return 0
        else:
            return 1
        
    def has_error(solution):
        if has_end_error(solution) == 1 or has_all_acts_error(solution) == 1:
            return 1
        else:
            return 0
    
    return (has_error(solution_0) + has_error(solution_1)) / 2

def process_solution(solution):
    solution_prompt = solution['prompt']
    (solution_0, solution_1) = solution['completions']
    spelling_err = spelling_error(solution_0,solution_1 )
    formatting_err = formatting_error(solution_0,solution_1 )
    fully_formatted = 0 if formatting_err < 0.01 and spelling_err < 0.2 else 1
    return {
        'prompt': solution_prompt,
        'spelling_error': spelling_err,
        'formatting_error': formatting_err, 
        'fully_formatted': fully_formatted
    }
 
def process_and_write(solution, output_file):
    result = process_solution(solution)
    write_result(result, output_file)
    
def write_result(result, output_file):
    if result is None:
        return
        
    with open(output_file, 'a') as f:
        fcntl.flock(f.fileno(), fcntl.LOCK_EX)
        try:
            f.write(json.dumps(result) + '\n')
        finally:
            fcntl.flock(f.fileno(), fcntl.LOCK_UN)
    
output_file = f'final_data/quality_judge/{args.approach}_temp_{T}.jsonl'
if args.force == "True":
    if os.path.exists(output_file):
        os.remove(output_file)

if os.path.exists(output_file):
    with open(output_file, 'r') as f:
        seen_prompts = [json.loads(line)['prompt'] for line in f.readlines()]

    solutions = [sol for sol in solutions if sol['prompt'] not in seen_prompts]

with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
    futures = [
        executor.submit(process_and_write, solution, output_file)
        for solution in solutions
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()

        
print (f"Finished processing {args.approach} temperature {T}")