from openai import OpenAI
import pandas as pd
import json
import os
from dotenv import load_dotenv
import argparse
import re

load_dotenv(override=True)
API_KEY = os.getenv("OPENAI_API_KEY")


def generate_task(caption_root, ds = 'S4', split='test', audio_prompt="generate audio caption", model="gpt-4o-mini"):
    df = pd.read_csv(os.path.join(caption_root, 'audio captions', ds, split, f'audio prompt = {audio_prompt}.csv'))

    tasks = []

    for _, row in df.iterrows():
        name = row['name']
        task_id = name + '/' + str(row['frame'] + 1)
        text = row['text']
        task = {
             "custom_id": f"task#{task_id}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                # This is what you would have in your Chat Completions API call
                "model": model,
                "messages": [
                    {
                        "role": "system",
                        "content": 
f"""
You are participating in a competitive game where your goal is to identify the most likely abstract source(s) (e.g., human, instrumental, etc.) that is/are producing sound in a given audio clip.
The given frame caption corresponds to a video clip and may contain no sound-producing objects at all, or the caption text could provide misleading information.

Your task:
- Identify and output only the object(s) producing sound with the given caption text.
- Provide your guess in one line, (seperate by comma if multiple objects), enclosed in with <answer> and </answer> tag pair. (E.g., <answer>guitar</answer>).
"""
                    },
                    {
                        "role": "user",
                        "content": text
                    }
                ],
            }
        }
        tasks.append(task)
    
    return tasks


def generate_advanced_task(caption_root, ds = 'S4', split='test', prompt_consist=False, frame_consist=False, model="gpt-4o-mini"):

    df1 = pd.read_csv(os.path.join(caption_root, 'audio caption', ds, split, 'audio prompt = generate audio caption.csv'))
    df2 = pd.read_csv(os.path.join(caption_root, 'audio caption', ds, split, 'audio prompt = generate metadata.csv'))
    df3 = pd.read_csv(os.path.join(caption_root, 'audio caption', ds, split, 'audio prompt = this is a sound of.csv'))

    task_goal = ""
    if prompt_consist:
        task_goal += "\n- Analyze the outputs from all audio AIs in each frame together (i.e., Prompt consistency)."
    if frame_consist:
        task_goal += "\n- Consider the relationships among frames (i.e., Frame consistency)."

    tasks = []

    gp1 = df1.groupby('name')
    gp2 = df2.groupby('name')
    gp3 = df3.groupby('name')
 
    for name in gp1.groups:
        item1 = gp1.get_group(name)
        item2 = gp2.get_group(name)
        item3 = gp3.get_group(name)
        
        # text = row['text']
        # task_id = row['name'] + '/' + str(row['frame'])
        # text1 = '<bag1>\n' + '\n'.join(item1['text'].values) + '\n</bag1>'
        # text2 = '<bag2>\n' + '\n'.join(item2['text'].values) + '\n</bag2>'
        # text3 = '<bag3>\n' + '\n'.join(item3['text'].values) + '\n</bag3>'

        frames = []

        for frame_id, (exp1, exp2, exp3) in enumerate(zip(item1['text'].values, item2['text'].values, item3['text'].values)):
            frame = f'<frame{frame_id}>\n' + f' <exp1>{exp1}</exp1>\n' + f' <exp2>{exp2}</exp2>\n' + f' <exp3>{exp3}</exp3>\n' + f'</frame{frame_id}>'
            frames.append(frame)

        task_id = name + '/' + str(item1['frame'].max() + 1)
        task = {
            "custom_id": f"task#{task_id}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                # This is what you would have in your Chat Completions API call
                "model": model,
                "messages": [
                    {
                        "role": "system",
                        "content": 
f"""
You are participating in a competitive game where your goal is to identify the most likely abstract source(s) (e.g., human, instrumental, etc.) that is/are producing sound in a given audio clip. This clip was broken down into several frames, each containing multiple audio outputs generated by different AIs, representing sounds at a specific timestamp.
Each frame corresponds to a different moment in the same video clip and some frames may contain no sound-producing objects at all, or the text output could provide misleading information.

Your task:{task_goal}
- Identify and output only the object(s) producing sound in each frame.
- For each frame (e.g., <frame0>...</frame0>), provide your guess in one line, (seperate by comma if multiple objects), enclosed in with <answer_x> and </answer_x> tag pair, where x is the frame id (e.g., <answer_0>...</answer_0>).
"""
                    },
                    {
                        "role": "user",
                        "content": '\n'.join(frames)
                    }
                ],
            }
        }
        
        tasks.append(task)

    return tasks


def write_task_to_jsonl(tasks, output_file):
    with open(output_file, 'w') as file:
        for obj in tasks:
            file.write(json.dumps(obj) + '\n')


def submit_task(task_file):

    client = OpenAI(api_key=API_KEY)

    # Submit batch
    batch_input_file = client.files.create(
      file=open(task_file, "rb"),
      purpose="batch"
    )

    batch_file = client.batches.create(
        input_file_id=batch_input_file.id,
        endpoint="/v1/chat/completions",
        completion_window="24h"
    )

    return batch_file.id


def retrive_result(file_api, result_file):
    client = OpenAI(api_key=API_KEY)
    # Retrive
    file_id = client.batches.retrieve(file_api).output_file_id
    result = client.files.content(file_id).content

    with open(result_file, 'wb') as file:
        file.write(result)


def save_answer(result_josnl, ouput_csv, advanced=False):

    if not advanced:
        df = pd.read_json(result_josnl, lines=True)
        flat = (
        pd.json_normalize(
            df["response"],          # the nested column you want to explode
            sep="."                  # dot‑separated keys → column names
            )
        )

        df = pd.concat([df[["custom_id"]], flat], axis=1)

        df["content"] = df["body.choices"].str[0].str["message"].str["content"]
        df["prompt_tokens"] = df["body.usage.prompt_tokens"]
        df["completion_tokens"] = df["body.usage.completion_tokens"]

        df_res = df[["custom_id", "content", "prompt_tokens", "completion_tokens"]]

        df_res["object"] = df_res["content"].str.extract(r"<answer>(.*?)</answer>")
        df_res['name'] = df_res['custom_id'].str.split('task#').str[1].str.split('/').str[0]
        df_res['frame'] = df_res['custom_id'].str.split('task#').str[1].str.split('/').str[1].astype(int)

        df_res = df_res[['name', 'frame', 'object', 'prompt_tokens', 'completion_tokens']]
        df_res.to_csv(ouput_csv, index=False)
        return 

    results = []
    with open(result_josnl, 'r') as file:
        for line in file:
            # Parsing the JSON string into a dict and appending to the list of results
            json_object = json.loads(line.strip())
            results.append(json_object)

    # data = []
    name = []
    frame = []
    object = []
    missing_content_id = []
    prompt_tokens = []
    completion_tokens = []

    for res in results:
        # task_id = res['custom_id']
        task_id = res['task_id']
        # row = re.match(pattern_task, task_id)
        row = task_id.split('task#')[1].split('/')
        # answer = res['response']['body']['choices'][0]['message']['content']
        answer = res['response']['choices'][0]['message']['content']
        useage =  res['response']['usage']
        prompt_tokens.append(useage['prompt_tokens'])
        completion_tokens.append(useage['completion_tokens'])

        if answer == "":
            missing_content_id.append(task_id)
            continue
        length = int(row[-1])
        response = []
        for i in range(length):
            pattern_answer = rf"<answer_{i}>(.*?)</answer_{i}>"
            find_answer = re.findall(pattern_answer, answer)
            # assert len(find_answer) == 1, print(task_id)
            response.append(",".join(find_answer))

        if len(response) != length:
            print(task_id)
            raise ValueError

        name += [row[0]] * length
        frame += [i for i in range(length)]
        if len(response) < length:
            object += response + [''] * (length - len(response))
        else:
            object += response[:length]
        # if len(find_answer) == 0:
        #     row.append("")
        # else:
        #     row.append(re.findall(pattern_answer, answer)[0])
        # data.append(row)
    
    # name, frame, object = zip(*data)
    df_data = pd.DataFrame({'name': name, 'frame': frame, 'object': object, 'prompt_tokens': prompt_tokens, 'completion_tokens': completion_tokens})
    df_data.to_csv(ouput_csv)
    with open("missing_content_id.txt", 'w') as file:
        file.write("\n".join(missing_content_id))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--caption_root', type=str, default='output')
    parser.add_argument('--audio_prompt', type=str, default='generate audio caption')
    parser.add_argument('--dataset', type=str, default='S4')
    parser.add_argument('--split', type=str, default='train')
    parser.add_argument('--output_root', type=str, default='translated')
    parser.add_argument('--model', type=str, default='gpt-4o-mini')
    parser.add_argument('--prompt_consist', action='store_true')
    parser.add_argument('--frame_consist', action='store_true')
    parser.add_argument('--submit', action='store_true')
    args = parser.parse_args()
    return args


def translate_submit():
    if not args.prompt_consist and not args.frame_consist:
        tasks = generate_task(args.caption_root, args.dataset, args.split, args.audio_prompt, args.model)
        task_jsonl = os.path.join(out_root, "submit_batch - " + args.audio_prompt + " - " + args.model + ".jsonl")
    else:
        tasks = generate_advanced_task(args.caption_root, args.dataset, args.split, args.prompt_consist, args.frame_consist, args.model)
        task_jsonl = os.path.join(out_root, "submit_batch - " + "prompt_consist=" + args.prompt_consist + " - " + "frame_consist=" + args.frame_consist + " - " + args.model + ".jsonl")
    batch_id = write_task_to_jsonl(tasks, task_jsonl)
    base_name = os.path.basename(task_jsonl)

    # save the batch id
    with open(os.path.join(out_root, base_name.replace(".jsonl", ".batch")), 'w') as f:
        f.write(batch_id)


def translation_fetch():
    if not args.prompt_consist and not args.frame_consist:
        batch_name = "submit_batch - " + args.audio_prompt + " - " + args.model + ".batch"
        result_jsonl = "result - " + args.audio_prompt + " - " + args.model + ".jsonl"
        result_csv = "result - " + args.audio_prompt + " - " + args.model + ".csv"
    else:
        batch_name = "submit_batch - " + "prompt_consist=" + args.prompt_consist + " - " + "frame_consist=" + args.frame_consist + " - " + args.model + ".batch"
        result_jsonl = "result - " + "prompt_consist=" + args.prompt_consist + " - " + "frame_consist=" + args.frame_consist + " - " + args.model + ".jsonl"
        result_csv = "result - " + "prompt_consist=" + args.prompt_consist + " - " + "frame_consist=" + args.frame_consist + " - " + args.model + ".csv"
    try:
        with open(batch_name, 'r') as f:
            batch_id = f.read()
    except:
        print('No batch id found, submit first')
        return

    try:
        retrive_result(batch_id, result_jsonl)
        save_answer(result_jsonl, result_csv)
    except:
        client = OpenAI(api_key=API_KEY)
        response = client.batches.list()

        for batch in response:
            if batch.id == batch_id:
           
                print(batch)
        # status.append(bat
        print('Batch not finished, please wait')


if __name__ == '__main__':
    args = parse_args()

    out_root = os.path.join(args.caption_root, args.output_root, args.dataset, args.split)
    os.makedirs(out_root, exist_ok=True)

    if args.submit:
        translate_submit()
        exit()

    translation_fetch()
