#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm
import pdb
import re
# Assuming this script is saved in a package and the import below refers to another module within the same package    
from concurrent.futures import ThreadPoolExecutor, as_completed  
from functools import partial  
from contextlib import contextmanager    
from io import StringIO    
import sys
import timeout_decorator
import ast
import logging
from code_utils import modify_main_function, extract_code, remove_intermediate_prints, contains_exit_statements, exec_, capture_output, exec_table, remove_print_df_head


def load_and_infer_from_jsonl_base_exec(prompt_path, result_path, prompt_template, eval_model, num_threads=5):    
    """    
    Load prompts from a .jsonl file, where each prompt is a string, get responses using Azure OpenAI, and save the results in another .jsonl file with progress monitoring.    
    """    
    total_lines = sum(1 for line in open(prompt_path, 'r', encoding='utf-8'))    
    i = 0  # Initialize counter for indexing prompts  
      
    with open(prompt_path, 'r', encoding='utf-8') as infile, open(result_path, 'w', encoding='utf-8') as outfile:    
        for line in tqdm(infile, total=total_lines, desc="running"):  # tqdm wrapper for progress  
            i += 1  # Increment index for each prompt processed  
            data = json.loads(line)
            data_path = data.get('data_path', '')
            opt_flag = data.get('require_opt', '')
            if opt_flag == 'false':
                
                outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
                continue
            if eval_model == "baseline":
                final_code = data.get('first_code', '')   # rag or training
            elif eval_model == "training":
                final_code = data.get('final_code', '')
            elif eval_model == "inference":
                final_code = data.get('rag_final_code', '')
    
            if final_code:   
                final_code = prompt_template.replace('[[final_code]]', final_code)  
                final_code = final_code.replace('[[data_path]]', data_path)
                
                final_code = remove_intermediate_prints(final_code)
                final_code = remove_print_df_head(final_code)
                # final_code = modify_main_function(final_code)
                
                # Check for dangerous exit calls
                if contains_exit_statements(final_code):
                    data['execution'] = 'Code contains unsafe exit statements.'
                else:
                    try:
                        result_string, errors_string = exec_(final_code)          
                        if result_string:          
                            data['execution'] = result_string 
                        else:
                            data['execution'] = errors_string  
                    except Exception as exec_exception:
                        data['execution'] = f'Execution error: {exec_exception}'
                
                # try:
                #     result_string, errors_string = exec_(final_code)          
                #     if result_string:          
                #         data['execution'] = result_string 
                #     else:
                #         data['execution'] = errors_string  
                # except Exception as exec_exception:
                #     data['execution'] = f'Execution error: {exec_exception}'
            else:  
                data['execution'] = 'no effective code'
        
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
  
def inference():  
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--num_threads', type=int, required=False, help='if your API could be run in parallel')
    parser.add_argument('--dataset', type=str, required=True, help='name of dataset')
    parser.add_argument('--eval_mode', type=str, required=True, help='model of evaluation')
      
    args = parser.parse_args()  
    if args.dataset == "tabmwp":
        from prompts.tabmwp_prompt_combine.new.tabmwp_baseline import exec_string as prompt_template
    if args.dataset == "wikitq":
        from prompts.wikitq_prompt_combine.prompt.wikitq_all import exec_string as prompt_template
    
    load_and_infer_from_jsonl_base_exec(args.prompt_path, args.result_path, prompt_template, eval_model=args.eval_mode)  
  
if __name__ == "__main__":  
    inference()  
