#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm  
import pdb  
import re  
  
# Update the import statements to use absolute imports  
from concurrent.futures import ThreadPoolExecutor, as_completed  
from model_api_call import get_chat_response_azure  
from functools import partial  
import sqlite3
import pandas as pd
import gc
from code_utils import extract_code, create_message

import json
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from code_utils import mask_code_with_comments

def process_batch_data(data, batch_size):
    # Assuming process_batch_data is a custom function to split data into batches
    return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]


def wrap_up_tabmwp_code_prompt(data, prompt_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_title = data.get('table_title', '')
    data_overview = data.get('data_overview', '')
    success_sug = data.get('rag_mg', '')
    cot = data.get('rag_final_cot', '')
    if not cot: cot = ''

    # first_code = data.get('first_code', '')
    # function_code = mask_code_with_comments('#import necessary packages:\n' + first_code)
    
    prompt = prompt_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    prompt = prompt.replace('[[data_overview]]', data_overview)
    if not success_sug: success_sug = 'No suggestions, just follow your thoughts.'
    if data_title is None: data_title = ''
    
    prompt = prompt.replace('[[cot]]', cot)
    prompt = prompt.replace('[[data_title]]', data_title)
    prompt = prompt.replace('[[successful_plan_suggestions]]', success_sug)
    
    data['rag_final_code_prompt'] = prompt
    pdb.set_trace()
    return data


def wrap_up_clean_tabmwp_al_prompt(data, prompt_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_title = data.get('table_title', '')
    cot = data.get('rag_final_cot', '')
    if not cot: cot = ''
    
    prompt = prompt_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    data_overview = data.get('data_overview', '')
    prompt = prompt.replace('[[data_overview]]', data_overview)
    if data_title is None: data_title = ''
    
    prompt = prompt.replace('[[data_title]]', data_title)
    
    prompt = prompt.replace('[[cot]]', cot)
    
    
    data['rag_final_code_prompt'] = prompt
    # pdb.set_trace()
    return data


def wrap_up_wikitq_clean_prompt(data, prompt_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    cot = data.get('rag_final_cot', '')
    data_overview = data.get('data_overview', '')
    if not cot: cot = 'no chain of thoughts needed, follow your thinking'
    # else:
    #     cot = extract_plan(cot)
    
    prompt = prompt_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    
    prompt = prompt.replace('[[data_overview]]', data_overview)
    prompt = prompt.replace('[[cot]]', str(cot))
    data['rag_final_code_prompt'] = prompt
    # pdb.set_trace()
    return data

def wrap_up_wikitq_sug_prompt(data, prompt_template, prompt_sug):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    cot = data.get('rag_final_cot', '')
    data_overview = data.get('data_overview', '')
    if not cot: cot = 'no chain of thoughts needed, follow your thinking'
    # else:
    #     cot = extract_plan(cot)
    
    prompt = prompt_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    
    prompt = prompt.replace('[[data_overview]]', data_overview)
    prompt = prompt.replace('[[successful_plan_suggestions]]', prompt_sug)
    prompt = prompt.replace('[[cot]]', cot)
    data['rag_final_code_prompt'] = prompt
    # pdb.set_trace()
    return data



def wrap_up_code_ins_prompt(data, prompt_template_sketch, prompt_suggesitions):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_title = data.get('table_title', '')
    data_overview = data.get('data_overview', '')
    cot = data.get('rag_final_cot', '')
    if not cot: cot = 'No instructions, follow your thoughts'
    
    # first_code = data.get('first_code', '')
    # function_code = mask_code_with_comments('#import necessary packages:\n' + first_code)
    if data_title is None: data_title = ''
    prompt = prompt_template_sketch.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    prompt = prompt.replace('[[data_overview]]', data_overview)

    
    prompt = prompt.replace('[[data_title]]', data_title)
    prompt = prompt.replace('[[successful_plan_suggestions]]', prompt_suggesitions)
    prompt = prompt.replace('[[cot]]', cot)
    
    
    data['rag_final_code_prompt'] = prompt
    # pdb.set_trace()
    return data



def load_and_infer_from_jsonl_dataset_al_gpt(prompt_path, result_path, client, model, prompt_template, temperature=0.0, max_tokens=60, top_p=1, frequency_penalty=0, presence_penalty=0, stop=None, max_retries=10, num_threads=5):
    """
    Load prompts from a .jsonl file, get responses using Azure OpenAI in parallel, and save the results in another .jsonl file with progress monitoring.
    This function keeps all original items from the input JSON lines and adds the response, along with indexing each processed line for better tracking.
    """
    prompt_template_sketch, prompt_sug = prompt_template

    def process_line(data, idx):
        # Prepare the input prompt using the template
        prompt_string = data['rag_final_code_prompt']
        
        # Prepare the messages format for OpenAI
        messages = create_message(prompt_string=prompt_string)

        try:
            # Get response from the OpenAI model
            response = get_chat_response_azure(client=client, model=model, messages=messages, temperature=temperature,    
                                               max_tokens=max_tokens, top_p=top_p, frequency_penalty=frequency_penalty,    
                                               presence_penalty=presence_penalty, stop=stop, max_retries=max_retries)
            
            # Extract relevant code from the response (e.g., `cot` logic from your original code)
            cot = extract_code(response)
            if cot:
                data['rag_final_code'] = cot

        except Exception as e:
            print(f"Error processing prompt: {str(e)}")
            data['result'] = 'error'

        # Print response info for tracking/debugging
        print(f"================ response #{idx}: ================\n")
        print(prompt_string)
        print(response)
        print(f"================ finish response #{idx} ================\n")
        
        return data

    # Load dataset from the JSONL file
    with open(prompt_path, 'r', encoding='utf-8') as infile:
        dataset = [json.loads(line) for line in infile]
        dataset = [wrap_up_wikitq_clean_prompt(data, prompt_template_sketch) for data in dataset]
        # dataset = [wrap_up_wikitq_sug_prompt(data, prompt_template_sketch) for data in dataset]

    # Process each item in parallel using a thread pool
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        tasks = [(data, idx) for idx, data in enumerate(dataset)]
        results = list(tqdm(executor.map(lambda p: process_line(*p), tasks), total=len(dataset), desc="running"))

    # Save the results to the output JSONL file
    with open(result_path, 'w', encoding='utf-8') as outfile:
        for data in results:
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n')

                       
def inference():  
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--deployment_name', type=str, required=True, help='Model name to use for the API call.')  
    parser.add_argument('--temperature', type=float, default=0.0, help='Temperature for the response. Default is 0.0.')  
    parser.add_argument('--max_tokens', type=int, default=60, help='Maximum number of tokens to generate. Default is 60.')  
    parser.add_argument('--top_p', type=float, default=1, help='Top P value. Default is 1.')  
    parser.add_argument('--frequency_penalty', type=float, default=0, help='Frequency penalty. Default is 0.')  
    parser.add_argument('--presence_penalty', type=float, default=0, help='Presence penalty. Default is 0.')  
    parser.add_argument('--stop', nargs='*', help='Stop sequence(s). Multiple values are allowed.')  
    parser.add_argument('--api_key', type=str, required=True, help='OpenAI API key.')  
    parser.add_argument('--api_base', type=str, default="https://api.openai.com", help='OpenAI API base URL. Default is the standard OpenAI API.')  
    parser.add_argument('--api_version', type=str, default="v1", help='OpenAI API version. Default is "v1".')  
    parser.add_argument('--api_type', type=str, default="azure", help='OpenAI API Type. Default is "Azure"')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--num_threads', type=int, required=False, help='if your API could be run in parallel')
    parser.add_argument('--dataset', type=str, required=True, help='name of dataset')
      
    args = parser.parse_args()  
    if args.dataset == "tabmwp":
        from prompts.tabmwp_prompt_combine.prompt.tabwmp_rag import rag_gpt_sug_al_v5_final, contrastive_neg_suggestions, rag_mg_slm_code_v3, rag_final_code_gpt_sug, rag_final_code_gpt, rag_gpt_al_v6_final
        from prompts.tabmwp_prompt_combine.prompt.tabmwp_all import rag_final_code_gpt35_sug, rag_star_final_sug, rag_final_code_gpt35
    
    if args.dataset == "wikitq":
        from prompts.wikitq_prompt_combine.prompt.wikitq_all import rag_phi_final_sug, neg_suggestions_v1, rag_phi_final, rag_final_code_gpt35
        
    client = AzureOpenAI(  
        api_key=args.api_key,  
        api_version=args.api_version,  
        base_url=f"{args.api_base}/openai/deployments/{args.deployment_name}") 
    
    
    prompt_template = rag_final_code_gpt35, neg_suggestions_v1
    
    load_and_infer_from_jsonl_dataset_al_gpt(args.prompt_path, args.result_path, client, model=args.deployment_name, prompt_template=prompt_template,  temperature=args.temperature, max_tokens=args.max_tokens,  
                                top_p=args.top_p, stop=args.stop, num_threads=args.num_threads)
 
  
if __name__ == "__main__":  
    inference()  
