# type: ignore

import json
import re
from collections import defaultdict


def count_digits(question):
    """Extract digit counts from multiplication question like '461 * 603'"""
    nums = [int(n) for n in re.findall(r'\d+', question)]
    if len(nums) != 2:
        return None, None
    return len(str(nums[0])), len(str(nums[1]))


def get_difficulty(x_digit, y_digit):
    """Classify multiplication difficulty by max number of digits"""
    if x_digit is None or y_digit is None:
        return "Unknown"
    
    max_digits = max(x_digit, y_digit)
    if max_digits <= 5:
        return "Easy ID"
    elif max_digits <= 8:
        return "Hard ID"
    else:
        return "Hard OOD"


def extract_result(response):
    """Extract the result from response text using regex"""
    # Try to find JSON result first
    json_match = re.search(r'{\s*"result"\s*:\s*(\d+)', response)
    if json_match:
        return int(json_match.group(1))
    
    # Try to find standalone number (last number in response)
    numbers = re.findall(r'\b\d+\b', response)
    if numbers:
        return int(numbers[-1])
    
    return None


def verify_result(question, response_text):
    """Verify if the reported multiplication result is correct"""
    nums = [int(n) for n in re.findall(r'\d+', question)]
    if len(nums) != 2:
        return False
    
    reported_result = extract_result(response_text)
    if reported_result is None:
        return False
        
    actual_result = nums[0] * nums[1]
    if actual_result == reported_result:
        return True
    else:
        return False

def analyze_results(input_file):
    with open(input_file) as f:
        data = json.load(f)
    
    stats = defaultdict(lambda: defaultdict(lambda: {
        'responses': 0,
        'errors': 0,
        'correct': 0,
        'incorrect': 0,
        'problems': []
    }))
    
    for entry in data:
        # Extract basic info
        question = entry["question"]
        model = entry["model"]
        x_digit, y_digit = count_digits(question)
        if x_digit is None or y_digit is None:
            continue
        
        difficulty = get_difficulty(x_digit, y_digit)
        
        # Store problem details
        problem_info = {
            "question": question,
            "x_digit": x_digit,
            "y_digit": y_digit
        }
        stats[model][difficulty]['problems'].append(problem_info)
        
        # Check if it's an error or response
        if "error" in entry:
            stats[model][difficulty]['errors'] += 1
        else:
            stats[model][difficulty]['responses'] += 1
            response_text = entry["response"]
            is_correct = verify_result(question, response_text)
            if is_correct:
                stats[model][difficulty]['correct'] += 1
            else:
                stats[model][difficulty]['incorrect'] += 1
    
    # Calculate statistics
    results = []
    for model in stats:
        for difficulty in stats[model]:
            total = stats[model][difficulty]['responses'] + stats[model][difficulty]['errors']
            correct = stats[model][difficulty]['correct']
            responses = stats[model][difficulty]['responses']
            accuracy = (correct / responses * 100) if responses > 0 else 0
            
            # Calculate average digits for problems in this category
            problems = stats[model][difficulty]['problems']
            avg_x_digit = sum(p["x_digit"] for p in problems) / len(problems) if problems else 0
            avg_y_digit = sum(p["y_digit"] for p in problems) / len(problems) if problems else 0
            
            results.append({
                "model": model,
                "difficulty": difficulty,
                "total_entries": total,
                "responses": stats[model][difficulty]['responses'],
                "errors": stats[model][difficulty]['errors'],
                "correct": correct,
                "incorrect": stats[model][difficulty]['incorrect'],
                "accuracy": f"{accuracy:.2f}%",
                "response_rate": f"{(responses/total*100):.2f}%" if total > 0 else "0.00%"
            })
    
    return results


def main():
    import sys
    if len(sys.argv) < 2 or len(sys.argv) > 3:
        print("Usage: python process_multiplication_results.py <input_file.json> [output_file.json]")
        sys.exit(1)

    results = analyze_results(sys.argv[1])
    print(json.dumps(results, indent=2))

    # Save to output file if specified
    if len(sys.argv) == 3:
        with open(sys.argv[2], 'w') as f:
            json.dump(results, f, indent=2)
