import os
import re

import tabulate

from benchmarks.scorers.gaia_scorer import check_close_call, question_scorer

# Please change the result_path to the path of the results you want to analyze

result_path = [
    "/path/to/your/autogen/python/packages/agbench/benchmarks/GAIA/Results/gaia_validation_level_1__MagenticOne",
    "/path/to/your/autogen/python/packages/agbench/benchmarks/GAIA/Results/gaia_validation_level_2__MagenticOne",
    "/path/to/your/autogen/python/packages/agbench/benchmarks/GAIA/Results/gaia_validation_level_3__MagenticOne",
]

def extract_cost_from_log(log_content):
    # Find all lines containing "Total cost"
    cost_lines = re.findall(r"Total Cost: \$(\d+\.\d+)", log_content)
    # Convert to float and sum
    assert len(cost_lines) == 1, f"Expected exactly one cost line, got {len(cost_lines)}"
    return float(cost_lines[0])

def extract_final_answer(log_content):
    """Extract the final answer from log content."""
    final_answer = None
    for line in log_content.split('\n'):
        if line.startswith("FINAL ANSWER:"):
            final_answer = line.split("FINAL ANSWER:", 1)[1].strip()
    return final_answer

def check_answer(model_answer, expected_answer_path):
    """Compare the final answer with expected answer."""
    try:
        with open(expected_answer_path, 'r', encoding='utf-8') as f:
            ground_truth = f.read().strip()
    except Exception as e:
        print(f"Error reading expected answer: {e}")
        return False, False
    
    if model_answer is None:
        return False, False
        
    is_correct = question_scorer(model_answer, ground_truth)
    is_close = check_close_call(model_answer, ground_truth, is_correct)
    return is_correct, is_close

def process_directory(directory):
    total_cost = 0.0
    console_log_count = 0
    correct_answers = 0
    close_calls = 0
    
    # Walk through all subdirectories
    for root, _, files in os.walk(directory):
        expected_answer_path = os.path.join(root, "expected_answer.txt")
        if not os.path.exists(expected_answer_path):
            continue
            
        for file in files:
            if file == "console_log.txt":
                console_log_count += 1
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        content = f.read()
                        cost = extract_cost_from_log(content)
                        total_cost += cost
                        
                        # Extract and check answer
                        model_answer = extract_final_answer(content)
                        is_correct, is_close = check_answer(model_answer, expected_answer_path)
                        
                        if is_correct:
                            correct_answers += 1
                        if is_close:
                            close_calls += 1
                            
                        print(f"Found cost in {file_path}: ${cost}")
                        print(f"Answer: {model_answer} - Correct: {is_correct}, Close call: {is_close}")
                except Exception as e:
                    print(f"Error reading {file_path}: {str(e)}")
    
    return total_cost, console_log_count, correct_answers, close_calls

def main():
    result_list = []
    for path in result_path:
        if not os.path.exists(path):
            print(f"Directory not found: {path}")
            continue
            
        print(f"\nProcessing directory: {path}")
        total_cost, console_log_count, correct_answers, close_calls = process_directory(path)
        test_name = path.split("/")[-1]
        accuracy = (correct_answers / console_log_count) if console_log_count > 0 else 0
        cost_per_solved_question = total_cost / correct_answers if correct_answers > 0 else 0
        result_list.append((test_name, total_cost, console_log_count, correct_answers, close_calls, accuracy, cost_per_solved_question))
        
    # Add summary to the end of the result_list
    total_cost = sum([result[1] for result in result_list])
    total_questions = sum([result[2] for result in result_list])
    total_correct = sum([result[3] for result in result_list])
    total_close_calls = sum([result[4] for result in result_list])
    total_accuracy = (total_correct / total_questions) if total_questions > 0 else 0
    total_cost_per_solved_question = total_cost / total_correct if total_correct > 0 else 0
    result_list.append(("Summary", 
                        total_cost, 
                        total_questions, 
                        total_correct, 
                        total_close_calls, 
                        total_accuracy, 
                        total_cost_per_solved_question))

    # tabulate the cost_list with answer statistics
    print()
    print(tabulate.tabulate(
        result_list, 
        headers=["Test Name", "Total Cost [$]", "Total Questions", "Correct", "Close Calls", "Accuracy", "Cost/Solved [$]"], 
        tablefmt="simple",
        floatfmt=".4f"
    ))

if __name__ == "__main__":
    main()
