#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os  
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  
import json
import time
import argparse
import pickle
from openai import AzureOpenAI  
from prompts.bird_prompt_combine.prompt.bird_all import llm_pos_opt_sum_body, llm_pos_opt_sum_final
# from bird_prompt_combine.new.bird_baseline import llm_neg_opt_sum_body, llm_neg_opt_sum_final
import pdb
from tqdm import tqdm

# Replace with your Azure OpenAI client initialization
# from azure_openai import AzureOpenAIClient
# client = AzureOpenAIClient(api_key='your_azure_api_key')

# Sample function provided by you
def get_chat_response_azure(client, model, messages, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, stop=None, max_retries=10):
    attempt = 0
    while attempt < max_retries:
        try:
            response = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                top_p=top_p,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                stop=stop
            )
            return response.choices[0].message.content
        except Exception as e:
            attempt += 1
            error_message = str(e)
            if "400" in str(error_message) and "repetitive patterns" in error_message:
                return None
            print(f"Attempt {attempt} failed: {e}")
            if attempt < max_retries:
                time.sleep(10)
            else:
                return f"Failed after {max_retries} attempts: {e}"

# Read JSONL file and extract 'plan_opt' fields
def read_jsonl_file(file_path, opinion='pos'):
    plan_opts = []
    if opinion == "pos":
        case_study = "case_study"
    else:
        case_study = "case_study_con"
    with open(file_path, 'r') as file:
        for line in file:
            data = json.loads(line)
            if case_study in data:
                plan_opts.append(data[case_study])
    return plan_opts

# Divide data into batches
def divide_into_batches(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

# Process individual batches and return results
def process_batches(file_path, client, model, batch_prompt, batch_size, temperature=0.7, max_tokens=100, top_p=1, frequency_penalty=0, presence_penalty=0, stop=None, opinion="pos"):
    plan_opts = read_jsonl_file(file_path, opinion)
    batches = list(divide_into_batches(plan_opts, batch_size))
    batch_results = []
    
    for idx, batch in enumerate(tqdm(batches, desc="Processing batches")):
        labeled_batch = [f"Case {idx * batch_size + j + 1}:\n{item}" for j, item in enumerate(batch)]
        batch_string = '\n\n'.join(labeled_batch)
        
        prompt_string = batch_prompt.replace('[[case_study_batch]]', batch_string)
        
        messages = [    
        {"role": "system", "content": "You are a helpful assistant."},    
        {"role": "user", "content": prompt_string}    
    ]   
        result = get_chat_response_azure(
            client=client,
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            stop=stop
        )
        print("=============================== prompt: ===============================")
        print(prompt_string)
        print("=============================== response: ===============================")
        print(result)
        if result:
            batch_results.append(result)

    return batch_results

# Unify results from individual batches
def unify_results(batch_results, client, model, prompt_template, temperature=0.7, max_tokens=100, top_p=1, frequency_penalty=0, presence_penalty=0, stop=None):
    if batch_results:
        labeled_batch = [f"Case {j}:\n{item}" for j, item in enumerate(batch_results)]
        batch_results_string = '\n\n'.join(labeled_batch)
        unified_prompt = prompt_template.replace('[[batched_sum]]', batch_results_string)
        messages = [    
            {"role": "system", "content": "You are a helpful assistant."},    
            {"role": "user", "content": unified_prompt}    
        ]   
        final_result = get_chat_response_azure(
            client=client,
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            stop=stop
        )
        return final_result
    else:
        return "No valid results from batches to unify."

# Save batch results to a pickle file
def save_results(file_path, results):
    with open(file_path, 'wb') as f:
        pickle.dump(results, f)

# Load batch results from a pickle file
def load_results(file_path):
    with open(file_path, 'rb') as f:
        return pickle.load(f)

# Entry point with argument parsing
def main():
    parser = argparse.ArgumentParser(description='Call OpenAI API with specified parameters and configurations.')  
    parser.add_argument('--deployment_name', type=str, required=True, help='Model name to use for the API call.')  
    parser.add_argument('--temperature', type=float, default=0.0, help='Temperature for the response. Default is 0.0.')  
    parser.add_argument('--max_tokens', type=int, default=60, help='Maximum number of tokens to generate. Default is 60.')  
    parser.add_argument('--top_p', type=float, default=1, help='Top P value. Default is 1.')  
    parser.add_argument('--frequency_penalty', type=float, default=0, help='Frequency penalty. Default is 0.')  
    parser.add_argument('--presence_penalty', type=float, default=0, help='Presence penalty. Default is 0.')  
    parser.add_argument('--stop', nargs='*', help='Stop sequence(s). Multiple values are allowed.')  
    parser.add_argument('--api_key', type=str, required=True, help='OpenAI API key.')  
    parser.add_argument('--api_base', type=str, default="https://api.openai.com", help='OpenAI API base URL. Default is the standard OpenAI API.')  
    parser.add_argument('--api_version', type=str, default="v1", help='OpenAI API version. Default is "v1".')  
    parser.add_argument('--api_type', type=str, default="azure", help='OpenAI API Type. Default is "Azure"')  
    parser.add_argument('--prompt_path', type=str, required=True, help='Path to the input .jsonl file containing prompts.')  
    parser.add_argument('--result_path', type=str, required=True, help='Path where the output .jsonl file with results will be saved.')  
    parser.add_argument('--pickle_file', type=str, required=True, help='sved intermediate batch result')  
    parser.add_argument('--num_threads', type=int, required=False, help='if your API could be run in parallel')
    parser.add_argument('--dataset', type=str, required=False, help='Dataset type for loading specific prompts.')
    parser.add_argument('--opinion', type=str, default="pos", required=False, help='Dataset type for loading specific prompts.')
    
    
    args = parser.parse_args()  
    if args.dataset == "tabmwp":
        from prompts.tabmwp_prompt_combine.prompt.tabmwp_baseline import baseline_llm_v2 as prompt_template
    if args.dataset == "wikitq":
        from prompts.wikitq_prompt_combine.prompt.wikitq_all import llm_neg_opt_sum_body, llm_neg_opt_sum_final
        body_prompt, final_prompt = llm_neg_opt_sum_body, llm_neg_opt_sum_final
    if args.dataset == "bird":
        from prompts.bird_prompt_combine.prompt.bird_all import llm_neg_opt_sum_body, llm_neg_opt_sum_final
        body_prompt, final_prompt = llm_neg_opt_sum_body, llm_neg_opt_sum_final

    # Example client initialization
    # Replace with actual client initialization
    client = AzureOpenAI(  
        api_key=args.api_key,  
        api_version=args.api_version,  
        base_url=f"{args.api_base}/openai/deployments/{args.deployment_name}") 

    # Process batches
    batch_results = process_batches(
        file_path=args.prompt_path,
        client=client,
        model=args.deployment_name,
        batch_prompt=body_prompt,
        batch_size=20,
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        top_p=args.top_p,
        frequency_penalty=args.frequency_penalty,
        presence_penalty=args.presence_penalty,
        stop=args.stop
    )

    # Save batch results to pickle file
    save_results(args.pickle_file, batch_results)

    # Load batch results from pickle file
    loaded_results = load_results(args.pickle_file)
    
    # Unify results
    final_output = unify_results(
        batch_results=loaded_results,
        client=client,
        model=args.deployment_name,
        prompt_template=final_prompt,
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        top_p=args.top_p,
        frequency_penalty=args.frequency_penalty,
        presence_penalty=args.presence_penalty,
        stop=args.stop
    )

    print(final_output)

if __name__ == "__main__":
    main()
