#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import re
import datetime
import requests
import random
import time
from collections import defaultdict
import os

GPT_EVAL_MODEL_NAME = "gpt-4o-2024-11-20"
API_TYPE = os.getenv("API_TYPE", "openai")
NUM_SECONDS_TO_SLEEP = 5

if API_TYPE == "openai":
    API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
    API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json",
    }
elif API_TYPE == "azure":
    API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
    API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
    headers = {
        "api-key": API_KEY,
        "Content-Type": "application/json",
    }

eval_prompt = '''
As an AI assistant, your task is to evaluate a candidate answer in comparison to a given correct answer for a video reasoning QA task. Both the groundtruth and candidate answers contain reasoning steps and supporting evidence. 

The **final answer** refers to the direct response to the question, excluding any intermediate reasoning steps. If the candidate’s final answer matches the groundtruth’s final answer, it is considered correct; otherwise, it is incorrect.

Your assessment should range from 0 to 4, based on both the correctness of the final answer and the similarity of the reasoning process:

- If the final answer differs from the groundtruth, the score must be 0, 1, or 2:
  - **0**: The final answer is incorrect, and the provided reasoning/evidence is either missing or entirely dissimilar to the groundtruth.
  - **1**: The final answer is incorrect, but some visual details, reasoning steps, or evidence partially overlap with the groundtruth; however, most of the reasoning is incorrect.
  - **2**: The final answer is incorrect, but the majority of the reasoning process, including key visual evidence and logical steps, aligns with the groundtruth, with only minor deviations causing an incorrect conclusion.

- If the final answer matches the groundtruth, the score must be 3 or 4:
  - **3**: The final answer is correct, but the reasoning process or supporting evidence significantly differs from the groundtruth.
  - **4**: The final answer is correct, and the reasoning process, including supporting evidence, closely aligns with the groundtruth without major inconsistencies.

Your response should be a single integer: 0, 1, 2, 3, or 4.

Question: {question}
Groundtruth answer: {answer}
Candidate answer: {candidate}
Your response:"""
'''

def extract_characters_regex(s):
    s = s.strip()
    answer_prefixes = [
        "The best answer is", "The correct answer is", "The answer is", "The answer",
        "Best answer:", "Answer:", "The best option is", "The correct option is",
        "Best option:", "Option:", "Therefore, the final answer is:", 
        "The final answer is:", "final answer is:", "final answer is",
        "final answer", "Final answer", "FINAL ANSWER:", "ANSWER:"
    ]
    for prefix in answer_prefixes:
        s = s.split(prefix)[-1]
    if len(s.split()) > 10 and not re.search("[ABCDE]", s):
        return ""
    m = re.search(r"[ABCDE]", s)
    return m[0] if m else ""

def process_results_mc(doc, results):
    pred = results[0]
    pred_option = extract_characters_regex(pred)
    gt_option = doc["correct_option"]
    correct = (pred_option.upper() == gt_option.upper())
    return {
        "correct": correct,
        "task": doc.get("task"),
        "dimension": doc.get("dimension"),
        "level": doc.get("level"),
    }

def aggregate_results_mc(results):
    task_correct = defaultdict(list)
    dim_tasks = defaultdict(set)
    lvl_tasks = defaultdict(set)

    for entry in results:
        t = entry["task"]
        d = entry["dimension"]
        l = entry["level"]
        c = entry["correct"]

        task_correct[t].append(c)
        dim_tasks[d].add(t)
        lvl_tasks[l].add(t)

    print("="*30)
    print("Task Accuracies:")
    task_acc = {}
    for t, lst in task_correct.items():
        acc = sum(lst) / len(lst) * 100
        print(f"  Task: {t}, Acc: {acc:.2f}% ({len(lst)} samples)")
        task_acc[t] = acc

    print("\nDimension Accuracies:")
    dim_acc = {}
    for d, tasks in dim_tasks.items():
        acc = sum(task_acc[t] for t in tasks) / len(tasks)
        print(f"  Dimension: {d}, Acc: {acc:.2f}% ({len(tasks)} tasks)")
        dim_acc[d] = acc

    print("\nLevel Accuracies:")
    lvl_acc = {}
    for l, tasks in lvl_tasks.items():
        acc = sum(task_acc[t] for t in tasks) / len(tasks)
        print(f"  Level: {l}, Acc: {acc:.2f}% ({len(tasks)} tasks)")
        lvl_acc[l] = acc

    overall_acc = sum(task_acc.values()) / len(task_acc) if task_acc else 0
    print(f"\nOverall Accuracy (by tasks): {overall_acc:.2f}%")
    print("="*30)

    return {
        "task_accuracy": task_acc,
        "dimension_accuracy": dim_acc,
        "level_accuracy": lvl_acc,
        "overall_accuracy": overall_acc / 100.0, 
    }

def get_eval(question: str, ground_truth: str, candidate: str, max_tokens: int, retries: int = 5):
    global headers
    content = eval_prompt.format(question=question, answer=ground_truth, candidate=candidate)

    messages = [
        {"role": "user", "content": content},
    ]

    payload = {
        "model": GPT_EVAL_MODEL_NAME,
        "messages": messages,
        "temperature": 0,
        "max_tokens": max_tokens,
    }

    base_sleep = NUM_SECONDS_TO_SLEEP  # Base sleep time in seconds

    for attempt in range(retries):
        try:
            response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
            response.raise_for_status()
            response_data = response.json()

            content = response_data["choices"][0]["message"]["content"].strip()
            if content != "":
                return content, response_data["model"]
            break  # If successful, break out of the loop

        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {e}")
            if attempt < retries - 1:  # If we have retries left, sleep and then continue to next attempt
                # Calculate exponential backoff sleep time
                sleep_time = base_sleep * (3 ** attempt)
                
                # Optional: Add jitter by randomizing the sleep time slightly
                jitter = random.uniform(0, base_sleep)
                sleep_time += jitter

                print(f"Sleeping for {sleep_time:.2f} seconds before next retry.")
                time.sleep(sleep_time)
            else:  # If this was the last attempt, log and return empty
                print(f"All {retries} attempts failed. Last error message: {e}")
                return "", ""
    return "", ""

def process_results_oe(doc, results):
    pred = results[0]
    return {
        "question": doc.get("question", ""),
        "answer": doc.get("answer", ""),
        "pred": pred
    }
    

def extract_eval_score(eval_answer: str, max_score = 4) -> float:
    numbers = re.findall(r'\d+(?:\.\d+)?', eval_answer)
    if numbers:
        last_number = float(numbers[-1])
        if 0 <= last_number <= max_score:
            return last_number
    return 0.0    

def oe_gpt_eval(results):
    score = []
    for result in results:
        eval_answer, model_name = get_eval(question=result["question"], ground_truth=result["answer"], candidate=result["pred"], max_tokens=1024)
        eval_score = extract_eval_score(eval_answer)
        try:
            eval_score = float(eval_score)
        except:
            eval_score = 0.0
        result["eval_score"] = eval_score
        score.append(eval_score)
    
    print(f"The open-ended score is {sum(score)/len(score):.2f}.") 


def main():
    parser = argparse.ArgumentParser(
        description="HumanPCR Evaluation Script Preview"
    )
    parser.add_argument(
        '--input', '-i', type=str, required=True,
        help='Input JSON file containing original samples and model predictions'
    )
    parser.add_argument(
        '--type', '-t', choices=['mc','oe'], required=True,
        help='Evaluation task type: mc (multiple choice) or oe (open-ended question)'
    )
    args = parser.parse_args()

    data = json.load(open(args.input, 'r', encoding='utf-8'))
    submissions = []
    results = []

    if args.type == 'mc':
        for idx, doc in enumerate(data):
            pred = doc.get('model_prediction', '')
            res = process_results_mc(doc, [pred])
            results.append(res)
            submissions.append({
                **({"id": doc["id"]} if "id" in doc else {}),
                "prediction": extract_characters_regex(pred)
            })
        with open(args.output, 'w', encoding='utf-8') as fw:
            for rec in submissions:
                fw.write(json.dumps(rec, ensure_ascii=False) + "\n")
        aggregate_results_mc(results)

    else:
        out_sub = args.output
        results = []

        for doc in data:
            pred = doc.get('model_prediction', '')
            res = process_results_oe(doc, [pred])
            results.append(res)
            
        oe_gpt_eval(results)

if __name__ == "__main__":
    main()
