# Copyright 2024 The Chain-of-Table authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import re
from tqdm import tqdm
import numpy as np
import os
import traceback
import shutil
import ast
import logging
import os

from operations import *
from utils.helper import *
from utils.prompts import *

from collections import Counter
from sqlalchemy.sql import text
import multiprocessing as mp
from utils.helper import table2string
from collections import defaultdict
import pickle

#################################################################################################################################################################
#################################################################################################################################################################
#################################################################################################################################################################

import ast

def clean_element(element):
    # Remove any leading or trailing quotes and strip extra whitespace
    element = element.strip()
    if element.startswith('"') and not element.endswith('"'):
        return element[1:]
    if element.endswith('"') and not element.startswith('"'):
        return element[:-1]
    return element.strip('"')

def merge_elements(elements):
    if len(elements) == 2 and elements[0].startswith('"') and elements[1].endswith('"'):
        return [clean_element(elements[0] + ', ' + elements[1])]
    return [clean_element(element) for element in elements]

def process_string(line):
    line = line.strip()
    # Check if the line is already a list (like ['2', '5', '12'])
    if line.startswith('[') and line.endswith(']'):
        try:
            # Safely evaluate the list string to an actual list
            line_list = ast.literal_eval(line)
            # Ensure all elements are strings and clean up any quotes
            line_list = merge_elements([str(elem) for elem in line_list])
            return line_list
        except (ValueError, SyntaxError):
            # Handle cases where the input is like [indenpendent]
            if line.count('[') == 1 and line.count(']') == 1:
                cleaned_line = line.strip('[]').strip()
                return [cleaned_line]
            return [clean_element(line)]
    else:
        # Handle lines that are not lists
        return [clean_element(line)]

def is_list_of_strings(variable):
    # Check if the variable is a list
    if isinstance(variable, list):
        # Check if all elements in the list are strings
        return all(isinstance(item, str) for item in variable)
    return False

# Doing fallback here
def wikitq_fall_back(fb_table, sample, llm):
    table_info = {}
    table_info["table_text"] = fb_table
    sample_copy = wikitq_simple_query(sample, table_info, llm, debug=False, use_demo=False, llm_options=None)

    fallback_answer = sample_copy["chain"][0]['parameter_and_conf'][0][0].lower().rstrip('.')

    print(type(fallback_answer))
    print('wikitq answer for fall back before:\n', fallback_answer)

    fallback_answer = process_string(fallback_answer)
    print('wikitq answer for fall back after:\n', fallback_answer)

    # if is_list_of_strings(fallback_answer) is True:
    #     print('Wrong post-processing: ', fallback_answer)

    return [fallback_answer]

def wikitq_natural_language_plan_construct_prompt(sample, intermediate_table, action, table_name, statement):
    table_name = 'table_sql'
    question = sample['statement']
    table_text = sample['table_text']
    answer = sample['answer']
    
    if len(intermediate_table[0]) > 1:
        existing_cols = " or ".join(f"{item}" for item in intermediate_table[0])
    else:
        existing_cols = f"{intermediate_table[0][0]}"    

    prompt = ""
    prompt += wikitq_natural_language_step_demo
    prompt += "\n####\n"

    prompt += f"Given this table:\n"
    prompt += "/*\n" + table2string(intermediate_table) + "\n*/\n"

    data_type = table2df(intermediate_table).dtypes

    prompt += "\nData types of columns:\n"
    for col, dtype in data_type.items():
        dtype_str = "string" if dtype == "object" else str(dtype)
        prompt += f"- {col}: {dtype_str}\n"

    prompt += f"\nWrite a SQL command that: {action}\n"

    num_rows = len(table_text) -1
    
    prompt += f"\n\nConstraints for your SQL:\n"

    prompt += "\n1.The columns used in your SQL MUST be: "
    prompt += f"{existing_cols}."
    prompt += "\n Otherwise, you will be PENALIZED!"

    prompt += f"\n2.{syntax_instr1} If adding new columns, they MUST be different than existing columns {existing_cols}"
    
    prompt += f"\n3.Your SQL command MUST be compatible and executable by python sqlite3."
    prompt += f"\n4.If using FROM, the table to be selected MUST be {table_name}."
    prompt += f"\n5.You MUST look at the cell contents in {table_name} and consider data format to avoid problems of exact matchings in SQL. Sometimes, data in the action and Table are not in the same format."
    prompt += f"\n6.If there are conflicting data types between the table and the natural language action, you should convert them into the same data type. Note: 'object' data type corresponds to 'string' data type."

    
    prompt += "\n####\n"
    return prompt

def wikitq_natural_language_chain_exec_one_sample(sample, llm, llm_options=None, strategy="top", debug=False):
    table_name = 'table_sql'
    question = sample['statement']
    table_text = sample['table_text']
    answer = sample['answer']
    sample_id = sample['id']

    logger, log_filename = wikitq_setup_logger(sample_id)


    original_table = copy.deepcopy(table_text)  # Store the original table
    groundtruth = answer
    results = []
    is_sql_executable = True

    try:
        # PLANNING
        plans, plans_generated_successfully = wikitq_generate_natural_language_planning(
            sample, llm=llm, llm_options=llm_options, strategy=strategy, debug=debug
        )

        if not plans or not plans_generated_successfully:
            logger.error('Failed to generate plans or initial executable flag is False!')
            print('ERR2: Failed to generate plans or initial executable flag is False!')
            is_sql_executable = False
            return sample_id, 'N/A', is_sql_executable, groundtruth, {}        

        for plan_idx, plan in enumerate(plans):
            intermediate_table = copy.deepcopy(original_table)  # Reset the table for each plan
            all_operations_successful = True

            logger.info('*' * 120)

            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Query: {sample["statement"]}')
            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Groundtruth: {groundtruth}')
            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: X-Original table pd: \n{table2df(intermediate_table)}')
            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Caption: none')

            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Original table: {intermediate_table}')


            print('DB1: Generated plan:')
            for operation_idx, operation in enumerate(plan):
                print('DB1: operation:', operation)

            int_tables = []
            int_tables.append(original_table)
            for operation_idx, operation in enumerate(plan):
                if operation_idx == (len(plan) - 1):  # Last operation
                    question = sample["statement"]
                else:
                    question = None

                prompt = wikitq_natural_language_plan_construct_prompt(sample, intermediate_table, operation, table_name, question)
                logger.info('#' * 120)
                try:
                    responses = llm.generate_plus_with_score(prompt, options=llm_options, end_str="\n\n")
                    
                    if responses and len(responses) > 0 and len(responses[0]) > 0:
                        sql_command = extract_sql_code(responses[0][0])
                    else:
                        logger.error("No responses or unexpected response format.")
                        print(f'ERR3: No responses or unexpected response format:', responses)
                        continue  # Skip to the next iteration of the loop or handle error as needed
                    
                    previous_ops = plan[0:operation_idx]
                    logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Operation {operation_idx + 1}: {operation}')

                    if SQL_EXECUTOR == 'SQL_ALCHEMY':
                        sql_command = text(sql_command)
                        intermediate_table, selected_indices = transform_table_with_sqlalchemy(intermediate_table, sql_command, table_name)
                        logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Selected indices: {selected_indices}')

                    else:
                        intermediate_table, selected_indices = transform_table_with_sql(intermediate_table, sql_command, table_name)
                        logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Selected indices: {selected_indices}')
                    
                    int_tables.append(intermediate_table)

                    logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: X-Table after operation df:\n{table2df(intermediate_table)}')
                    logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Table after operation: {(intermediate_table)}')

                except Exception as e:
                    logger.error(f"SQL execution error in operation {operation_idx + 1}: {e}")
                    print(f'ERR4: SQL execution error in operation {operation_idx + 1}: {e}')
                    print(traceback.format_exc())  # Print the detailed traceback information
                    all_operations_successful = False
                    break
        
        fall_back_llm = True
        if all_operations_successful is True:
            # Remove the header for the final answer
            final_answer = intermediate_table[1:]

            print('formatted answer:', final_answer)
            
            if len(final_answer) == 0 or None in final_answer[0]:
                print('empty final ans:', final_answer)
                # WIKITQ: Doing fallback here with original table if SQL is executable cannot give answer in the right format
                # final_answer = wikitq_fall_back(intermediate_table, sample, llm)

                final_answer = wikitq_fall_back(original_table, sample, llm)

                # if len(int_tables) > 1:
                #     final_answer = wikitq_fall_back(int_tables[1], sample, llm)
                # else:
                #     final_answer = wikitq_fall_back(original_table, sample, llm)


                print('DB1: Question:', question)
                print('DB1: Answer:', answer)
                print('WIKITQ final answer for fall back 3:\n', final_answer)
                logger.info(f'Fall-back: TRUE')

            else:
                logger.info(f'Fall-back: FALSE')
                fall_back_llm = False

            logger.info(f'Answer from plan {plan_idx + 1}: {final_answer}')
            logger.info(f'Groundtruth: {groundtruth}')
            
        else:
            logger.error("Intermediate table does not have the expected structure.")
            # final_answer = 'N/A'

            # WIKITQ: Doing fallback here with original table if SQL is executable cannot give answer in the right format
            # final_answer = wikitq_fall_back(intermediate_table, sample, llm)

            final_answer = wikitq_fall_back(original_table, sample, llm)

            # if len(int_tables) > 1:
            #     final_answer = wikitq_fall_back(int_tables[1], sample, llm)
            # else:
            #     final_answer = wikitq_fall_back(original_table, sample, llm)

            logger.info(f'Fall-back: TRUE')
            logger.info(f'Answer from plan {plan_idx + 1}: {final_answer}')
            logger.info(f'Groundtruth: {groundtruth}')


            print('DB1: Question:', question)
            print('DB1: Answer:', answer)
            print('WIKITQ final answer for fall back 2:\n', final_answer)


        return sample_id, final_answer, is_sql_executable, groundtruth, {}, fall_back_llm

    except Exception as e:
        print(f'ERR1: Unexpected error occurred: {e}')
        print(traceback.format_exc())  # Print the detailed traceback information
        is_sql_executable = False
        return sample_id, 'N/A', is_sql_executable, groundtruth, {}, fall_back_llm


def wikitq_generate_natural_language_planning(sample, debug=False, llm=None, llm_options=None, strategy="top"):
    # Set up LLM options
    if llm_options is None:
        llm_options = llm.get_model_options()
    # llm_options["n"] = OPERATION_NUMS  # Request multiple responses for a single prompt
    llm_options["n"] = K_plans  # Request multiple responses for a single prompt
    
    if llm_options["n"] > 1:
        llm_options["temperature"] = 0.8
        llm_options["top_p"] = 1.0

    statement = sample['statement']
    table_text = sample['table_text']
    answer = sample['answer']

    is_sql_executable = False
    num_rows = len(table_text) -1

    prompt = ""
    
    prompt += wikitq_natural_language_plan_demo + "\n"

    prompt += "\n### Here come to your task!\n"
    # We dont have table caption in WikiTQ
    prompt += "/*\n" + table2string(table_text) + "\n*/\n"
    prompt += f"This Table has {num_rows} rows.\n"
    prompt += "Question: " + statement + "\n"


    prompt += """
Let's develop a precise and detailed step-by-step plan to answer the given Question based on the provided Table.

Your steps will later be converted to SQL commands to transform the Table into the final answer for the question. You MUST thoroughly analyze and understand the Question before writing the plan.

Plan Steps:
1. Each step in your plan should be atomic and straightforward, ensuring they can be easily executed or converted into SQL.
2. You MUST closely examine the Question and ensure all conditions are checked accurately.

Step Order:
1. The order of steps is crucial! Ensure the steps logically support the correct answering.
2. Each step will be executed sequentially, with the next step operating on the output table of the previous step. The first step will be executed on the given Table.

Final Step:
Ensure the last step involves selecting the relevant cells or calculating the values that correctly answer the Question.

Plan:\n
    """
    # if True:
    #     print('Model prompt for plan:\n')
    #     print(prompt)
    #     print('X'*100)
        
    try:
        responses = llm.generate_plus_with_score(
            prompt, options=llm_options, end_str="\n\n"
        )
        is_sql_executable = True

    except Exception as e:
        print('ERR1: Cannot generate plans:', (e))
        return None,  is_sql_executable

    # Extract the plan
    responses.sort(key=lambda x: x[1], reverse=True)
    plans = []
    for response, score in responses:
        plans.append(plan_to_step_list(response))

    # print('generated plans:\n', plans)
    return plans, is_sql_executable




##################################################################################################################
##################################################################################################################
##################################################################################################################
##################################################################################################################

# Doing fallback here for TABFACT
def fall_back(fb_table, sample, llm):
    table_info = {}
    table_info["table_text"] = fb_table
    sample_copy = simple_query(sample, table_info, llm, debug=False, use_demo=False, llm_options=None)

    print('answer for fall back:\n', sample_copy)
    
    fallback_answer = sample_copy["chain"][0]['parameter_and_conf'][0][0].upper()

    if fallback_answer == 'YES':
        answer = 'TRUE'
    elif fallback_answer == 'NO':
        answer = 'FALSE'
    else:
        answer = 'N/A'

    return answer

def natural_language_plan_construct_prompt(sample, intermediate_table, action, table_name, statement):
    
    if len(intermediate_table[0]) > 1:
        existing_cols = " or ".join(f"{item}" for item in intermediate_table[0])
    else:
        existing_cols = f"{intermediate_table[0][0]}"    

    prompt = ""
    prompt += natural_language_step_demo
    prompt += "\n####\n"

    prompt += f"Given this table:\n"
    prompt += "/*\n" + table2string(intermediate_table) + "\n*/\n"

    data_type = table2df(intermediate_table).dtypes
    # print(intermediate_table)
    # print(data_type)

    prompt += "\nData types of columns:\n"
    for col, dtype in data_type.items():
        dtype_str = "string" if dtype == "object" else str(dtype)
        prompt += f"- {col}: {dtype_str}\n"

    prompt += f"\nWrite a SQL command that: {action}\n"

    table_info = get_table_info(sample)
    num_rows = len(table_info["table_text"]) -1
    prompt += f"The original table has {num_rows} rows.\n"
    
    prompt += f"\n\nConstraints for your SQL:\n"

    # prompt += "\n1.The columns used in your SQL MUST be: "
    # prompt += f"{existing_cols}."
    # prompt += "\n Otherwise, you will be PENALIZED!"

    prompt += f"\n1.{syntax_instr1} If adding new columns, they should be different than columns {existing_cols}"
    
    prompt += f"\n2. Your SQL command MUST be compatible and executable by python sqlite3 and pandas."
    prompt += f"\n3. If using FROM, the table to be selected MUST be {table_name}."
    # prompt += f"\n5.You MUST look at the cell contents in {table_name} and consider data format to avoid problems of exact matchings in SQL. Sometimes, data in Statement and Table are not in the same format."
    # prompt += f"\n6.If there are conflicting data types between the table and the natural language action, you should convert them into the same data type. Note: 'object' data type corresponds to 'string' data type."
    # if statement is not None:
    #     prompt += f"\n6.Your SQL MUST help verify this Statement: {statement}"
    
    prompt += "\n####\n"
    return prompt

def natural_language_chain_exec_one_sample(sample, llm, llm_options=None, strategy="top", debug=False):
    logger, log_filename = setup_logger(sample["id"])
    table_info = get_table_info(sample)
    table_name = 'table_sql'
    sample_id = sample["id"]
    table_caption = sample['table_caption']

    original_table = copy.deepcopy(table_info["table_text"])  # Store the original table
    groundtruth = "TRUE" if sample["label"] == 1 else "FALSE"
    results = []

    try:
        # PLANNING
        plans, plans_generated_successfully = generate_natural_language_planning(
            sample, llm=llm, llm_options=llm_options, strategy=strategy, debug=debug
        )

        if not plans or not plans_generated_successfully:
            logger.error('Failed to generate plans or initial executable flag is False!')
            print('ERR2:Failed to generate plans or initial executable flag is False!')

            return sample_id, 'N/A', False, groundtruth, {}, None

        for plan_idx, plan in enumerate(plans):

            intermediate_table = copy.deepcopy(original_table)  # Reset the table for each plan
            all_operations_successful = True
            logger.info('*' * 120)

            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Statement: {sample["statement"]}')
            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Groundtruth: {groundtruth}')
            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: X-Original table pd: \n{table2df(intermediate_table)}')
            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Caption: {table_caption}')

            logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Original table: {intermediate_table}')


            for operation_idx, operation in enumerate(plan):
                if operation_idx == (len(plan) - 1):  # Last operation
                    statement = sample["statement"]
                else:
                    statement = None

                prompt = natural_language_plan_construct_prompt(sample, intermediate_table, operation, table_name, statement)
                logger.info('#' * 120)
                try:
                    # print('prompt in to the model to get SQL:\n', prompt)
                    # PLAN TO SQL
                    responses = llm.generate_plus_with_score(prompt, options=llm_options, end_str="\n\n")
                    
                    # print('response from to the model to get SQL:\n', responses)

                    if responses and len(responses) > 0 and len(responses[0]) > 0:
                        sql_command = extract_sql_code(responses[0][0])
                    else:
                        logger.error("No responses or unexpected response format.")
                        print(f'ERR3: No responses or unexpected response format:', responses)
                        continue  # Skip to the next iteration of the loop or handle error as needed
                    
                    previous_ops = plan[0:operation_idx]
                    logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Operation {operation_idx + 1}: {operation}')

                    if SQL_EXECUTOR == 'SQL_ALCHEMY':
                        sql_command = text(sql_command)
                        intermediate_table, selected_indices = transform_table_with_sqlalchemy(intermediate_table, sql_command, table_name)
                        logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Selected indices: {selected_indices}')

                    else:
                        intermediate_table, selected_indices = transform_table_with_sql(intermediate_table, sql_command, table_name)
                        logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Selected indices: {selected_indices}')

                    logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: X-Table after operation df:\n{table2df(intermediate_table)}')
                    logger.info(f'Sample {sample_id} - Plan {plan_idx+1}: Table after operation: {(intermediate_table)}')

                    if len(intermediate_table) == 2 and (intermediate_table[0][0] == 'verification_result' or intermediate_table[0][0] == 'comparison_result'):
                        all_operations_successful = True
                        break
                except Exception as e:
                    logger.error(f"SQL execution error in operation {operation_idx + 1}: {e}")
                    print(f'Sample {sample_id} - Plan {plan_idx+1}: Operation {operation_idx + 1}: {operation}')
                    print(f'ERR4: SQL execution error in operation {operation_idx + 1}: {e}')
                    break
            
            fall_back_llm = True
            if all_operations_successful:
                if len(intermediate_table) > 1 and len(intermediate_table[1]) > 0:
                    answer = intermediate_table[1][0]
                    if answer in ["TRUE", "FALSE"]:
                        logger.info(f'Answer from plan {plan_idx + 1}: {answer}')
                        logger.info(f'Groundtruth: {groundtruth}')
                        fall_back_llm = False
                        logger.info(f'Fall-back: FALSE')

                    else:
                        # Doing fallback here with original table if SQL is executable cannot give answer in the right format
                        answer = fall_back(original_table, sample, llm)
                        print('final answer for fall back 3:\n', answer)
                        logger.info(f'Fall-back: TRUE')

                        logger.info(f'Answer from plan {plan_idx + 1}: {answer}')
                        logger.info(f'Groundtruth: {groundtruth}')

                else: # Doing fallback here with original table if SQL is executable cannot give answer in the right format
                    # Doing fallback here with original table
                    answer = fall_back(original_table, sample, llm)
                    logger.info(f'Fall-back: TRUE')

                    logger.info(f'Answer from plan {plan_idx + 1}: {answer}')
                    logger.info(f'Groundtruth: {groundtruth}')
                    print('final answer for fall back 2:\n', answer)
            else:
                # Doing fallback here with original table if SQL is failed
                    answer = fall_back(original_table, sample, llm)
                    logger.info(f'Fall-back: TRUE')

                    logger.info(f'Answer from plan {plan_idx + 1}: {answer}')
                    logger.info(f'Groundtruth: {groundtruth}')
                    print('final answer for fall back 1:\n', answer)

            results.append(answer)

        if results:
            # Majority vote
            result_counter = Counter(results)
            final_answer = result_counter.most_common(1)[0][0]
            is_sql_executable = True
            result_dict = dict(result_counter)
        else:
            final_answer = 'N/A'
            is_sql_executable = False
            print('it comes here')
            result_dict = {}

        # Determine the correctness of the answer
        if final_answer == groundtruth:
            if final_answer == "TRUE":
                correctness_dir = "TP"
            else:
                correctness_dir = "TN"
        else:
            if final_answer == "TRUE":
                correctness_dir = "FP"
            elif final_answer == "FALSE":
                correctness_dir = "FN"

        # Move the log file to the corresponding directory
        log_directory = os.path.dirname(log_filename)
        target_directory = os.path.join(log_directory, correctness_dir)
        os.makedirs(target_directory, exist_ok=True)
        shutil.move(log_filename, os.path.join(target_directory, os.path.basename(log_filename)))

        return sample_id, final_answer, is_sql_executable, groundtruth, result_dict, fall_back_llm

    except Exception as e:
        logger.error(f"Unexpected error in generating plans: {e}")
        print(f'ERR5: Planning failed!: {e}')

        return sample_id, 'N/A', False, groundtruth, {}, None

def generate_natural_language_planning(
    sample,
    debug=False,
    llm=None,
    llm_options=None,
    strategy="top",
):

    # Set up LLM options
    if llm_options is None:
        llm_options = llm.get_model_options()
    # llm_options["n"] = OPERATION_NUMS  # Request multiple responses for a single prompt
    llm_options["n"] = K_plans  # Request multiple responses for a single prompt
    
    if llm_options["n"] > 1:
        llm_options["temperature"] = 0.8
        llm_options["top_p"] = 1.0

    table_info = get_table_info(sample)
    act_chain = table_info["act_chain"]
    caption = sample["table_caption"]
    is_sql_executable = False
    num_rows = len(table_info["table_text"]) -1

    prompt = ""
    
    prompt += natural_language_plan_demo + "\n"
    
    prompt += "\n### Here come to your task!\n"
    prompt += f"table caption: {caption}\n"
    prompt += "/*\n" + table2string(table_info["table_text"]) + "\n*/\n"
    prompt += f"This Table has {num_rows} rows.\n"
    prompt += "Statement: " + sample["statement"] + "\n"

    prompt += """
Let's develop a step-by-step plan to verify if the given Statement is TRUE or FALSE on the given Table!
You MUST think carefully analyze the Statement and comprehend it before writing the plan!

Plan Steps: Each step in your plan should be very atomic and straightforward, ensuring they can be easily executed or converted into SQL.
You MUST make sure all conditions (except those mentioned in the table caption) are checked properly in the steps.

Step order: The order of steps is crucial! You must ensure the orders support the correct information retrieval and verification!
The next step will be executed on the output table of the previous step. The first setp will be executed on the given Table.
For comparative or superlative Statement involving "highest," "lowest," "earliest," "latest," "better," "faster," "earlier," etc., you should order the table accordingly before selecting rows. This ensures that the desired comparative or superlative data is correctly retrieved.

Plan:\n
    """

    # if True:
    #     print('Model prompt for plan:\n')
    #     print(prompt)
    #     print('X'*100)
        
    try:
        responses = llm.generate_plus_with_score(
            prompt, options=llm_options, end_str="\n\n"
        )
        is_sql_executable = True

    except Exception as e:
        print('ERR1: Cannot generate plans:', (e))
        return None,  is_sql_executable

        
    # if True:
    #     print('Model response for plan:\n')
    #     print(responses)
    #     print('X'*100)

    # Extract the plan
    responses.sort(key=lambda x: x[1], reverse=True)
    plans = []
    for response, score in responses:
        plans.append(plan_to_step_list(response))

    # print('generated plans:\n', plans)
    return plans, is_sql_executable


##################################################################################################################
##################################################################################################################
##################################################################################################################
##################################################################################################################


# proc_sample ~= init_samples
def fixed_chain_exec_mp(llm, init_samples, fixed_op_list, n_proc=10, chunk_size=50):
    history = {}
    final_result = None

    chain_header = copy.deepcopy(init_samples)
    chain_key = ""

    if DEBUG:
        print(fixed_op_list)

    for i, (op_name, solver_func, kargs, llm_kargs) in enumerate(fixed_op_list):
        # Here is where the table manipulation is done
        chain_key += f"->{op_name}" if i > 0 else op_name
        if DEBUG:
            print(op_name)
        chain_header = conduct_single_solver_mp(
            llm=llm,
            all_samples=chain_header,
            solver_func=solver_func,
            tqdm_tag=op_name,
            n_proc=n_proc,
            chunk_size=chunk_size,
            llm_options=llm.get_model_options(
                **llm_kargs,
            ),
            **kargs,
        )

        history[f"({i}) {chain_key}"] = chain_header
        if i == len(fixed_op_list) - 1:
            final_result = chain_header

    return final_result, history


def conduct_single_solver(llm, all_samples, solver_func, tqdm_tag=None, **kwargs):
    result_samples = [None for _ in range(len(all_samples))]

    for idx in tqdm(range(len(all_samples)), desc=tqdm_tag):
        try:
            sample = all_samples[idx]
            table_info = get_table_info(
                sample,
                skip_op=kwargs.get("skip_op", []),
                first_n_op=kwargs.get("first_n_op", None),
            )
            proc_sample = solver_func(sample, table_info, llm, **kwargs)
            result_samples[idx] = proc_sample
        except Exception as e:
            print(f"CKPT1: Error in {idx}th sample: {e}")
            continue
    return result_samples


def _conduct_single_solver_mp_core(arg):
    idx, sample, llm, solver_func, kwargs = arg
    try:
        table_info = get_table_info(
            sample,
            skip_op=kwargs.get("skip_op", []),
            first_n_op=kwargs.get("first_n_op", None),
        )
        proc_sample = solver_func(sample, table_info, llm, **kwargs)
        return idx, proc_sample
    except Exception as e:
        traceback.print_exc()
        print(f"CKPT: Error in {idx}-th sample: {e}")
        return idx, None


def conduct_single_solver_mp(
    llm, all_samples, solver_func, tqdm_tag=None, n_proc=10, chunk_size=50, **kwargs
):
    result_samples = [None for _ in range(len(all_samples))]

    args = [
        (idx, sample, llm, solver_func, kwargs)
        for idx, sample in enumerate(all_samples)
    ]
    
    with mp.Pool(n_proc) as p:
        for idx, proc_sample in tqdm(
            p.imap_unordered(_conduct_single_solver_mp_core, args, chunksize=chunk_size),
            total=len(all_samples),
            desc=tqdm_tag,
        ):
            result_samples[idx] = proc_sample
            print(result_samples[idx])
            print(all_samples[idx]['is_sql_executable'])
            result_samples[idx]['is_sql_executable'] = all_samples[idx]['is_sql_executable']

    return result_samples


def get_act_func(name, using_sql):
    try:
        # if ('add_column' in name or 'select_row' in name or 'sort' in name) \
        # and USING_SQL is True:
        # if USING_SQL is True:

        if using_sql is True:
            return eval(f"{name}_act_sql")
        else:
            return eval(f"{name}_act")
    except:

        def _default_act(table_text, *args, **kwargs):
            return copy.deepcopy(table_text)

        if "query" not in name:
            print("Unknown operation: ", name)
        return _default_act


def get_table_info(sample, skip_op=[], first_n_op=None):
    table_text = sample["table_text"]
    chain = sample["chain"]
    if 'using_sql' in sample:
        using_sql = sample["using_sql"]

    if first_n_op is not None:
        chain = chain[:first_n_op]

    table_info = {
        "table_text": table_text,
        "act_chain": [],
    }
    for operation in chain:
        operation_name = operation["operation_name"]

        act_func = get_act_func(operation_name, using_sql)
        if DEBUG:
            print(table_info)
            print(operation)
        table_info = act_func(table_info, operation, skip_op=skip_op)

    return table_info


def get_table_log(sample, skip_op=[], first_n_op=None):
    table_text = sample["table_text"]
    chain = sample["chain"]

    if first_n_op is not None:
        chain = chain[:first_n_op]

    table_log = []

    table_info = {
        "table_text": table_text,
        "act_chain": [],
    }
    table_log.append(table_info)

    for operation in chain:
        operation_name = operation["operation_name"]    
        act_func = get_act_func(operation_name, using_sql=False)
        table_info = act_func(table_info, operation, skip_op=skip_op)
        if DEBUG:
            print(operation_name)
        if 'row' in operation_name:
            # print('HERE')
            # print(table_info)
            # print('HERE')
            if '_real_select_rows' in table_info:
                table_info['act_chain'][-1] = table_info['_real_select_rows']
            # else:
            #     table_info['act_chain'][-1] = table_info['act_chain']

        if 'query' in operation_name:
            table_info['act_chain'].append(f'{operation_name}()')
            table_info['cotable_result'] = operation['parameter_and_conf'][0][0]
        table_log.append(table_info)

    return table_log

def get_operation_name(string):
    # f_xxxx(...)
    res = re.findall(r"f_(.*?)\(.*\)", string)[0]
    return res


def get_all_operation_names(string):
    if DEBUG:
        print('Here print the operation names:')
        print(string)
    operation_names = []
    parts = string.split("->")
    for part in parts:
        part = part.strip()
        if part == "<END>":
            operation_names.append("<END>")
        else:
            res = re.findall(r"f_(.*?)\(.*\)", part)
            if res:
                operation_names.append(res[0])
    return operation_names

def generate_prompt_for_next_step(
    sample,
    debug=False,
    llm=None,
    llm_options=None,
    strategy="top",
):
    table_info = get_table_info(sample)
    act_chain = table_info["act_chain"]

    if debug:
        print("Act Chain: ", act_chain, flush=True)

    kept_act_chain = [x for x in act_chain if not x.startswith("skip")]
    kept_act_chain_str = " -> ".join(kept_act_chain)
    if kept_act_chain_str:
        kept_act_chain_str += " ->"

    skip_act_chain = [x for x in act_chain if x.startswith("skip")]
    skip_act_chain_op_names = []
    for op in skip_act_chain:
        op = op[len("skip ") :]
        op_name = get_operation_name(op)
        skip_act_chain_op_names.append(op_name)

    if debug:
        print("Kept Act Chain: ", kept_act_chain, flush=True)
        print("Skip Act Chain: ", skip_act_chain, flush=True)

    if USING_SQL is True:
        # Set the first operation to 'select_column' if the act_chain is empty
        if not kept_act_chain:
            log = {
                "act_chain": act_chain,
                "last_operation": "<init>",
                "possible_next_operations": ["select_column"],
                "prompt": None,
                "response": None,
                "generate_operations": None,
                "next_operation": "select_column",
            }
            return "select_column", log

    last_operation = (
        "<init>" if not kept_act_chain else get_operation_name(kept_act_chain[-1])
    )
    possible_next_operations = possible_next_operation_dict[last_operation]
    possible_next_operations = [
        x for x in possible_next_operations if x not in skip_act_chain_op_names
    ]

    # Remove f_sort_column() if only one row in the table
    table_text = table_info["table_text"]
    if DEBUG:
        print('Table text:')
        print(table_text)
        print('Possible next operations:')
        print(possible_next_operations)
    
    if USING_SQL is True:
        if len(table_text) <= 2:  # Check if there's only one data row besides the header
            print('before:', possible_next_operations)
            print('One-row table! Skipping operations')
            possible_next_operations = [op for op in possible_next_operations if op != "sort_column" and op != "select_row" and op!= "group_column"]
            print('after:', possible_next_operations)

    if DEBUG:
        print("Last Operation: ", last_operation, flush=True)
        print("Final Possible Next Operations: ", possible_next_operations, flush=True)

    if len(possible_next_operations) == 1:
        log = {
            "act_chain": act_chain,
            "last_operation": last_operation,
            "possible_next_operations": possible_next_operations,
            "prompt": None,
            "response": None,
            "generate_operations": None,
            "next_operation": possible_next_operations[0],
        }
        return possible_next_operations[0], log

    prompt = ""
    for operation in possible_next_operations:
        if operation == "<END>":
            continue
        if USING_SQL is True:
            prompt += eval(f"plan_{operation}_demo_sql") + "\n\n"
        else:
            prompt += eval(f"plan_{operation}_demo") + "\n\n"

    prompt += plan_full_demo_simple + "\n\n"
    
    if USING_SQL:
        prompt += "########\nHere is your actual task.\n"
        # prompt += "\nPlease always try to use select_row function for the first operation in the chain. For example, Function Chain: f_select_column() -> f_select_row() -> f_select_column() -> f_group_column() -> <END>\n"

    prompt += "/*\n" + table2string(table_info["table_text"]) + "\n*/\n"
    prompt += "Statement: " + sample["statement"] + "\n"

    _possible_next_operations_str = " or ".join(
        [f"f_{op}()" if op != "<END>" else op for op in possible_next_operations]
    )

    if len(possible_next_operations) > 1:
        prompt += (
            f"The next operation must be one of {_possible_next_operations_str}.\n"
        )
    else:
        prompt += f"The next operation must be {_possible_next_operations_str}.\n"

    prompt += "Function Chain: " + kept_act_chain_str
    if DEBUG:
        print('Before prompting to get operation')
        print(prompt)
    
    if USING_SQL is True:
        try:
            responses = llm.generate_plus_with_score(
                prompt, options=llm_options, end_str="\n\n"
            )
        except Exception as e:
            print(f"Error when prompting model in generating next operation. Maybe the prompt is too long and exceeds context length!")
            log = {
            "act_chain": act_chain,
            "last_operation": last_operation,
            "possible_next_operations": possible_next_operations,
            "prompt": None,
            "response": None,
            "generate_operations": None,
            "next_operation": possible_next_operations[0],
            }
            return possible_next_operations[0], log
    else:
        responses = llm.generate_plus_with_score(
                prompt, options=llm_options, end_str="\n\n"
            )
        
    if DEBUG:
        print('Model response:')
        print(responses)
        print('*'*100)

    # choose the first suggested operation
    if strategy == "top":
        response = responses[0][0]
        # print(responses)
        generate_operations = get_all_operation_names(response)
        if DEBUG:
            print('Prompt:', prompt.split("\n\n")[-1])
            print('Response:', response)
            print("Generated Operations: ", generate_operations)
        next_operation = "<END>"
        for operation in generate_operations:
            if operation in possible_next_operations:
                next_operation = operation
                # print('Next operation:', next_operation)

                break
    elif strategy == "voting":
        next_operation_conf_dict = defaultdict(float)
        for response, score in responses:
            generate_operations = get_all_operation_names(response)
            next_operation = None
            for operation in generate_operations:
                if operation in possible_next_operations:
                    next_operation = operation
                    break
            if next_operation:
                next_operation_conf_dict[next_operation] += np.exp(score)
        if len(next_operation_conf_dict) != 0:
            next_operation_conf_pairs = sorted(
                next_operation_conf_dict.items(), key=lambda x: x[1], reverse=True
            )
            next_operation = next_operation_conf_pairs[0][0]
        else:
            next_operation = "<END>"

    log = {
        "act_chain": act_chain,
        "last_operation": last_operation,
        "possible_next_operations": possible_next_operations,
        "prompt": prompt,
        "response": response,
        "generate_operations": generate_operations,
        "next_operation": next_operation,
    }
    return next_operation, log


def dynamic_chain_exec_one_sample(
    sample,
    llm,
    llm_options=None,
    strategy="top",
    debug=False,
    operation_parameter_dict=None,
):      
    dynamic_chain_log = []

    # breakpoint()

    current_sample = copy.deepcopy(sample)
    # If none if the operations can be done by SQL, return False
    sql_executable = False
    is_sql_executable = False
        
    while True:
        # generate next operation
        next_operation, log = generate_prompt_for_next_step(
            current_sample,
            llm=llm,
            llm_options=llm_options,
            strategy=strategy,
            debug=debug,
        )
        dynamic_chain_log.append(log)
        if DEBUG:
            print('Expanding the chain...')
            print(next_operation)
        #     print('Chain:')
        #     print(dynamic_chain_log)

        if debug:
            print(next_operation)

        if next_operation == "<END>":
            break

        # Get solver_func to process the Table
        # TODO: Change the table here by using LLMs
        param = operation_parameter_dict[next_operation]

        # Get the solver function that will generate the arguments for the given operation
        op_name, solver_func, kargs, op_llm_options = param
        if DEBUG:
            print('op_name:')
            print(op_name)
            print('Table info BEFORE table manipulation operation:')

        # Then perform the modification on the existing table here.
        # check if using SQL failed --> change the act function

        # if USING_SQL is True and is_sql_executable is False:
        #     print('SQL falling here')
        # print(current_sample["chain"])
        table_info = get_table_info(current_sample)

        if DEBUG:
            print('Table info AFTER table manipulation operation:')
            print(table_info)
            # print('Sample after AFTER manipulation operation:')
            # print(current_sample)
        
        # Generate the arguments for the given operation
        # if USING_SQL is True:
        # print(solver_func)

        if current_sample['using_sql'] is True:
            current_sample, is_sql_executable, op_name = solver_func(
                current_sample, table_info, llm=llm, llm_options=op_llm_options, **kargs
        )
        else:
            current_sample = solver_func(
                current_sample, table_info, llm=llm, llm_options=op_llm_options, **kargs
            )

        # ######## anonym ADDED ON 16/06 to test fall back when SQL is failed and we forward to the final query ###
        # TODO: If any of the operations in the chain is unexecutable, we forward the current sample to Chain-of-Table.
        # Hơever, if we enable this, the perf degrades clearly
        # Why?
        # 
        # if is_sql_executable is False and USING_SQL is True:
        #     # break
        #     return current_sample, dynamic_chain_log, is_sql_executable
        if is_sql_executable is True:
            sql_executable = True
              
        

        # ##########
        # if DEBUG:
        #     print('Table info after generating arguments:')
        #     print(table_info)
        #     print('Sample after generating arguments:')
        #     print(current_sample)
        #     print('Dynamic chain log:')
        #     print(dynamic_chain_log)
            
    return current_sample, dynamic_chain_log, sql_executable


def dynamic_chain_exec_with_cache_for_loop(
    all_samples,
    llm,
    llm_options=None,
    strategy="voting",
    cache_dir="./cache/debug",
):
    os.makedirs(cache_dir, exist_ok=True)
    result_samples = [None for _ in range(len(all_samples))]
    dynamic_chain_log_list = [None for _ in range(len(all_samples))]

    cache_filename = "case-{}.pkl"

    def _func(idx):
        sample = all_samples[idx]
        sample_id = sample["id"]
        cache_path = os.path.join(cache_dir, cache_filename.format(sample_id))
        if os.path.exists(cache_path):
            _, proc_sample, log = pickle.load(open(cache_path, "rb"))
        else:
            proc_sample, log = dynamic_chain_exec_one_sample(
                sample, llm=llm, llm_options=llm_options, strategy=strategy
            )
            pickle.dump((sample, proc_sample, log), open(cache_path, "wb"))
        result_samples[idx] = proc_sample
        dynamic_chain_log_list[idx] = log

    for idx in tqdm(range(len(all_samples)), total=len(all_samples)):
        try:
            _func(idx)
        except Exception as e:
            print(f"IDX={idx}: {e}", flush=True)

    return result_samples, dynamic_chain_log_list

def _wikitq_natural_language_chain_exec_with_cache_mp_core(arg):
    idx, sample, llm, llm_options, strategy, cache_dir = arg
    sample_id = sample['id']
    
    if LLM == 'GPT3-5':
        sample_id, answer, is_sql_executable, groundtruth, result_dict, fall_back_llm  = wikitq_natural_language_chain_exec_one_sample(
            sample, llm=llm, llm_options=llm_options, strategy=strategy,
        )

        return sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, None, fall_back_llm

    else:

        # Load existing processed results
        processed_samples = load_processed_samples(result_file_name)
        sample_id = sample["id"]
        if str(sample_id) in processed_samples:
            print(f"Skipping already processed sample {sample_id}")
            
            answer = processed_samples[f'{sample_id}'][f'{sample_id}']['answer']
            is_sql_executable = processed_samples[f'{sample_id}'][f'{sample_id}']['is_sql_executable']
            groundtruth = processed_samples[f'{sample_id}'][f'{sample_id}']['groundtruth']
            result_dict = processed_samples[f'{sample_id}'][f'{sample_id}']['answer_plans']
            fall_back_llm = processed_samples[f'{sample_id}'][f'{sample_id}']['fallback_LLM']
            return sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, None, fall_back_llm
        
        else:
            sample_id, answer, is_sql_executable, groundtruth, result_dict, fall_back_llm  = wikitq_natural_language_chain_exec_one_sample(sample, llm=llm, llm_options=llm_options, strategy=strategy)

            result_samples = {}
            result_samples[f'{sample_id}'] = {}
            result_samples[f'{sample_id}']['input'] = sample
            result_samples[f'{sample_id}']['id'] = sample_id
            result_samples[f'{sample_id}']['answer'] = answer
            result_samples[f'{sample_id}']['answer_plans'] = result_dict
            result_samples[f'{sample_id}']['groundtruth'] = groundtruth
            result_samples[f'{sample_id}']['fallback_LLM'] = fall_back_llm
            result_samples[f'{sample_id}']['is_sql_executable'] = is_sql_executable

            processed_samples[str(sample_id)] = result_samples
            print('SAVING..')
            # Save the updated result samples
            save_processed_samples(result_file_name, processed_samples)

            return sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, None, fall_back_llm


import json
import os
import multiprocessing as mp
from tqdm import tqdm

def load_processed_samples(result_file_path):
    if os.path.exists(result_file_path):
        with open(result_file_path, 'r') as f:
            try:
                return json.load(f)
            except json.JSONDecodeError:
                print(f"Warning: {result_file_path} is empty or corrupted. Starting with an empty dictionary.")
                return {}
    return {}


def save_processed_samples(result_file_path, result_samples):
    with open(result_file_path, 'w') as f:
        # print(f)
        # print(result_samples)
        json.dump(result_samples, f, indent=4)

# Run POS only
def _natural_language_chain_exec_with_cache_mp_core(arg):
    idx, sample, llm, llm_options, strategy, cache_dir = arg
    sample_id = sample['id']
    
    if LLM == 'GPT3-5':
        sample_id, answer, is_sql_executable, groundtruth, result_dict, fb_llm  = natural_language_chain_exec_one_sample(
            sample, llm=llm, llm_options=llm_options, strategy=strategy,
        )

        return True, sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, None, fb_llm
    else: # do the caching for other models than gpt3-5
        unprocessed = True

        # Load existing processed results
        processed_samples = load_processed_samples(result_file_name)

        sample_id = sample["id"]
        # Check if the sample has already been processed
        if str(sample_id) in processed_samples:
            unprocessed = False
            print(f"Skipping already processed sample {sample_id}")

            answer = processed_samples[f'{sample_id}'][f'{sample_id}']['answer']
            is_sql_executable = processed_samples[f'{sample_id}'][f'{sample_id}']['is_sql_executable']
            groundtruth = processed_samples[f'{sample_id}'][f'{sample_id}']['groundtruth']
            result_dict = processed_samples[f'{sample_id}'][f'{sample_id}']['answer_plans']
            fb_llm = processed_samples[f'{sample_id}'][f'{sample_id}']['fallback_LLM']
            return unprocessed, sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, None, fb_llm

        else:
            sample_id, answer, is_sql_executable, groundtruth, result_dict, fb_llm  = natural_language_chain_exec_one_sample(sample, llm=llm, llm_options=llm_options, strategy=strategy)
            
            result_samples = {}
            result_samples[f'{sample_id}'] = {}
            result_samples[f'{sample_id}']['input'] = sample
            result_samples[f'{sample_id}']['id'] = sample_id
            result_samples[f'{sample_id}']['answer'] = answer
            result_samples[f'{sample_id}']['answer_plans'] = result_dict
            result_samples[f'{sample_id}']['groundtruth'] = groundtruth
            result_samples[f'{sample_id}']['fallback_LLM'] = fb_llm
            result_samples[f'{sample_id}']['is_sql_executable'] = is_sql_executable

            processed_samples[str(sample_id)] = result_samples
            print('SAVING..')
            # Save the updated result samples
            save_processed_samples(result_file_name, processed_samples)
            return unprocessed, sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, None, fb_llm

def _dynamic_chain_exec_with_cache_mp_core(arg):
    idx, sample, llm, llm_options, strategy, cache_dir = arg

    cache_filename = "case-{}.pkl"

    done_by_SQL = 0
    is_sql_executable = False
    if True:
        sample_id = sample["id"]
        cache_path = os.path.join(cache_dir, cache_filename.format(idx))

        # Freddy 
        # Caching results for saving intermediate computation - to uncomment for using cache
        # if os.path.exists(cache_path):
        #     _, proc_sample, log = pickle.load(open(cache_path, "rb"))
        #     # print(cache_filename)
        # else:

        if True:

            operation_parameters_dict = {}
            operation_parameters_dict['SQL'] = {
                "add_column": (
                    "addColumn",
                    add_column_func_sql,
                    {},
                    llm.get_model_options(
                        temperature=0.0,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                    ),
                ),
                "select_row": (
                    "selectRow",
                    select_row_func_sql,
                    {},
                    llm.get_model_options(
                        temperature=0.5,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                        n_sample=8,
                    ),
                ),
                "select_column": (
                    "selectColumn",
                    select_column_func_sql,
                    {},
                    llm.get_model_options(
                        temperature=0.5,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                        n_sample=8,
                    ),
                ),
                "group_column": (
                    "groupColumn",
                    group_column_func_sql,
                    dict(skip_op=[]),
                    llm.get_model_options(
                        temperature=0.0,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                    ),
                ),
                "sort_column": (
                    "sortColumn",
                    sort_column_func_sql,
                    dict(skip_op=[]),
                    llm.get_model_options(
                        temperature=0.0,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                    ),
                ),
            }
            #######################################################################################################################################
            operation_parameters_dict['No_SQL'] = {
                "add_column": (
                    "addColumn",
                    add_column_func,
                    {},
                    llm.get_model_options(
                        temperature=0.0,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                    ),
                ),
                "select_row": (
                    "selectRow",
                    select_row_func,
                    {},
                    llm.get_model_options(
                        temperature=0.5,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                        n_sample=8,
                    ),
                ),
                "select_column": (
                    "selectColumn",
                    select_column_func,
                    {},
                    llm.get_model_options(
                        temperature=0.5,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                        n_sample=8,
                    ),
                ),
                "group_column": (
                    "groupColumn",
                    group_column_func,
                    dict(skip_op=[]),
                    llm.get_model_options(
                        temperature=0.0,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                    ),
                ),
                "sort_column": (
                    "sortColumn",
                    sort_column_func,
                    dict(skip_op=[]),
                    llm.get_model_options(
                        temperature=0.0,
                        per_example_max_decode_steps=150,
                        per_example_top_p=1.0,
                    ),
                ),
            }

            is_sql_executable = False
            if USING_SQL is True:
                try:
                    sample['using_sql'] = True
                    proc_sample, log, is_sql_executable = dynamic_chain_exec_one_sample(
                        sample, llm=llm, llm_options=llm_options, strategy=strategy,
                        operation_parameter_dict = operation_parameters_dict['SQL']
                    )

                    if not is_sql_executable:
                        sample['using_sql'] = False
                        # Attempt the second method if SQL is not executable
                        print("SQL execution failed, trying alternative method...")
                        # print('Processed samples by using_sql:\n', proc_sample)

                        proc_sample, log, _ = dynamic_chain_exec_one_sample(
                            sample, llm=llm, llm_options=llm_options, strategy=strategy,
                            operation_parameter_dict = operation_parameters_dict['No_SQL']
                        )
                    else:
                        done_by_SQL += 1
                except Exception as e:
                    # Handle the exception and proceed with alternative method
                    print(f"Error during SQL execution: {e}, trying alternative method...")
                    is_sql_executable = False
                    sample['using_sql'] = False
                    proc_sample, log, _ = dynamic_chain_exec_one_sample(
                        sample, llm=llm, llm_options=llm_options, strategy=strategy,
                        operation_parameter_dict = operation_parameters_dict['No_SQL']
                    )
            else:
                is_sql_executable = False
                sample['using_sql'] = False
                proc_sample, log, _ = dynamic_chain_exec_one_sample(
                    sample, llm=llm, llm_options=llm_options, strategy=strategy,
                    operation_parameter_dict = operation_parameters_dict['No_SQL']
                )

            with open(cache_path, "wb") as f:
                pickle.dump((sample, proc_sample, log), f)
        
        print('Done by SQL:', done_by_SQL)
    return idx, proc_sample, log, is_sql_executable

def dynamic_chain_exec_with_cache_mp(
    all_samples,
    llm,
    llm_options=None,
    strategy="voting",
    cache_dir="./results/debug",
    n_proc=10,
    chunk_size=50,
):
    os.makedirs(cache_dir, exist_ok=True)
    result_samples = [None for _ in range(len(all_samples))]
    dynamic_chain_log_list = [None for _ in range(len(all_samples))]
    result_dict = {}

    args = [
        (idx, sample, llm, llm_options, strategy, cache_dir)
        for idx, sample in enumerate(all_samples)
    ]

    sql_cnt = 0

    if NATURAL_LANGUAGE_PLANNING is True:
        result_samples = {}
        with mp.Pool(n_proc) as p:
            for unprocessed, sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, subs, fb_llm in tqdm(
                p.imap_unordered(
                    _natural_language_chain_exec_with_cache_mp_core, args, chunksize=chunk_size
                ),
                total=len(all_samples),
            ):

                result_samples[f'{sample_id}'] = {}
                result_samples[f'{sample_id}']['input'] = sample
                result_samples[f'{sample_id}']['input']['sub_statements'] = subs
                result_samples[f'{sample_id}']['id'] = sample_id
                result_samples[f'{sample_id}']['answer'] = answer
                result_samples[f'{sample_id}']['answer_plans'] = result_dict
                result_samples[f'{sample_id}']['groundtruth'] = groundtruth
                result_samples[f'{sample_id}']['fallback_LLM'] = fb_llm
                result_samples[f'{sample_id}']['is_sql_executable'] = is_sql_executable

                dynamic_chain_log_list = None

    elif OTG_PLANNING is True:
        with mp.Pool(n_proc) as p:
            for idx, proc_sample, log, is_sql_executable in tqdm(
                p.imap_unordered(
                    _dynamic_chain_exec_with_cache_mp_core, args, chunksize=chunk_size
                ),
                total=len(all_samples),
            ):
                if is_sql_executable is True:
                    sql_cnt += 1
                result_samples[idx] = proc_sample
                result_samples[idx]['is_sql_executable'] = is_sql_executable
                dynamic_chain_log_list[idx] = log
    
        print('Number of samples that have tables (record 1) edited by SQLs:', sql_cnt)

    print('Number of total samples tested:', len(all_samples))

    return result_samples, dynamic_chain_log_list


def wikitq_dynamic_chain_exec_with_cache_mp(
    all_samples,
    llm,
    llm_options=None,
    strategy="voting",
    cache_dir="./results/debug",
    n_proc=10,
    chunk_size=50,
):
    os.makedirs(cache_dir, exist_ok=True)
    result_samples = [None for _ in range(len(all_samples))]
    dynamic_chain_log_list = [None for _ in range(len(all_samples))]
    result_dict = {}

    args = [
        (idx, sample, llm, llm_options, strategy, cache_dir)
        for idx, sample in enumerate(all_samples)
    ]

    sql_cnt = 0

    if NATURAL_LANGUAGE_PLANNING is True:
        result_samples = {}
        with mp.Pool(n_proc) as p:
            for sample, sample_id, answer, is_sql_executable, groundtruth, result_dict, subs, fb_llm in tqdm(
                p.imap_unordered(
                    _wikitq_natural_language_chain_exec_with_cache_mp_core, args, chunksize=chunk_size
                ),
                total=len(all_samples),
            ):

                result_samples[sample_id] = {}
                result_samples[sample_id]['input'] = sample
                result_samples[sample_id]['input']['sub_statements'] = subs
                result_samples[sample_id]['id'] = sample_id
                result_samples[sample_id]['answer'] = answer
                result_samples[sample_id]['answer_plans'] = result_dict
                result_samples[sample_id]['groundtruth'] = groundtruth
                result_samples[sample_id]['fallback_LLM'] = fb_llm
                result_samples[sample_id]['is_sql_executable'] = is_sql_executable

                dynamic_chain_log_list = None

    return result_samples, dynamic_chain_log_list
