#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import argparse  
import json  
from openai import AzureOpenAI  
from tqdm import tqdm  
import pdb  
import re  
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import torch
  
# Update the import statements to use absolute imports  
from concurrent.futures import ThreadPoolExecutor, as_completed  
from model_api_call import get_chat_response_azure  
from functools import partial  
import sqlite3
import pandas as pd
import gc
import json
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from code_utils import extract_code, create_message, extract_code_plan, mask_code_with_comments

def process_batch_data(data, batch_size):
    # Assuming process_batch_data is a custom function to split data into batches
    return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]


def wrap_up_code_prompt(data, prompt_template):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_title = data.get('table_title', '')
    data_overview = data.get('data_overview', '')
    success_sug = data.get('rag_mg', '')

    # first_code = data.get('first_code', '')
    # function_code = mask_code_with_comments('#import necessary packages:\n' + first_code)
    
    prompt = prompt_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    prompt = prompt.replace('[[data_overview]]', data_overview)
    if not success_sug: success_sug = 'No suggestions, just follow your thoughts.'
    if data_title is None: data_title = ''
    
    prompt = prompt.replace('[[data_title]]', data_title)
    prompt = prompt.replace('[[successful_plan_suggestions]]', success_sug)
    
    data['rag_final_cot_prompt'] = prompt
    # pdb.set_trace()
    return data


def wrap_up_code_few_prompt(data):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_title = data.get('table_title', '')
    data_overview = data.get('data_overview', '')
    few_shot_template = data.get('rag_few_shot_prompt')

    # first_code = data.get('first_code', '')
    # function_code = mask_code_with_comments('#import necessary packages:\n' + first_code)
    
    prompt = few_shot_template.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    prompt = prompt.replace('[[data_overview]]', data_overview)
    if data_title is None: data_title = ''
    
    prompt = prompt.replace('[[data_title]]', data_title)
    
    data['rag_final_cot_prompt'] = prompt
    # pdb.set_trace()
    return data


def wrap_up_code_ins_prompt(data, prompt_template_sketch, prompt_suggesitions):
    question = data.get('question', '')
    data_path = data.get('data_path', '')
    data_title = data.get('table_title', '')
    data_overview = data.get('data_overview', '')
    
    # first_code = data.get('first_code', '')
    # function_code = mask_code_with_comments('#import necessary packages:\n' + first_code)
    if data_title is None: data_title = ''
    prompt = prompt_template_sketch.replace('[[question]]', question)
    prompt = prompt.replace('[[data_path]]', data_path)
    prompt = prompt.replace('[[data_overview]]', data_overview)

    
    prompt = prompt.replace('[[data_title]]', data_title)
    prompt = prompt.replace('[[successful_plan_suggestions]]', prompt_suggesitions)
    
    data['rag_final_cot_prompt'] = prompt
    # pdb.set_trace()
    return data


def load_and_infer_from_jsonl_dataset_al(prompt_path, result_path, model_name, prompt_template, temperature=0.0, max_tokens=60, top_p=1, frequency_penalty=0, presence_penalty=0, stop=None, max_retries=10, batch_size=20, min_batch_size=4, num_gpus=1):
    """
    Load prompts from a .jsonl file, get responses using Azure OpenAI in parallel, and save the results in another .jsonl file with progress monitoring.
    This function keeps all original items from the input JSON lines and adds the response, along with indexing each processed line for better tracking.
    """
    prompt_template_sketch, sug_prompt = prompt_template
    
    with open(prompt_path, 'r') as f:
        dataset = [json.loads(line) for line in f]
        dataset = [wrap_up_code_ins_prompt(data, prompt_template_sketch, sug_prompt) for data in dataset]
        # Few shot:
        # dataset = [wrap_up_code_few_prompt(data) for data in dataset]
        batch_data = process_batch_data(dataset, batch_size=batch_size)
        
    llm = LLM(model=model_name, max_model_len=100000, gpu_memory_utilization=0.7, trust_remote_code=True, tensor_parallel_size=num_gpus)  # Use the model's max input len
    stop_tokens = None  # should be a list
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        stop=None,
        stop_token_ids=[tokenizer.eos_token_id]
    )

    
    with open(result_path, 'w', encoding='utf-8') as f:
        for idx, batch_sample in enumerate(tqdm(batch_data)):
            print(f'Infering {idx}th batch...', flush=True)
            with torch.no_grad():
                batch_sample_prompt_list = [batch['rag_final_cot_prompt'] for batch in batch_sample]
                completions = llm.generate(batch_sample_prompt_list, sampling_params)
                for j, completion in enumerate(completions):
                    data_w = batch_data[idx][j]
                    # cot = extract_code("```python\n# Step 1: Import neccessary libraries\nimport pandas as pd" + completion.outputs[0].text)  # or the correct field name
                    cot = extract_code_plan("```code_plan" + completion.outputs[0].text)
                    if cot:
                        data_w['rag_final_cot'] = cot
                    else:
                        data_w['rag_final_cot'] = ''

                    
                    f.write(json.dumps(data_w, ensure_ascii=False) + '\n')
            # if idx == 4: pdb.set_trace()
            # pdb.set_trace()
            torch.cuda.empty_cache()
            gc.collect()
                       
def inference():  
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--deployment_name', type=str, required=True, help='Model name to use for the API call.')  
    parser.add_argument('--temperature', type=float, default=0.0, help='Temperature for the response. Default is 0.0.')  
    parser.add_argument('--max_tokens', type=int, default=60, help='Maximum number of tokens to generate. Default is 60.')  
    parser.add_argument('--top_p', type=float, default=1, help='Top P value. Default is 1.')  
    parser.add_argument('--frequency_penalty', type=float, default=0, help='Frequency penalty. Default is 0.')  
    parser.add_argument('--presence_penalty', type=float, default=0, help='Presence penalty. Default is 0.')  
    parser.add_argument('--stop', nargs='*', help='Stop sequence(s). Multiple values are allowed.')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--batch_size', type=int, required=False, help='if your API could be run in parallel')
    parser.add_argument('--dataset', type=str, required=True, help='dataset name')
    parser.add_argument('--num_gpus', type=int, required=True, help='num of gpus')
      
    args = parser.parse_args()  
    
    if args.dataset == "tabmwp":
        from prompts.tabmwp_prompt_combine.prompt.tabwmp_rag import rag_mg_slm_al_v3, pos_suggestions_gpt, contrastive_neg_suggestions, pos_suggestions
        from prompts.tabmwp_prompt_combine.new.tabmwp_rag import pos_suggestions_llama_no_ex, pos_suggestions_llama31_8b, pos_suggestions_llama31_8b_no_ex, contrastive_neg_suggestions_llama
        from prompts.tabmwp_prompt_combine.prompt.tabmwp_all import rag_star_final_cot, rag_llama_final_cot, rag_llama_final_cot_v2, pos_suggestions_llama, rag_llama_al_v4, llama_pos_summary_suggestions
    
    prompt_template = (rag_llama_final_cot, llama_pos_summary_suggestions)
    
    load_and_infer_from_jsonl_dataset_al(args.prompt_path, args.result_path, model_name=args.deployment_name, prompt_template=prompt_template, temperature=args.temperature, max_tokens=args.max_tokens,  
                                top_p=args.top_p, stop=args.stop, batch_size=args.batch_size, num_gpus=args.num_gpus)
      
  
if __name__ == "__main__":  
    inference()  
