import json
import multiprocessing
import os
import sys
from argparse import ArgumentParser
from functools import partial
from typing import Any
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from tqdm import tqdm
from openai import OpenAI

from src.utils.io_utils import ensure_output_path_not_exists, save_evaluation_results
from src.utils.generation_utils import generate_text_chat
from src.utils.prompts_utils import EVAL_PROMPTS_FOR_OPENAUDIOBENCH
from src.utils.utils import get_next_level, OPENAUDIOBENCH_ATTRIBUTES, OPENAI_API_KEY, OPENAI_BASE_URL

"""
python eval/openaudiobench/evaluate_openaudiobench_gpt4.py --result_path path/to/model_outputs.jsonl
"""


# LLM used in evaluation
client = OpenAI(
    api_key=OPENAI_API_KEY,
    base_url=OPENAI_BASE_URL,
    default_headers={"x-foo": "true"},
)


def generate(item, template: str, dataset_name: str, attributes: Any):
    if dataset_name in ['alpaca_eval']:
        prompt = template.format(question=item[attributes[0]], answer=item['infer_response'])
    else:
        prompt = template.format(question=item[attributes[0]], gt_answer=item[attributes[1]], answer=item['infer_response'])

    rtn = generate_text_chat(
        client=client,
        model="gpt-4o-2024-08-06",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=8192,
        temperature=0.0
    )
    item['eval_response'] = rtn.choices[0].message.content.strip()
    return item


def main():

    parser = ArgumentParser()
    parser.add_argument('--result_path', type=str, default=None)

    args = parser.parse_args()
    result_path = args.result_path

    # load all results
    if not os.path.exists(result_path):
        print(f"{result_path} does not exist, exiting...")
        exit(0)

    with open(result_path, 'r', encoding='utf-8') as file:
        all_results = [json.loads(line) for line in file]

    if len(all_results) == 0:
        print(f"{result_path} is empty, exiting...")
        exit(0)

    
    # get evaluation template
    dataset_name = get_next_level(result_path, "OpenAudioBench")
    TEMPLATE = EVAL_PROMPTS_FOR_OPENAUDIOBENCH[dataset_name]
    filename = f"{'.'.join(result_path.split('/')[-1].split('.')[:-1])}_gpt4_evaluated.jsonl"


    # ensure output path not exists
    eval_output_path = os.path.join(os.path.dirname(result_path), filename)
    ensure_output_path_not_exists(eval_output_path)

    # evaluate data
    attributes = OPENAUDIOBENCH_ATTRIBUTES[dataset_name]
    _generate = partial(generate, template=TEMPLATE, dataset_name=dataset_name, attributes=attributes)
    with multiprocessing.Pool(processes=4) as pool:
        evaluated_results = list(tqdm(pool.imap(_generate, all_results), total=len(all_results)))

    # sort with uuid
    evaluated_results = sorted(evaluated_results, key=lambda x: x['uuid'])
    prompts = [item['infer_response'] for item in evaluated_results]
    outputs = [item['eval_response'] for item in evaluated_results]

    # save results
    save_evaluation_results(eval_output_path, evaluated_results, prompts, outputs, attributes[2])

    # sleep to inspect the results
    if dataset_name == 'web_questions':
        print(f"Sleeping for 1000000 seconds...")
        import time
        time.sleep(1000000)


if __name__ == "__main__":
    main()
