#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm
import sqlite3
import pandas as pd
import pdb
import re
from tabulate import tabulate
# Assuming this script is saved in a package and the import below refers to another module within the same package    
from concurrent.futures import ThreadPoolExecutor, as_completed  
from model_api_call import get_chat_response_azure  
from functools import partial  
from prompts.wikitq_prompt_combine.prompt.wikitq_all import llm_case_study_contrastive_v2
from prompts.bird_prompt_combine.prompt.bird_all import llm_case_study_contrastive
from code_utils import extract_code



def wrap_up_cs_wikitq_prompt(data, prompt_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    trad_wrong = data.get('trad_wrong', '')
    trad_correct = data.get('trad_correct', '')
    wrong_final_code = trad_wrong.get('final_code', '')
    wrong_execution = trad_wrong.get('execution', '')
    data_overview = data.get('data_overview', '')
    if not wrong_final_code: wrong_final_code = ''
    if not wrong_execution: wrong_execution = ''
    
    correct_final_code = trad_correct.get('final_code', '')
    correct_execution = trad_correct.get('execution', '')
    if not correct_final_code: correct_final_code = ''
    if not correct_execution: correct_execution = ''
    
    final_code = data.get('final_code', '')
    if not final_code: final_code = 'no code available'
    
    prompt = prompt_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    
    prompt = prompt.replace('[[data_overview]]', data_overview)
    prompt = prompt.replace('[[wrong_final_code]]', wrong_final_code)
    prompt = prompt.replace('[[wrong_execution]]', wrong_execution)
    
    prompt = prompt.replace('[[correct_final_code]]', correct_final_code)
    prompt = prompt.replace('[[correct_execution]]', correct_final_code)
    
    
    data['case_study_prompt'] = prompt
    # pdb.set_trace()
    return data

def wrap_up_cs_bird_prompt(data, prompt_template):
    question = data.get('question', '')
    trad_wrong = data.get('trad_wrong', '')
    trad_correct = data.get('trad_correct', '')
    wrong_final_code = trad_wrong.get('final_code', '')
    data_overview = data.get('data_overview', '')
    if not wrong_final_code: wrong_final_code = ''
    if not wrong_execution: wrong_execution = ''
    
    correct_final_code = trad_correct.get('final_code', '')
    if not correct_final_code: correct_final_code = ''
    if not correct_execution: correct_execution = ''
    
    final_code = data.get('final_code', '')
    if not final_code: final_code = 'no code available'
    
    prompt = prompt_template.replace('[[question]]', question)
    
    prompt = prompt.replace('[[data_overview]]', data_overview)
    prompt = prompt.replace('[[wrong_final_code]]', wrong_final_code)
    
    prompt = prompt.replace('[[correct_final_code]]', correct_final_code)
    
    
    data['case_study_prompt'] = prompt
    # pdb.set_trace()
    return data


def load_and_infer_from_jsonl_parallel(prompt_path, result_path, client, model, temperature=0.0, max_tokens=60, top_p=1, frequency_penalty=0, presence_penalty=0, stop=None, max_retries=10, num_threads=5):    
    """      
    Load prompts from a .jsonl file, get responses using Azure OpenAI in parallel, and save the results in another .jsonl file with progress monitoring.      
    This function keeps all original items from the input JSON lines and adds the response, along with indexing each processed line for better tracking.      
    """      
    with open(prompt_path, 'r', encoding='utf-8') as infile:    
        lines = [json.loads(line) for line in infile] 
        lines = [wrap_up_cs_wikitq_prompt(line, llm_case_study_contrastive) for line in lines]   
  
    def process_line(data, idx):    
        prompt_string = data.get('case_study_prompt', '')    
        messages = [                
            {"role": "system", "content": "You are a helpful assistant."},    
            {"role": "user", "content": prompt_string}    
        ]    
        try:    
            response = get_chat_response_azure(client=client, model=model, messages=messages, temperature=temperature,    
                                               max_tokens=max_tokens, top_p=top_p, frequency_penalty=frequency_penalty,    
                                               presence_penalty=presence_penalty, stop=stop, max_retries=max_retries)    
                
            if response:    
                data['case_study_con'] = response    
            else:    
                data['case_study_con'] = 'missing how to optimize plan'    
        except Exception as e:    
            data['case_study_con'] = 'missing how to optimize plan'    
        # pdb.set_trace()
        # print(f"================ response #{idx}: ================\n")    
        # print(prompt_string)
        # print(response)
        # print(f"================ finish response #{idx} ================\n")  
            
        return data    
  
    with ThreadPoolExecutor(max_workers=num_threads) as executor:    
        # We use enumerate to get an index (idx) and pass it along with each line to process_line  
        tasks = [(line, idx) for idx, line in enumerate(lines)]  
        results = list(tqdm(executor.map(lambda p: process_line(*p), tasks), total=len(lines), desc="running"))    
        
    with open(result_path, 'w', encoding='utf-8') as outfile:    
        for data in results:    
            outfile.write(json.dumps(data, ensure_ascii=False) + '\n') 

  
def inference():  
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--deployment_name', type=str, required=True, help='Model name to use for the API call.')  
    parser.add_argument('--temperature', type=float, default=0.0, help='Temperature for the response. Default is 0.0.')  
    parser.add_argument('--max_tokens', type=int, default=60, help='Maximum number of tokens to generate. Default is 60.')  
    parser.add_argument('--top_p', type=float, default=1, help='Top P value. Default is 1.')  
    parser.add_argument('--frequency_penalty', type=float, default=0, help='Frequency penalty. Default is 0.')  
    parser.add_argument('--presence_penalty', type=float, default=0, help='Presence penalty. Default is 0.')  
    parser.add_argument('--stop', nargs='*', help='Stop sequence(s). Multiple values are allowed.')  
    parser.add_argument('--api_key', type=str, required=True, help='OpenAI API key.')  
    parser.add_argument('--api_base', type=str, default="https://api.openai.com", help='OpenAI API base URL. Default is the standard OpenAI API.')  
    parser.add_argument('--api_version', type=str, default="v1", help='OpenAI API version. Default is "v1".')  
    parser.add_argument('--api_type', type=str, default="azure", help='OpenAI API Type. Default is "Azure"')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--num_threads', type=int, required=False, help='if your API could be run in parallel')
      
    args = parser.parse_args()  
  
  
    client = AzureOpenAI(  
        api_key=args.api_key,  
        api_version=args.api_version,  
        base_url=f"{args.api_base}/openai/deployments/{args.deployment_name}") 
    
    print(args.num_threads)
    load_and_infer_from_jsonl_parallel(args.prompt_path, args.result_path, client, model=args.deployment_name, temperature=args.temperature, max_tokens=args.max_tokens,  
                                top_p=args.top_p, frequency_penalty=args.frequency_penalty, presence_penalty=args.presence_penalty,   
                              stop=args.stop, max_retries=100, num_threads=args.num_threads)
 
  
if __name__ == "__main__":  
    inference()  
