import pandas as pd
import tqdm
from openai import OpenAI
import os
from prompts import REWRITE_PROMPTS
from tqdm import tqdm
import ast

MODEL_NAME = "gpt-4o"
client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
    organization=os.environ.get("OPENAI_ORGANIZATION")
)

def rewrite_conversation(conversation, prompt_template, rtype):
    '''
    Prompt LLM for rewrite
    '''
    user_prompt = prompt_template.format(conversation=conversation)
    messages = [
        {
            "role": "system",
            "content": "Your output should be in tone of a instruction-style sentence which will be used as input to a planner module. eg. Build a..., design a..., plan a..., etc. Do not use first person nouns or pronouns."
        },
        {
            "role": "user",
            "content": user_prompt
        }
    ]
    model_config = {
        "model": MODEL_NAME,
        "messages": messages,
        "temperature": 0,
    }
    response = client.chat.completions.create(**model_config)
    return response.choices[0].message.content

def main(input_file, output_file):
    '''
    Generate rewrite of 3 types: dummy, basic, advanced
    '''
    df = pd.read_csv(input_file)
    rows = []
    df['conversation'] = df['conversation'].apply(ast.literal_eval)

    for i, row in tqdm(df.iterrows(), total=len(df)):
        conv_id = row['conv_id']
        conversation = row['conversation']
        for rtype in REWRITE_PROMPTS:
            if rtype == "dummy":
                rewrite = "\n".join(conversation)
            else:
                prompt_template = REWRITE_PROMPTS[rtype]
                try:
                    rewrite = rewrite_conversation(conversation, prompt_template, rtype)
                except Exception as e:
                    print(f"ERROR: {e}")
                    rewrite = ""
            rows.append({
                "conv_id": conv_id,
                "rewrite_type": rtype,
                "rewrite": rewrite
            })
        new_df = pd.DataFrame(rows)
        new_df.to_json(output_file, orient='records', indent=2)

if __name__ == '__main__':
    main(
        input_file = f"data/synthetic_conversations.json",
        output_file = f"data/rewrite.json"
    )