#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm
import sqlite3
import pandas as pd
import pdb
import re
from tabulate import tabulate
# Assuming this script is saved in a package and the import below refers to another module within the same package    
from concurrent.futures import ThreadPoolExecutor, as_completed  
from model_api_call import get_chat_response_azure  
from functools import partial  
import timeout_decorator
import ast
import logging

import sys    
from contextlib import contextmanager    
from io import StringIO    
import pandas as pd    
    

from contextlib import contextmanager    
from io import StringIO    
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
from sqlparse.tokens import Keyword, DML
from tabulate import tabulate

@contextmanager  
def capture_output():  
    new_out, new_err = StringIO(), StringIO()  
    old_out, old_err = sys.stdout, sys.stderr  
    try:  
        sys.stdout, sys.stderr = new_out, new_err  
        yield new_out, new_err  
    finally:  
        sys.stdout, sys.stderr = old_out, old_err  
        

def extract_code(result):  
    pattern = r"```python\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  
    
def extract_sqlite(result):  
    pattern = r"```sqlite\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  
    

def extract_code_plan(result):  
    pattern = r"```code_plan\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  

def extract_meta_guideline(result):  
    pattern = r"```successful plan suggestions:\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  

def remove_print_statements(code):  
    # This pattern matches print statements. It handles both simple and complex cases.  
    # Note: This pattern might not catch all possible variations of print statements,  
    # especially those spanning multiple lines or using unusual string concatenation.  
    pattern = r'^\s*print\(.*\)\s*$'  
    cleaned_code = re.sub(pattern, '', code, flags=re.MULTILINE)  
    return cleaned_code

def get_sample_data_markdown(db_path):
    markdown_output = ""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    markdown_output += f"# Database Path: {db_path}\n\n"

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    for table in tables:
        table_name = table[0]
        markdown_output += f"## Table: {table_name} (3 rows)\n\n"

        query = f"SELECT * FROM `{table_name}` LIMIT 3"
        df = pd.read_sql_query(query, conn)

        if not df.empty:
            markdown_output += df.to_markdown(index=False) + "\n...\n"
        else:
            markdown_output += "No data available\n\n"
    conn.close()
    return markdown_output

def tsv_to_markdown(path):
    # Read the TSV file into a pandas DataFrame
    df = pd.read_csv(path, sep='\t')
    
    # Generate a Markdown formatted table
    markdown_table = tabulate(df.head(3), headers='keys', tablefmt='pipe', showindex=False)
    
    return markdown_table

def contains_exit_statements(code):
    """
    Check if the given code contains any statements that could terminate the script.
    
    Args:
    - code (str): The code to be checked.
    
    Returns:
    - bool: True if the code contains unsafe exit calls, False otherwise.
    """
    forbidden_calls = ['exit', 'quit', 'sys.exit', 'os._exit']

    for call in forbidden_calls:
        if call in code:
            logging.error(f"Unsafe call detected: {call}")
            return True
    return False

def replace_code_with_placeholder_with_steps(code: str) -> str:
    """
    Replaces each block of code under comments with '[Fill Your Code]', adds step prefixes to the comments, 
    and keeps import statements.

    Args:
        code (str): The string containing the code.

    Returns:
        str: The modified code with placeholders, step prefixes on comments, and preserved import statements.
    """
    lines = code.split('\n')
    modified_lines = []
    step_counter = 1
    comment_block = False

    for line in lines:
        stripped_line = line.strip()
        if stripped_line.startswith('import ') or stripped_line.startswith('from '):
            modified_lines.append(line)
        elif stripped_line.startswith('#'):
            if not comment_block:
                modified_lines.append(line)
                comment_block = True
            else:
                modified_lines.append(line)
        elif comment_block and stripped_line != '':
            modified_lines.append('[Fill Your Code]\n')
            step_counter += 1
            comment_block = False
        else:
            comment_block = False

    # In case the last lines are a comment block without following code
    if comment_block:
        modified_lines.append('[Fill Your Code]\n')

    return '\n'.join(modified_lines)


def remove_print_df_head(code: str) -> str:
    """
    Removes lines containing 'print(df.head())' from a given code string.

    Parameters:
    code (str): The input code string.

    Returns:
    str: The modified code string with the specified lines removed.
    """
    lines = code.split('\n')
    filtered_lines = [line for line in lines if 'print(df.head())' or 'print(df.head(3))' or 'print(df.head(10))' or 'print(data.head(3))' or 'print(data.head(10))' not in line.strip()]
    return '\n'.join(filtered_lines)


@contextmanager  
def capture_output():  
    new_out, new_err = StringIO(), StringIO()  
    old_out, old_err = sys.stdout, sys.stderr  
    try:  
        sys.stdout, sys.stderr = new_out, new_err  
        yield new_out, new_err  
    finally:  
        sys.stdout, sys.stderr = old_out, old_err  
        
@timeout_decorator.timeout(30)
def exec_(code):  
    # Use the custom context manager to capture output and errors  
    with capture_output() as (out, err):  
        try:  
            # Execute the provided code  
            # The globals parameter is crucial for function definition recognition  
            exec(code, {"__builtins__": __builtins__})  
        except Exception as e:  
            # Write any exceptions to the err buffer  
            err.write(f"Error executing code: {e}\n")  
  
    # Get the content from both buffers  
    result_output = out.getvalue()  
    result_errors = err.getvalue()  
  
    # Close the buffers  
    out.close()  
    err.close()  
  
    # Return the captured output and errors  
    return result_output, result_errors


def create_message(prompt_string):  
    """Create a message list for the chat."""  
    return [  
        {"role": "system", "content": "You are a helpful assistant."},  
        {"role": "user", "content": prompt_string}  
    ]  


def extract_formula_plan(result):  
    pattern = r"```formula_plan\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  
    
def extract_formula_completion(result):  
    pattern = r"```formula_completion\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  
    

def mask_formula(text):
    result_lines = []
    lines = text.split('\n')
    
    for line in lines:
        if 'Current Formula' in line or 'Final Formula' in line:
            # Replace everything after ':' with "[Fill Your Formula]"
            converted_line = line.split(':')[0] + ': [Fill Your Formula]'
            result_lines.append(converted_line)
        else:
            result_lines.append(line)
    
    return '\n'.join(result_lines)


def extract_final_formula(text):
    lines = text.split('\n')
    
    for line in lines:
        if 'Final Formula' in line:
            # Extract everything after the ':'
            final_formula = line.split(':', 1)[1].strip()
            return final_formula
    
    return ''  # Return None if no final formula is found

def extract_final_formula(text):
    lines = text.split('\n')
    
    for line in lines:
        if 'Final Formula' in line:
            # Extract everything after the ':'
            final_formula = line.split(':', 1)[1].strip()
            return final_formula
    
    return ''  # Return None if no final formula is found

def extract_result_from_anchor(text):
    """
    Extract the result from within <a></a> tags if present in the text.
    """
    match = re.search(r'<a>(.*?)<\/a>', text)
    return match.group(1) if match else text

def extract_code_plan(result):  
    pattern = r"```code_plan\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return ''  
    
    
def mask_sub_sqls(sql_plan):
    def is_subquery(parsed):
        if not parsed.is_group:
            return False
        for item in parsed.tokens:
            if item.ttype is DML and item.value.upper() == 'SELECT':
                return True
        return False

    def mask_tokens(tokens):
        for token in tokens:
            if isinstance(token, Parenthesis):
                if is_subquery(token):
                    token.tokens = [sqlparse.sql.Token(None, '([[FILL YOUR SQL]])')]
                else:
                    mask_tokens(token.tokens)
            elif isinstance(token, IdentifierList):
                mask_tokens(token.tokens)
            elif isinstance(token, Identifier):
                mask_tokens(token.tokens)

    parsed = sqlparse.parse(sql_plan)
    for stmt in parsed:
        mask_tokens(stmt.tokens)

    return ''.join(str(stmt) for stmt in parsed)


def mask_code_with_comments(code_string):
    masked_code_lines = []
    previous_line_was_comment = False
    step_counter = 1

    for line in code_string.split('\n'):
        stripped_line = line.strip()
        if stripped_line.startswith('#'):
            if previous_line_was_comment:
                masked_code_lines[-1] += f" {line.strip()}"
            else:
                # masked_code_lines.append(f"# Step {step_counter}: {line[1:].strip()}")
                masked_code_lines.append(f"{line[1:].strip()}")
                step_counter += 1
            previous_line_was_comment = True
        elif previous_line_was_comment:
            masked_code_lines.append("[Fill Your Code]\n")
            previous_line_was_comment = False
        else:
            previous_line_was_comment = False

    # Ensure the last line is masked if it followed a comment
    if previous_line_was_comment:
        masked_code_lines.append("[Fill Your Code]\n")

    return '\n'.join(masked_code_lines)

def modify_main_function(code_str):
    """
    This function takes a Python code string and modifies the `if __name__ == "__main__": main()`
    section to just `main()` to ensure it can run properly when executed with exec().
    """
    # Replace the `if __name__ == "__main__": main()` with `main()`
    modified_code = code_str.replace('if __name__ == "__main__":\n    main()', 'main()')
    
    return modified_code


def extract_meta_guideline(result):  
    pattern = r"```successful plan suggestions:\n(.*?)(?:\n```|\n#n```)"  
    python_re = re.search(pattern, result, re.DOTALL)  
    if python_re:  
        code = python_re.group(1)  # Use group(1) directly  
        return code  
    else:  
        return None  


@contextmanager    
def capture_output():    
    new_out, new_err = StringIO(), StringIO()    
    old_out, old_err = sys.stdout, sys.stderr    
    try:    
        sys.stdout, sys.stderr = new_out, new_err    
        yield new_out, new_err    
    finally:    
        sys.stdout, sys.stderr = old_out, old_err    
    
def exec_table(code, table_path):    
    # Replace the placeholder in the code string with the actual table path    
    code = code.replace("[[table_path]]", f"'{table_path}'")   
    
    # Use the custom context manager to capture output and errors    
    with capture_output() as (out, err):    
        try:  
            # Execute the modified code    
            exec(code)    
        except Exception as e:  
            # Write any exceptions to the err buffer  
            err.write(f"Error executing code: {e}\n")  
        
    # Get the content from both buffers    
    result_output = out.getvalue()    
    result_errors = err.getvalue()    
        
    # Close the buffers    
    out.close()    
    err.close()    
        
    # Return the captured output and errors    
    return result_output, result_errors  


@timeout_decorator.timeout(30)
def exec_(code):  
    # Use the custom context manager to capture output and errors  
    with capture_output() as (out, err):  
        try:  
            # Execute the provided code  
            # The globals parameter is crucial for function definition recognition  
            exec(code, {"__builtins__": __builtins__})  
        except Exception as e:  
            # Write any exceptions to the err buffer  
            err.write(f"Error executing code: {e}\n")  
  
    # Get the content from both buffers  
    result_output = out.getvalue()  
    result_errors = err.getvalue()  
  
    # Close the buffers  
    out.close()  
    err.close()  
  
    # Return the captured output and errors  
    return result_output, result_errors


def contains_exit_statements(code):
    """
    Check if the given code contains any statements that could terminate the script.
    
    Args:
    - code (str): The code to be checked.
    
    Returns:
    - bool: True if the code contains unsafe exit calls, False otherwise.
    """
    forbidden_calls = ['exit', 'quit', 'sys.exit', 'os._exit']

    for call in forbidden_calls:
        if call in code:
            logging.error(f"Unsafe call detected: {call}")
            return True
    return False


def remove_intermediate_prints(code: str) -> str:
    # Split the code into lines
    lines = code.splitlines()
    
    # Find all the lines that contain print statements
    print_lines = [i for i, line in enumerate(lines) if 'print(' in line]
    
    # If there's more than one print statement, remove all but the last one
    if len(print_lines) > 1:
        for i in sorted(print_lines[:-1], reverse=True):
            del lines[i]
    
    # Join the lines back into a single string
    return '\n'.join(lines)

@timeout_decorator.timeout(30)
def exec_(code):  
    # Use the custom context manager to capture output and errors  
    with capture_output() as (out, err):  
        try:  
            # Execute the provided code  
            # The globals parameter is crucial for function definition recognition  
            exec(code, {"__builtins__": __builtins__})  
        except Exception as e:  
            # Write any exceptions to the err buffer  
            err.write(f"Error executing code: {e}\n")  
  
    # Get the content from both buffers  
    result_output = out.getvalue()  
    result_errors = err.getvalue()  
  
    # Close the buffers  
    out.close()  
    err.close()  
  
    # Return the captured output and errors  
    return result_output, result_errors

def create_message(prompt_string):  
    """Create a message list for the chat."""  
    return [  
        {"role": "system", "content": "You are a helpful assistant."},  
        {"role": "user", "content": prompt_string}  
    ]  


def remove_comments(code: str) -> str:
    lines = code.split('\n')
    cleaned_code = []
    for line in lines:
        # Remove any part of the line after a comment symbol '#'
        if '#' in line:
            line = line.split('#', 1)[0].rstrip()  # Keep only the code before the comment
        if line.strip():  # Skip empty lines
            cleaned_code.append(line)
    return '\n'.join(cleaned_code)

def extract_action(text):
    pattern = r'<action>(.*?)</action>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return ''
    

def extract_plan(code_string):
    lines = code_string.split('\n')
    plan = []

    for line in lines:
        line = line.strip()
        if line.startswith('# Step'):
            plan.append(line[2:].strip())  # Remove the leading '#' and extra spaces
    
    return "\n".join(plan)


def extract_score(text):
    pattern = r'<s>(.*?)</s>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return ''