import pandas as pd
from openai import OpenAI
import os
from prompts import PAIRWISE_EVALUATION_PROMPT
import json
import random
import argparse

model = OpenAI(
        api_key=os.environ.get("OPENAI_API_KEY"),
        organization=os.environ.get("OPENAI_ORGANIZATION")
)

def pairwise_eval_helper(conversation, planA, planB):
    '''
    Prompts LLM for evaluation label
    '''
    user_prompt = PAIRWISE_EVALUATION_PROMPT.format(
        conversation=conversation, 
        planA=planA, 
        planB=planB
    )
    model_config = {
        "model": "gpt-4o-mini",
        "messages": [
            {
                "role": "system",
                "content": "Your answer should be in JSON format with keys \"choice\" : which is either \"A\" for Plan A or \"B\" for Plan B or \"tie\"" #; and \"explanation\": which is explanation of why A or B or tie was chosen"
            },
            {
                "role": "user",
                "content": user_prompt
            }
        ],
        "temperature": 0,
        "response_format": {"type": "json_object"}
    }
    response = model.chat.completions.create(**model_config)
    content = response.choices[0].message.content
    content = json.loads(content)
    return content

def pairwise_eval(conversation, plan_from_conv, plan_from_rewrite):
    '''
    Pairwise evaluation of plans. Returns better plan: "A", "B" or "tie".
    '''

    # mitigate positional bias
    planA, planB, rewrite_assignment = (plan_from_conv, plan_from_rewrite, "B") if random.choice([True, False]) else (plan_from_rewrite, plan_from_conv, "A")
    res = pairwise_eval_helper(conversation, planA=planA, planB=planB)
    choice = res["choice"]

    if choice == "tie":
        return "tie"
    elif choice == rewrite_assignment:
        return "plan_from_rewrite"
    else:
        return "plan_from_conv"

def main(input_file, output_file, eval_type, plan_type=None):
    df = pd.read_json(input_file)
    eval = []

    for i, row in df.iterrows():
        print(i)
        conversation = row['conversation']
        try:
            if eval_type == 'pairwise':
                plan_from_conv = row['plan_from_conversation']
                plan_from_rewrite = row['plan_from_rewrite']
                eval.append(pairwise_eval(conversation, plan_from_conv, plan_from_rewrite))
        except:
            eval.append({})
            
    if eval_type == "pairwise":
        print("Chose Plan_from_conv")
        print(eval.count("plan_from_conv")/len(df))
        print("Chose Plan_from_rewrite")
        print(eval.count("plan_from_rewrite")/len(df))
        print("Tied")
        print(eval.count("tie")/len(df))

    df['llm_pref'] = eval
    df.to_json(output_file, orient = 'records', indent = 2)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('input_file')
    parser.add_argument('output_file')
    parser.add_argument('eval_type')
    args = parser.parse_args()

    main(input_file=args.input_file, output_file=args.output_file, eval_type=args.eval_type)