import json
import os
from tqdm import tqdm

# === Paths ===
DATA_DIR = "data/verification_test_50/"
OUTPUT_DIR = "backward_results/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# === Functions ===
from utils.deepseek_api_utils import ChatBot
from utils.backward_prompt import (
    baseline_verification_one_shot,
    traditional_verification_one_shot,
    backward_verification_one_shot,
)

# ======================================
# Load all 250 samples from all 5 datasets
# ======================================
def load_all_datasets():
    files = [
        "addsub_50_test.json",
        "aqua_50_test.json",
        "du_50_test.json",
        "gsm8k_50_test.json",
        "lastletter_50_test.json",
    ]
    all_data = []
    for f in files:
        path = os.path.join(DATA_DIR, f)
        with open(path, "r", encoding="utf-8") as fp:
            items = json.load(fp)
            for item in items:
                item["source_file"] = f
            all_data.extend(items)
    return all_data


def count_tokens(text):
    return len(text.split())


# ======================================
# Evaluate 2 baselines and our backward verification method
# ======================================
def build_prompt(prompt_header, question, reasoning):
    """
    question: data["question"]
    reasoning: data["candidate_answer"]
    """
    return f"{prompt_header}\n\nOriginal question:\n{question}\n\nReasoning trace:\n{reasoning}\n"


def run_evaluation(data_list, prompt_header, save_path, step_save=10):
    # ============================
    # 1. load checkpoint (if exists)
    # ============================
    if os.path.exists(save_path):
        print(f"[Resume] Loading existing checkpoint: {save_path}")
        with open(save_path, "r", encoding="utf-8") as fp:
            results = json.load(fp)
        processed_ids = {item["id"] for item in results}
        print(f"[Resume] Found {len(results)} processed samples.")
    else:
        print(f"[Start] No checkpoint found. Starting from scratch.")
        results = []
        processed_ids = set()

    # ============================
    # 2. continue from checkpoint
    # ============================
    for idx, item in enumerate(tqdm(data_list, desc=f"Running {save_path}")):

        # already processed
        if idx in processed_ids:
            continue

        question = item["question"]
        reasoning = item["candidate_answer"]

        full_prompt = build_prompt(prompt_header[0], question, reasoning)

        response = ChatBot.call_chat_deepseek(full_prompt)

        # cnt tokens
        prompt_tokens = count_tokens(full_prompt)
        output_tokens = count_tokens(response)

        results.append({
            "id": idx,
            "source_file": item["source_file"],
            "question": question,
            "reasoning": reasoning,
            "model_output": response,
            "prompt_tokens": prompt_tokens,
            "output_tokens": output_tokens,
            "total_tokens": prompt_tokens + output_tokens,
        })

        processed_ids.add(idx)

        if len(processed_ids) % step_save == 0:
            with open(save_path, "w", encoding="utf-8") as fp:
                json.dump(results, fp, indent=2)
            print(f"[Checkpoint saved] {save_path} ({len(processed_ids)} samples)")

    # ============================
    # 3. save all
    # ============================
    with open(save_path, "w", encoding="utf-8") as fp:
        json.dump(results, fp, indent=2)

    print(f"[Final saved] {save_path} ({len(results)} samples)")



# ======================================
# Main
# ======================================
if __name__ == "__main__":
    data = load_all_datasets()  # 250

    # 1. Full backward verification baseline
    run_evaluation(
        data,
        baseline_verification_one_shot,
        os.path.join(OUTPUT_DIR, "full_backward_250.json"),
        step_save=10,
    )

    # 2. Traditional backward verification
    run_evaluation(
        data,
        traditional_verification_one_shot,
        os.path.join(OUTPUT_DIR, "traditional_backward_250.json"),
        step_save=10,
    )

    # 3. our backward verification
    run_evaluation(
        data,
        backward_verification_one_shot,
        os.path.join(OUTPUT_DIR, "one_shot_backward_250.json"),
        step_save=10,
    )


