import time
import os
import json
import re

import utils.utils as utils
import utils.gpt as gpt
import retrieval.auto_retrieval as auto_retrieval

def parse_formal_problem(formal_problem):
    """ parse the formal problem into the format of SMT-LIB """
    match = re.search(r'```smt(.*?)```|```(.*?)```', formal_problem, re.DOTALL)
    if match:
        smt_code_block, generic_code_block = match.groups() 
        if smt_code_block:
            return smt_code_block.strip()
        elif generic_code_block:
            return generic_code_block.strip()
    else:
        return formal_problem

def autoformalize(data, index, client, logger, args):
    """ autoformalize the problem by calling GPT """
    Q, A = data['question'], data['answer']
    file_name = "problem_%s" %(args.index[0]+index)
    save_file_path = os.path.join(args.save_folder_path, file_name+'.json')
    if not os.path.exists(save_file_path):
        answer = utils.parse_answer(A)
        thy_series = {'problem': Q, 
                    'solution': A, 
                    'answer': answer}
        with open(save_file_path, 'w') as file:  
            json.dump(thy_series, file, indent=4)
    else:
        with open(save_file_path) as file:
            thy_series = json.load(file)
            answer = thy_series['answer']
    thy = thy_series.get(f'generation_{max(args.save_iter,0)}', {})
    formal_problem = thy.get("formal problem", "")
    if formal_problem != "" and args.overwrite == 0:
        logger.info('skip')
    else:
        t0 = time.time()
        last_thy = thy_series[f"generation_{args.ref_iter}"]
        informal_problem = last_thy.get("informal problem", "")
        prompt = f"Translate the natural language problem into SMT-LIB language: {informal_problem}"
        prob_examples = auto_retrieval.prob_sample(prompt, k=args.num_retrive)
        formal_problem = gpt.gpt4_response(client, prompt, prob_examples, top_p=0.95)
        formal_problem = parse_formal_problem(formal_problem)
        logger.info(informal_problem + '\nTO ========>>>\n' + formal_problem)        
        t1 = time.time()
        logger.info('success! In (%s)s' % (t1-t0))
        thy["formal problem"] = formal_problem
        thy_series.update({f'generation_{max(args.save_iter,0)}' : thy})
        with open(save_file_path, 'w') as file:  
            json.dump(thy_series, file, indent=4)
    return 