import pandas as pd
import json
import os
from openai import OpenAI

# Initialize OpenAI client
client = OpenAI(
    api_key=os.environ.get(
        "OPENAI_API_KEY",
        "OPENAI_API_KEY",
    )
)


def run_rating(prompt, response):
    evaluation_prompt = f"""
    You are an expert rater of summarization quality. Your role is to provide a SINGLE OVERALL SCORE from 1 to 10 that rates the coherence, accuracy, and coverage of the SUMMARY for a given ORIGINAL TEXT.

    A SCORE of 1 refers to a summary that is completely incoherent, inaccurate, or fails to cover the important information.
    A SCORE of 10 refers to a summary that is perfectly coherent, entirely accurate, and covers all important information.

    When determining the score, consider:
    - Coherence: How well-structured and understandable is the summary on its own?
    - Accuracy: Does the factual information in the summary accurately match the original text?
    - Coverage: How well does the summary cover the important information in the original text?

    ORIGINAL TEXT: {prompt}

    SUMMARY: {response}

    Please return your evaluation as a JSON object with a single field 'overall_score'. DO NOT INCLUDE ANY OTHER OUTPUT.
    """

    try:
        chat_completion = client.chat.completions.create(
            model="gpt-4",
            messages=[{"role": "user", "content": evaluation_prompt}],
            temperature=0,
        )

        raw_output = chat_completion.choices[0].message.content
        score = json.loads(raw_output)
        return raw_output, score["overall_score"]
    except Exception as e:
        print(f"Error: {e}")
        return raw_output, None


# Load the data
response_data = pd.read_csv("data.csv")
# response_data = response_data.sample(n=5, random_state=42)  # Randomly sample 5 rows

comparison_pairs = [
    ("base", "explainable"),
    ("base", "blackbox"),
    ("explainable", "blackbox"),
]

results = {pair: {"wins": 0, "ties": 0, "losses": 0} for pair in comparison_pairs}

for index, row in response_data.iterrows():
    prompt = row["prompt"]
    base = row["response (base)"]
    explainable = row["response (explainable)"]
    blackbox = row["response (blackbox)"]

    base_rating, base_score = run_rating(prompt, base)
    explainable_rating, explainable_score = run_rating(prompt, explainable)
    blackbox_rating, blackbox_score = run_rating(prompt, blackbox)

    response_data.loc[index, "base_rating"] = base_rating
    response_data.loc[index, "explainable_rating"] = explainable_rating
    response_data.loc[index, "blackbox_rating"] = blackbox_rating

    response_data.loc[index, "base_score"] = base_score
    response_data.loc[index, "explainable_score"] = explainable_score
    response_data.loc[index, "blackbox_score"] = blackbox_score

    if (
        base_score is not None
        and explainable_score is not None
        and blackbox_score is not None
    ):
        for pair in comparison_pairs:
            model1, model2 = pair
            score1 = eval(f"{model1}_score")
            score2 = eval(f"{model2}_score")

            if score1 > score2:
                results[pair]["wins"] += 1
            elif score1 < score2:
                results[pair]["losses"] += 1
            else:
                results[pair]["ties"] += 1

response_data.to_csv("saved_data.csv", index=False)

print("Evaluation complete. Results saved to 'saved_data.csv'.")

for pair in comparison_pairs:
    print(f"\nComparison: {pair[0]} vs {pair[1]}")
    wins = results[pair]["wins"]
    ties = results[pair]["ties"]
    losses = results[pair]["losses"]
    print(f"Overall: Wins: {wins}, Ties: {ties}, Losses: {losses}")
