import os
import json
import jsonlines

from tqdm import tqdm
from openai import OpenAI

def write_json(datas, path):
    with open(path, 'w', encoding='utf-8') as fr:
        json.dump(datas, fr, indent=4)

def build_prompt(query, gt_answer, gt_facts, answer1, answer2):
    sys_prompt = """
    ---Role---
    You are an expert judge evaluating two answers to the same question **against a ground-truth (GT) summary** and a canonical list of **GT facts**.
    Judge on two criteria: **SemanticSimilarity** and **Faithfulness** (broad, fact-anchored).
    Use only the **Question**, the **GT summary**, and the **GT facts**; do not rely on outside knowledge.
    When in doubt, favor the answer that better matches the GT summary and GT facts.
    """

    prompt = f"""
    You will evaluate two answers to the same question using two criteria: **SemanticSimilarity** and **Faithfulness**.
    Treat the **GT summary** as the reference narrative and the **GT facts** as the canonical checklist.

    Definitions and rubric:

    - **SemanticSimilarity**: How closely does the answer's meaning align with the GT summary?
      *Credit paraphrases and synonyms; do not require verbatim overlap. Penalize semantic drift.*

    - **Faithfulness**:
      1) **No contradictions** with GT facts or the GT summary.
      2) **Coverage of GT facts** — Enumerate GT facts as [F1]..[Fn] and assess how many are clearly present. **Omissions reduce Faithfulness**, but weigh them **less than** any material contradiction.

      **Decision rule for Faithfulness**:
      - Score with less contradiction and greater coverages.

    Instructions:
    1) Base judgments **only** on the Question, GT summary, and GT facts (ignore external knowledge).
    2) In explanations, **cite fact indices** like [F2], [F4] when referring to specific GT facts, and note coverage (e.g., "Covered: F1,F3,F4; Missing: F2").
    3) For each criterion, pick a winner (**Answer 1** or **Answer 2**) and explain why, referencing fact indices where relevant.

    Here are the GT facts (treat as canonical; enumerate internally as [F1]..[Fn]):
    {gt_facts}

    Here is the question:
    {query}

    Here is the ground truth (GT) answer summary:
    {gt_answer}

    Here are the two answers:

    **Answer 1 Start**
    {answer1}
    **Answer 1 End**

    **Answer 2 Start**
    {answer2}
    **Answer 2 End**

    Evaluate both answers using the two criteria above and provide detailed explanations for each criterion. In explanations, reference GT facts by index (e.g., [F3]) whenever applicable.

    Output your evaluation in the following JSON format (do not include additional fields):

    {{
        "SemanticSimilarity": {{
            "Winner": "[Answer 1 or Answer 2]",
            "Explanation": "[Explain which answer more closely matches the GT meaning. Refer to specific GT phrases; cite facts like [F1], [F2] when helpful.]"
        }},
        "Faithfulness": {{
            "Winner": "[Answer 1 or Answer 2]",
            "Explanation": "[Report any **material** contradictions (if any) and compare **coverage** (e.g., Covered: F1,F3; Missing: F2,F4). Cite facts such as [F2], [F5].]"
        }}
    }}
    """
    return sys_prompt, prompt

# This evaluation process eliminates positional bias
def batch_eval(result1_file, result2_file, output_file_path):
    client = OpenAI()
    with open(result1_file, "r") as f:
        answers1 = json.load(f)
    answers1 = [i for i in answers1]
    
    with open(result2_file, "r") as f:
        answers2 = json.load(f)
    answers2 = [i for i in answers2]
    print(len(answers1) , len(answers2))
    
    requests = []
    for i, (a1, a2) in enumerate(zip(answers1, answers2)):
        query1  = a1['query']
        query2  = a2['query']

        assert query1 == query2
        query = query1

        gt_answer = a1['gt_answer']
        gt_topics = a1['topics']
        answer1 = a1['result']
        answer2 = a2['result']

        sys_prompt, prompt = build_prompt(query, gt_answer, gt_topics, answer1, answer2)

        request_data = {
            "custom_id": f"request-{i+1}",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": "o4-mini",
                "messages": [
                    {"role": "system", "content": sys_prompt},
                    {"role": "user", "content": prompt},
                ],
            },
        }
        requests.append(request_data)

    with jsonlines.open(output_file_path, mode="w") as writer:
        for request in requests:
            writer.write(request)
    print(f"Batch API requests written to {output_file_path}")


    batch_input_file = client.files.create(
        file=open(output_file_path, "rb"), purpose="batch"
    )
    batch_input_file_id = batch_input_file.id
    batch = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={"description": "nightly eval job"},
    )
    print(f"Batch {batch.id} has been created.")

    return batch.id, batch

pairwise_compares = {
    "world_history": {
        "ours": "./...json",
        "baselines": {
            # simple baseline
            "without_context": "./...json",
            # embedding methods
            "naiverag": "./...json",
            "visrag": "./...json",
            "gme": "./...json",
            "colqwen": "...json",
            # graph methods
            "graphrag": "...json",
            "lightrag": "...json",
            "hipporag2": "...json",
        }
    },
    "dlcv": {
        "ours": "./...json",
        "baselines": {
            # simple baseline
            "without_context": "./...json",
            # embedding methods
            "naiverag": "./...json",
            "visrag": "./...json",
            "gme": "./...json",
            "colqwen": "...json",
            # graph methods
            "graphrag": "...json",
            "lightrag": "...json",
            "hipporag2": "...json",
        }
    },
    "environmental_reports": {
        "ours": "./...json",
        "baselines": {
            # simple baseline
            "without_context": "./...json",
            # embedding methods
            "naiverag": "./...json",
            "visrag": "./...json",
            "gme": "./...json",
            "colqwen": "...json",
            # graph methods
            "graphrag": "...json",
            "lightrag": "...json",
            "hipporag2": "...json",
        }
    },
    "picture_books": {
        "ours": "./...json",
        "baselines": {
            # simple baseline
            "without_context": "./...json",
            # embedding methods
            "naiverag": "./...json",
            "visrag": "./...json",
            "gme": "./...json",
            "colqwen": "...json",
            # graph methods
            "graphrag": "...json",
            "lightrag": "...json",
            "hipporag2": "...json",
        }
    },
}

results = {}
for subset_name, subset in tqdm(pairwise_compares.items()):
    ours_path = subset['ours']
    baselines = subset['baselines']

    output_dir = '/'.join(ours_path.split('/')[:-1])
    output_dir = os.path.join(output_dir, 'batch_eval')
    os.makedirs(output_dir, exist_ok=True)

    for baseline_name, baseline_path in baselines.items():
        # Original (Ours vs Others)
        print(f'Evaluating: ours vs {baseline_name}.')
        filename = f'{subset_name}_ours_vs_{baseline_name}'
        output_path = os.path.join(output_dir, f'{filename}.json')
        output_jsonl_path = os.path.join(output_dir, f'{filename}.jsonl')

        if not os.path.exists(output_jsonl_path):
            batch_id, info = batch_eval(
                result1_file=ours_path,
                result2_file=baseline_path,
                output_file_path=output_path
            )

            info = info.__dict__
            info["request_counts"] = info["request_counts"].__dict__

            results[filename] = {
                'input_file': output_path,
                'batch_id': batch_id,
                'info': info
            }
        else:
            print(f'Skip: {filename}')

        # Swap (Others vs Ours)
        filename = f'{filename}_swap'
        output_path = os.path.join(output_dir, f'{filename}.json')
        output_jsonl_path = os.path.join(output_dir, f'{filename}.jsonl')
        
        if not os.path.exists(output_jsonl_path):
            batch_id, info = batch_eval(
                result1_file=baseline_path,
                result2_file=ours_path,
                output_file_path=output_path
            )

            info = info.__dict__
            info["request_counts"] = info["request_counts"].__dict__

            results[filename] = {
                'input_file': output_path,
                'batch_id': batch_id,
                'info': info
            }
        else:
            print(f'Skip: {filename}')
            
    output_dir = '/'.join(ours_path.split('/')[:-2])
    output_result_path = os.path.join(output_dir, f'batchcall.json')
    write_json(results, output_result_path)
