import pandas as pd
from openai import OpenAI
import os
from prompts import PLAN_PROMPT
from llm_plan import LLMPlan
from tqdm import tqdm

MODEL_NAME = "gpt-4o"

class Planner:
    def __init__(self):
        self.model = OpenAI(
            api_key=os.environ.get("OPENAI_API_KEY"),
            organization=os.environ.get("OPENAI_ORGANIZATION")
        )

    def plan(self, input):
        '''Generate plan from rewrite. 
        Plan Format: {
            'nodes': List[Dict]
                - 'id': node id
                - 'name': node task
            'edges': List[int]
        } '''
        
        user_prompt = PLAN_PROMPT.format(input=input)
        model_config = {
            "model": MODEL_NAME,
            "messages": [
                {
                    "role": "user",
                    "content": user_prompt
                }
            ],
            "temperature": 0,
            "response_format": LLMPlan
        }
        response = self.model.beta.chat.completions.parse(**model_config)
        content = response.choices[0].message.content
        return content

def main(input_file, output_file, generate_plan_from_conv=True):
    # generate_plan_from_conv = False, once already generated.
    df = pd.read_json(input_file)

    plan_from_conversation = []
    plan_from_rewrite = []
    planner1 = Planner()

    for i, row in tqdm(df.iterrows(), total=len(df)):
        if generate_plan_from_conv:
            conversation = row['conversation']
            plan_from_conversation.append(planner1.plan(conversation))
        rewritten = row['rewritten']
        plan_from_rewrite.append(planner1.plan(rewritten))

    if generate_plan_from_conv:
        df['plan_from_conversation'] = plan_from_conversation
    df['plan_from_rewrite'] = plan_from_rewrite
    df.to_json(output_file, orient='records', indent=2)

if __name__ == '__main__':
    main(
        input_file="data/rewrite.json", 
        output_file="data/plans.json", 
        generate_plan_from_conv=True
    )