#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm
import pdb
import re
# Assuming this script is saved in a package and the import below refers to another module within the same package    
from concurrent.futures import ThreadPoolExecutor, as_completed  
from functools import partial  
from utils.tools import capture_output, exec_table
from contextlib import contextmanager    
from io import StringIO    
import sys
from model_api_call import get_chat_response_azure
import timeout_decorator
import ast
import logging
from prompts.wikitq_prompt_combine.prompt.wikitq_all import exec_string
# from wikitq_prompt_combine.new.wikitq_baseline import exec_string
from code_utils import remove_intermediate_prints

def extract_code(result):  
    pattern = r"```python\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  

def create_message(prompt_string):  
    """Create a message list for the chat."""  
    return [  
        {"role": "system", "content": "You are a helpful assistant."},  
        {"role": "user", "content": prompt_string}  
    ]  

@contextmanager  
def capture_output():  
    new_out, new_err = StringIO(), StringIO()  
    old_out, old_err = sys.stdout, sys.stderr  
    try:  
        sys.stdout, sys.stderr = new_out, new_err  
        yield new_out, new_err  
    finally:  
        sys.stdout, sys.stderr = old_out, old_err  
        
@timeout_decorator.timeout(30)
def exec_(code):  
    # Use the custom context manager to capture output and errors  
    with capture_output() as (out, err):  
        try:  
            # Execute the provided code  
            # The globals parameter is crucial for function definition recognition  
            exec(code, {"__builtins__": __builtins__})  
        except Exception as e:  
            # Write any exceptions to the err buffer  
            err.write(f"Error executing code: {e}\n")  
  
    # Get the content from both buffers  
    result_output = out.getvalue()  
    result_errors = err.getvalue()  
  
    # Close the buffers  
    out.close()  
    err.close()  
  
    # Return the captured output and errors  
    return result_output, result_errors

def extract_main_function(input_code):  
    # Pattern to find the 'def main():' and its contents  
    # This pattern assumes that the 'main' function is not nested and is defined at the top level of indentation  
    main_pattern = re.compile(r'def main\(\):.*?(?=^\w)', re.DOTALL | re.MULTILINE)  
  
    # Extract 'def main()' function  
    main_function_match = main_pattern.search(input_code)  
    main_function = main_function_match.group(0) if main_function_match else ""  
  
    # Extract the 'if __name__ == "__main__":' block  
    # Assuming it is at the end of the file and not nested  
    name_block_pattern = re.compile(r'if __name__ == "__main__":.*', re.DOTALL)  
    name_block_match = name_block_pattern.search(input_code)  
    name_block = name_block_match.group(0) if name_block_match else ""  
  
    # Combine the extracted parts  
    extracted_code = f"{main_function}\n{name_block}"  
  
    return extracted_code  

def remove_print_statements(code):  
    # This pattern matches print statements. It handles both simple and complex cases.  
    # Note: This pattern might not catch all possible variations of print statements,  
    # especially those spanning multiple lines or using unusual string concatenation.  
    pattern = r'^\s*print\(.*\)\s*$'  
    cleaned_code = re.sub(pattern, '', code, flags=re.MULTILINE)  
    return cleaned_code

def remove_print_df_head(code: str) -> str:
    """
    Removes lines containing 'print(df.head())' from a given code string.

    Parameters:
    code (str): The input code string.

    Returns:
    str: The modified code string with the specified lines removed.
    """
    lines = code.split('\n')
    filtered_lines = [line for line in lines if '.head()' not in line.strip() or '.head(3)' not in line.strip()]
    return '\n'.join(filtered_lines)



def contains_exit_statements(code):
    """
    Check if the given code contains any statements that could terminate the script.
    
    Args:
    - code (str): The code to be checked.
    
    Returns:
    - bool: True if the code contains unsafe exit calls, False otherwise.
    """
    forbidden_calls = ['exit', 'quit', 'sys.exit', 'os._exit']

    for call in forbidden_calls:
        if call in code:
            logging.error(f"Unsafe call detected: {call}")
            return True
    return False

def modify_main_function(code_str):
    """
    This function takes a Python code string and modifies the `if __name__ == "__main__": main()`
    section to just `main()` to ensure it can run properly when executed with exec().
    """
    # Replace the `if __name__ == "__main__": main()` with `main()`
    modified_code = code_str.replace('if __name__ == "__main__":\n    main()', 'main()')
    
    return modified_code

def load_and_infer_from_jsonl_base_exec(prompt_path, result_path, num_threads=5):    
    """    
    Load prompts from a .jsonl file, where each prompt is a string, get responses using Azure OpenAI, and save the results in another .jsonl file with progress monitoring.    
    """    
    total_lines = sum(1 for line in open(prompt_path, 'r', encoding='utf-8'))    
    i = 0  # Initialize counter for indexing prompts  
      
    with open(prompt_path, 'r', encoding='utf-8') as infile, open(result_path, 'w', encoding='utf-8') as outfile:    
        for line in tqdm(infile, total=total_lines, desc="running"):  # tqdm wrapper for progress  
            i += 1  # Increment index for each prompt processed  
            data = json.loads(line)
            data_path = data.get('data_path', '')
            opt_flag = data.get('require_opt', '')
            if opt_flag == 'false':
                
                outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
                continue
            
            final_code = data.get('first_code', '')   # rag or training

            if final_code:   
                final_code = remove_intermediate_prints(final_code) # reduce unfair print
                final_code = exec_string.replace('[[final_code]]', final_code)  
                final_code = final_code.replace('[[data_path]]', data_path)
                
                final_code = remove_print_df_head(final_code)
                final_code = modify_main_function(final_code)
                # Check for dangerous exit calls
                if contains_exit_statements(final_code):
                    data['execution'] = 'Code contains unsafe exit statements.'
                else:
                    try:
                        result_string, errors_string = exec_(final_code)          
                        if result_string:          
                            data['execution'] = result_string 
                        else:
                            data['execution'] = errors_string  
                    except Exception as exec_exception:
                        data['execution'] = f'Execution error: {exec_exception}'
            else:  
                data['execution'] = 'no effective code'
        
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')  
  
def inference():  
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--num_threads', type=int, required=False, help='if your API could be run in parallel')
      
    args = parser.parse_args()  
    
    load_and_infer_from_jsonl_base_exec(args.prompt_path, args.result_path)  
  
if __name__ == "__main__":  
    inference()  
