#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm
import sqlite3
import pandas as pd
import pdb
import re
from tabulate import tabulate
from concurrent.futures import ThreadPoolExecutor  
from model_api_call import get_chat_response_azure  
from functools import partial  
from code_utils import extract_result_from_anchor, extract_code_plan, replace_code_with_placeholder_with_steps, extract_action


def wrap_up_opt_tabmwp_prompt(data, prompt_template, unit_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_overview = data.get('data_overview')
    result = extract_result_from_anchor(data.get('result', ''))
    answer = data.get('answer', '')
    
    final_code = data.get('final_code', '')
    if not final_code: final_code = 'no mixed code available'
    execution = data.get('execution', '')
    if not execution: execution = 'no execution returned'
    
    unit_template = unit_template.replace('[[question]]', question)
    unit_template = unit_template.replace('[[data_path]]', data_path)
    unit_template = unit_template.replace('[[data_overview]]', data_overview)
    unit_template = unit_template.replace('[[final_code]]', final_code)
    unit_template = unit_template.replace('[[execution]]', execution)
    unit_template = unit_template.replace('[[result]]', result)
    unit_template = unit_template.replace('[[answer]]', str(answer))
    
    prompt = prompt_template.replace('[[last_round_case]]', unit_template)
    data['function_opt_prompt'] = prompt
    return data

def wrap_up_opt_wikitq_prompt(data, prompt_template, unit_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_overview = data.get('data_overview')
    result = extract_result_from_anchor(data.get('result', ''))
    answer = data.get('answer', '')
    
    final_code = data.get('final_code', '')
    if not final_code: final_code = 'no mixed code available'
    execution = data.get('execution', '')
    if not execution: execution = 'no execution returned'
    
    unit_template = unit_template.replace('[[question]]', question)
    unit_template = unit_template.replace('[[data_path]]', data_path)
    unit_template = unit_template.replace('[[data_overview]]', data_overview)
    unit_template = unit_template.replace('[[final_code]]', final_code)
    unit_template = unit_template.replace('[[execution]]', execution)
    unit_template = unit_template.replace('[[result]]', result)
    unit_template = unit_template.replace('[[answer]]', str(answer))
    
    prompt = prompt_template.replace('[[last_round_case]]', unit_template)
    data['function_opt_prompt'] = prompt
    return data

def load_and_infer_from_jsonl_parallel(prompt_path, result_path, client, model, prompt_template, temperature=0.0, max_tokens=60, top_p=1, frequency_penalty=0, presence_penalty=0, stop=None, max_retries=10, num_threads=5, dataset_name="wikitq", continue_previous=False):    
    prompt_sketch, prompt_unit = prompt_template 
    
    # Load previously generated results if `continue_previous` is True
    existing_results = {}
    if continue_previous and os.path.exists(result_path):
        with open(result_path, 'r', encoding='utf-8') as outfile:
            for line in outfile:
                data = json.loads(line)
                key = data.get('function_opt_prompt', '')
                existing_results[key] = data
    
    with open(prompt_path, 'r', encoding='utf-8') as infile:    
        lines = [json.loads(line) for line in infile] 
        if dataset_name == "tabmwp":
            lines = [wrap_up_opt_tabmwp_prompt(line, prompt_sketch, prompt_unit) for line in lines]   
        elif dataset_name == "wikitq":
            lines = [wrap_up_opt_wikitq_prompt(line, prompt_sketch, prompt_unit) for line in lines]  
        if not lines[0].get('require_opt', ''): 
            dataset_wrong = lines
            dataset_right = []
        else:
            dataset_wrong = [data for data in lines if data['require_opt'] == 'true']  
            dataset_right = [data for data in lines if data['require_opt'] == 'false']  
        print('There are {} cases needed to be optimized!'.format(len(dataset_wrong)))

    # Filter out the already processed prompts
    dataset_wrong = [data for data in dataset_wrong if data.get('function_opt_prompt', '') not in existing_results]

    def process_line(data, idx, outfile):    
        prompt_string = data.get('function_opt_prompt', '')    
        messages = [    
            {"role": "system", "content": "You are a helpful assistant."},    
            {"role": "user", "content": prompt_string}    
        ]    
        try:    
            response = get_chat_response_azure(client=client, model=model, messages=messages, temperature=temperature,    
                                               max_tokens=max_tokens, top_p=top_p, frequency_penalty=frequency_penalty,    
                                               presence_penalty=presence_penalty, stop=stop, max_retries=max_retries)    
            code = extract_code_plan(result=response)   
            action = extract_action(response)
                 
            if code:    
                function_code = replace_code_with_placeholder_with_steps(code)
                data['function_code'] = function_code     
            else:    
                data['function_code'] = ''
            
            if action:
                data['action'] = action
        except Exception as e:    
            print(f"Error processing prompt #{idx}: {str(e)}")    
            data['first_code'] = "Error: Unable to process the prompt."    
            data['first_code_flag'] = 'false'    

        print(f"================ response #{idx}: ================\n")    
        print(prompt_string)
        print(data['function_code'])
        print(f"================ finish response #{idx} ================\n")  
        
        outfile.write(json.dumps(data, ensure_ascii=False) + '\n')
            
        return data    
  
    with ThreadPoolExecutor(max_workers=num_threads) as executor, open(result_path, 'a', encoding='utf-8') as outfile:    
        tasks = [(line, idx, outfile) for idx, line in enumerate(dataset_wrong)]  
        list(tqdm(executor.map(lambda p: process_line(*p), tasks), total=len(dataset_wrong), desc="running"))    
    
    # Write any remaining right-side data
    with open(result_path, 'a', encoding='utf-8') as outfile:
        for data_r in dataset_right:    
            outfile.write(json.dumps(data_r, ensure_ascii=False) + '\n')

  
def inference():  
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--deployment_name', type=str, required=True, help='Model name to use for the API call.')  
    parser.add_argument('--temperature', type=float, default=0.0, help='Temperature for the response. Default is 0.0.')  
    parser.add_argument('--max_tokens', type=int, default=60, help='Maximum number of tokens to generate. Default is 60.')  
    parser.add_argument('--top_p', type=float, default=1, help='Top P value. Default is 1.')  
    parser.add_argument('--frequency_penalty', type=float, default=0, help='Frequency penalty. Default is 0.')  
    parser.add_argument('--presence_penalty', type=float, default=0, help='Presence penalty. Default is 0.')  
    parser.add_argument('--stop', nargs='*', help='Stop sequence(s). Multiple values are allowed.')  
    parser.add_argument('--api_key', type=str, required=True, help='OpenAI API key.')  
    parser.add_argument('--api_base', type=str, default="https://api.openai.com", help='OpenAI API base URL. Default is the standard OpenAI API.')  
    parser.add_argument('--api_version', type=str, default="v1", help='OpenAI API version. Default is "v1".')  
    parser.add_argument('--api_type', type=str, default="azure", help='OpenAI API Type. Default is "Azure"')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--num_threads', type=int, required=False, help='Number of threads for parallel processing.')
    parser.add_argument('--dataset', type=str, required=False, help='Dataset type for loading specific prompts.')
    parser.add_argument('--continue_previous', action='store_true', help='Whether to continue from a previous incomplete run.')
    
    args = parser.parse_args()  
    if args.dataset == "tabmwp":
        from prompts.tabmwp_prompt_combine.prompt.tabmwp_baseline import baseline_llm_v2 as prompt_template
    if args.dataset == "wikitq":
        from prompts.wikitq_prompt_combine.prompt.wikitq_all import llm_opt_prompt_critic_text, llm_opt_unit
        prompt_template = llm_opt_prompt_critic_text, llm_opt_unit
  
    client = AzureOpenAI(  
        api_key=args.api_key,  
        api_version=args.api_version,  
        base_url=f"{args.api_base}/openai/deployments/{args.deployment_name}") 
    
    print(args.num_threads)
    load_and_infer_from_jsonl_parallel(args.prompt_path, args.result_path, client, model=args.deployment_name, prompt_template=prompt_template, temperature=args.temperature, max_tokens=args.max_tokens,  
                                top_p=args.top_p, frequency_penalty=args.frequency_penalty, presence_penalty=args.presence_penalty,   
                              stop=args.stop, max_retries=100, num_threads=args.num_threads, dataset_name=args.dataset, continue_previous=args.continue_previous)

  
if __name__ == "__main__":  
    inference()  
