import os
import json
from collections import defaultdict
from pathlib import Path

def extract_response_components(llm_response):
    """
    Extracts the strategy and answer components from the LLM response.
    Accepts either the raw string or a Choice-like object with .message.content,
    or a Candidate-like object with .text.
    """

    if isinstance(llm_response, str):
        text = llm_response
    elif hasattr(llm_response, "message") and hasattr(llm_response.message, "content"):
        text = llm_response.message.content
    elif hasattr(llm_response, "text"):
        text = llm_response.text
    else:
        print("\nUnknown response format.\n")
        return "No strategy", "No answer"

    # find the start of the strategy and answer sections
    if text:
        strategy_start = text.find("[Strategy]:")
        answer_start = text.find("[Answer]:")
    else:
        return "No strategy", "No answer"
    
    if strategy_start == -1 or answer_start == -1:
        print("\nResponse format is incorrect. Expected format: [Strategy]: ... [Answer]: ...\n")
        return "No strategy", "No answer"
    
    strategy = text[strategy_start + len("[Strategy]:"):answer_start].strip()
    answer = text[answer_start + len("[Answer]:"):].strip()
    
    return strategy, answer

def safe_update_data(problem_ind, new_data, file_dir):
    """Safely updates the result in the file."""

    file_path = Path(file_dir)
    temp_path = file_path.with_suffix(".tmp")

    # read the backup file
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    # update the data
    data[problem_ind] = {**data[problem_ind], **new_data}
    
    # write the updated data to a temporary file and then rename it to the backup file
    # this is an atomic operation, so it won't corrupt the file if the process is interrupted
    try:
        with open(temp_path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=4, ensure_ascii=False)
        os.replace(temp_path, file_path)
    finally:
        if temp_path.exists():
            temp_path.unlink()

def compute_accuracy(dir, key):
    """Computes the accuracy of the LLM based on the results in the file."""

    with open(dir, "r") as f:
        data = json.load(f)

    correct = sum(1 for item in data if item[key])
    accuracy = correct / len(data)
    
    return accuracy, correct, len(data)

def stats_pipeline(dir, n=None):
    """Computes the statistics of the pipeline based on the results in the file."""

    with open(dir, "r") as f:
        data = json.load(f)
    if n:
        data = data[:n]

    stats = {
        "total_problems": sum(bool(item.get("question", 0)) for item in data),
        "weak_helped_by_strategy": sum(bool(item.get("weak_llm_helped_by_strategy", 0)) for item in data),
        "weak_total_correct": sum(bool(item.get("weak_llm_accuracy", 0)) for item in data),
        "weak_llm_used": sum(bool(item.get("weak_llm_used", 0)) for item in data),
        "weak_llm_used_correct": sum(1 for item in data if item.get("weak_llm_used", 0) and item.get("weak_llm_accuracy", 0)),
        "strong_total_correct": sum(bool(item.get("strong_llm_accuracy", 0)) for item in data),
        "strong_llm_used": sum(bool(item.get("strong_llm_used", 0)) for item in data),
        "strong_llm_used_correct": sum(1 for item in data if item.get("strong_llm_used", 0) and item.get("strong_llm_accuracy", 0)),
        "unaccept_problems": sum(1 for item in data if not item.get("strong_llm_used", 0) and not item.get("weak_llm_used", 0)),
    }

    print(f"Total number of problems: {stats['total_problems']}")
    print(f"Uncover problems: {stats['unaccept_problems']}\n")
    print(f"Accept weak llm: {stats['weak_llm_used']}")
    print(f"Number of correct answers of weak llm: {stats['weak_total_correct']}")
    print(f"Number of correct answers of weak llm that are used: {stats['weak_llm_used_correct']}")
    print(f"Number of weak llm helped by strategy: {stats['weak_helped_by_strategy']}\n")
    print(f"Accept strong llm: {stats['strong_llm_used']}")   
    print(f"Number of correct answers of strong llm: {stats['strong_total_correct']}")
    print(f"Number of correct answers of strong llm that are used: {stats['strong_llm_used_correct']}")
    print()

    return stats

def stats_pipeline_category(dir, category_name="category", n=None):
    """Computes the statistics of the pipeline grouped by category."""

    with open(dir, "r") as f:
        data = json.load(f)
    if n:
        data = data[:n]

    grouped = defaultdict(list)
    for item in data:
        cat = item.get(category_name, "Unknown")
        grouped[cat].append(item)

    results = {}
    for cat, items in grouped.items():
        stats = {
            "total_problems": sum(bool(item.get("question", 0)) for item in items),
            "weak_helped_by_strategy": sum(bool(item.get("weak_llm_helped_by_strategy", 0)) for item in items),
            "weak_total_correct": sum(bool(item.get("weak_llm_accuracy", 0)) for item in items),
            "weak_llm_used": sum(bool(item.get("weak_llm_used", 0)) for item in items),
            "weak_llm_used_correct": sum(1 for item in items if item.get("weak_llm_used", 0) and item.get("weak_llm_accuracy", 0)),
            "strong_total_correct": sum(bool(item.get("strong_llm_accuracy", 0)) for item in items),
            "strong_llm_used": sum(bool(item.get("strong_llm_used", 0)) for item in items),
            "strong_llm_used_correct": sum(1 for item in items if item.get("strong_llm_used", 0) and item.get("strong_llm_accuracy", 0)),
            "unaccept_problems": sum(1 for item in items if not item.get("strong_llm_used", 0) and not item.get("weak_llm_used", 0)),
        }
        results[cat] = stats

    for cat, stats in results.items():
        print(f"===== Category: {cat} =====")
        print(f"Total number of problems: {stats['total_problems']}")
        print(f"Uncover problems: {stats['unaccept_problems']}\n")
        print(f"Accept weak llm: {stats['weak_llm_used']}")
        print(f"Number of correct answers of weak llm: {stats['weak_total_correct']}")
        print(f"Number of correct answers of weak llm that are used: {stats['weak_llm_used_correct']}")
        print(f"Number of weak llm helped by strategy: {stats['weak_helped_by_strategy']}\n")
        print(f"Accept strong llm: {stats['strong_llm_used']}")   
        print(f"Number of correct answers of strong llm: {stats['strong_total_correct']}")
        print(f"Number of correct answers of strong llm that are used: {stats['strong_llm_used_correct']}")
        print()

    return results

