import openai
import ast
import time
import json 
import csv 
from tqdm import tqdm

def get_chatgpt_answer(user_prompt, max_tokens=2048, temperature=0.1, system_prompt="You are a helpful assistant!"):
    api_token = ""
    model_name = "gpt-4o"
    openai.api_key = api_token
    response = openai.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0
    )
    return response.choices[0].message.content

def extract_ratings(text):
    start_marker = "[The Start of Ratings]"
    end_marker = "[The End of Ratings]"    
    start_index = text.find(start_marker) + len(start_marker)
    end_index = text.find(end_marker)
    ratings_str = text[start_index:end_index].strip()  
    
    try:
        ratings_dict = ast.literal_eval(ratings_str)
    except (ValueError, SyntaxError):
        raise ValueError("The content between the markers is not a valid dictionary format.")
    return ratings_dict

def generate_judge_prompt(task_instruction, input, reference, generated):
    JUDGE_PROMPT = """Please act as an impartial judge and evaluate how well an assistant's answer aligns with the reference answer and the quality of the assistant's answer. You will be given a user prompt, a reference answer and an assistant's answer. 
Your evaluation must consider the following criteria: 
- Format consistency: ensuring the generated response matches the length and structure of the reference.
- Content completeness: evaluating whether all key points present in the reference are included in the assistant's answer.
- Factuality: checking for factual correctness of the assistant's answer.
- Style adherence: ensuring that the tone, style, and level of detail of the of the assistant's answer match the reference.
- Assistant's answer quality: assessing how well the response satisfies the user's requirements.

Begin your evaluation by providing a short explanation for each. Be as objective as possible. After providing your explanation, please rate the response on all the criterion on a scale of 1 to 10 by strictly following this format: 

[The Start of Explanation]
...
[The End of Explanation]

[The Start of Ratings]
{
"Format": 1-10,
"Content": 1-10,
"Factuality": 1-10,
"Style": 1-10,
"Quality": 1-10,
}
[The End of Ratings]

[User Prompt]
{{user_prompt}}

[The Start of Reference Answer]
{{answer_ref}}
[The End of Reference Answer]

[The Start of Assistant’s Answer]
{{answer_a}}
[The End of Assistant’s Answer]
"""
    return JUDGE_PROMPT.replace("{{user_prompt}}", task_instruction + "\n" + input).replace("{{answer_ref}}", reference).replace("{{answer_a}}", generated)

def get_llm_answer_with_retry(input_prompt, max_retries=3, retry_delay=2):
    for attempt in range(1, max_retries + 1):
        try:
            chatgpt_ans = get_chatgpt_answer(input_prompt)
            extracted_answer = extract_ratings(chatgpt_ans)
            return chatgpt_ans, extracted_answer
        except Exception as e:
            print(f"Error in llm answer (attempt {attempt}/{max_retries}): {str(e)}")
            time.sleep(retry_delay * attempt)
    raise Exception(f"Max retries reached. Unable to get a valid response from LLM.")

FILE_TO_JUDGE = "output/cnn_longguide.csv"
SAVED_SCORES = "output/cnn_llmjudge_scores.csv"

cnt = 0

our_sum = 0
baseline_sum = 0

saved_rows = []
with open(FILE_TO_JUDGE) as file:
    csvreader = csv.reader(file)
    header = list(next(csvreader))
    for row in tqdm(csvreader):
        article = row[header.index("article")]
        highlights = row[header.index("highlights")]
        zs = row[header.index("zero_shot_answer")]
        fs = row[header.index("few_shot_answer")]
        
        only_linguistics = row[header.index("only_linguistics")]
        fs_our = row[header.index("full_attributes_few_shot")] 
        
        task_instruction = "Summarize the highlights from the following article:"
        
        our_judge_prompt = generate_judge_prompt(task_instruction, article, highlights, only_linguistics)
        our_raw_judge, our_scores = get_llm_answer_with_retry(our_judge_prompt)
        
        baseline_judge_prompt = generate_judge_prompt(task_instruction, article, highlights, zs)
        baseline_raw_judge, baseline_scores = get_llm_answer_with_retry(baseline_judge_prompt)
        
        our_fs_judge_prompt = generate_judge_prompt(task_instruction, article, highlights, fs_our)
        our_fs_raw_judge, our_fs_scores = get_llm_answer_with_retry(our_fs_judge_prompt)
        
        fs_judge_prompt = generate_judge_prompt(task_instruction, article, highlights, fs)
        fs_raw_judge, fs_scores = get_llm_answer_with_retry(fs_judge_prompt)
        
        our_sum += sum(our_scores.values())/5
        baseline_sum += sum(baseline_scores.values())/5
        
        saved_rows.append([
            article, highlights, zs, fs, baseline_scores, fs_scores, our_scores, our_fs_scores
        ])
        
        cnt += 1

print(f"Ours: {our_sum/cnt}")
print(f"Baselines: {baseline_sum/cnt}")

with open(SAVED_SCORES, "w") as file:
    csvwriter = csv.writer(file)
    csvwriter.writerow([
        "article", "highlights", "zs", "fs", "zs_scores", 
        "fs_scores", "our_zs_scores", "our_fs_scores"
    ])
    csvwriter.writerows(saved_rows)