import json
import os
import time
import re
from openai import OpenAI
from tqdm import tqdm

def review_answer(client, question, reasoning, answer, index):
    retries = 3
    for attempt in range(retries):
        try:
            response = client.chat.completions.create(
                model="deepseek-chat",
                messages=[
                    {
                        "role": "system",
                        "content": """You are a VQA (Visual Question Answering) review expert. Your task is to determine if the provided answer aligns with the question and reasoning.

                        Review criteria:
                        1. Factual accuracy - The answer must be consistent with facts in the reasoning
                        2. Logical consistency - The answer should not contradict the reasoning
                        3. Completeness - The answer should cover key information from the reasoning

                        If the answer has serious factual errors or logical contradictions, provide a corrected version that is concise yet accurate.
                        If the answer is generally consistent, simply reply "CONSISTENT".
                        """
                    },
                    {
                        "role": "user",
                        "content": f"Question: {question}\n\nReasoning: {reasoning}\n\nCurrent Answer: {answer}\n\nPlease review item #{index}: Is this answer consistent with the question and reasoning? If not, provide a corrected answer that strictly follows the reasoning content. If consistent, only reply 'CONSISTENT'."
                    }
                ],
                temperature=0.1,
                max_tokens=512,
                stream=False
            )

            content = response.choices[0].message.content.strip()
            if content.startswith("CONSISTENT"):
                return False, answer
            else:
                corrected_answer = content
                if "Corrected answer:" in content:
                    parts = re.split(r"Corrected answer:", content, 1)
                    if len(parts) > 1:
                        corrected_answer = parts[1].strip()
                return True, corrected_answer

        except Exception as e:
            if attempt < retries - 1:
                time.sleep(2 ** attempt)
                continue
            print(f"\nReview failed for item #{index}: {str(e)}")
            return False, answer
    return False, answer

def main(input_file, output_file, modifications_file, api_key):
    client = OpenAI(
        api_key=api_key,
        base_url="https://api.deepseek.com",
        timeout=90
    )
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    modifications = []
    for idx, item in enumerate(tqdm(data, desc="Reviewing answers", unit="item")):
        if "model_reasoning_output" in item and "model_answer_output" in item:
            question = item["question"]
            reasoning = item["model_reasoning_output"]
            answer = item["model_answer_output"]
        elif "question" in item and "reasoning" in item and "answer" in item:
            question = item["question"]
            reasoning = item["reasoning"]
            answer = item["answer"]
        else:
            print(f"Skipping item #{idx}: Missing required fields")
            continue
        modified, new_answer = review_answer(
            client,
            question,
            reasoning,
            answer,
            idx
        )
        if modified:
            modifications.append({
                "index": idx,
                "question": question,
                "original_answer": answer,
                "new_answer": new_answer
            })
            if "model_answer_output" in item:
                data[idx]["model_answer_output"] = new_answer
            else:
                data[idx]["answer"] = new_answer
            print(f"\nModified item #{idx}")
            print(f"Original: {answer[:100]}...")
            print(f"New: {new_answer[:100]}...")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
    with open(modifications_file, "w", encoding="utf-8") as f:
        json.dump(modifications, f, indent=2, ensure_ascii=False)
    print(f"\nProcessed {len(data)} items")
    print(f"Modified {len(modifications)} answers")
    print(f"Results saved to {output_file}")
    print(f"Modification records saved to {modifications_file}")

if __name__ == "__main__":
    INPUT_FILE = ""
    OUTPUT_FILE = ""
    MODIFICATIONS_FILE = ""
    DEEPSEEK_API_KEY = "API_KEY"
    main(INPUT_FILE, OUTPUT_FILE, MODIFICATIONS_FILE, DEEPSEEK_API_KEY)