# type: ignore

import json
from collections import defaultdict
import re
import sys
import os


def validate_sudoku(question, response):
    # Check length is 81
    if len(response) != 81:
        return False, "Response must be exactly 81 digits long"
        
    # Check for zeros in response
    if '0' in response:
        return False, "Response contains zeros"
        
    # Verify non-zero question digits match response
    for q, r in zip(question, response):
        if q != '0' and q != r:
            return False, "Response modifies original question digits"
            
    # Convert response to 9x9 grid
    grid = [[int(response[i*9 + j]) for j in range(9)] for i in range(9)]
    
    # Validate rows, columns and 3x3 subgrids
    for i in range(9):
        # Check row
        row = grid[i]
        if len(set(row)) != 9:
            return False, f"Row {i+1} has duplicates"
            
        # Check column
        col = [grid[j][i] for j in range(9)]
        if len(set(col)) != 9:
            return False, f"Column {i+1} has duplicates"
            
        # Check 3x3 subgrid
        start_row, start_col = 3 * (i // 3), 3 * (i % 3)
        subgrid = []
        for x in range(start_row, start_row + 3):
            for y in range(start_col, start_col + 3):
                subgrid.append(grid[x][y])
        if len(set(subgrid)) != 9:
            return False, f"Subgrid {i+1} has duplicates"
            
    return True, "Valid sudoku solution"


def get_difficulty(question):
    """Classify sudoku difficulty by number of zeros (unknown cells)"""
    blanks = question.count('0')
    if blanks < 9:
        return "Easy OOD"
    elif 9 <= blanks < 36:
        return "Easy ID"
    elif blanks < 54:
        return "Hard ID"
    else:
        return "Hard OOD"

def analyze_sudoku_results(data):
    """Analyze sudoku results from a list of entries (dicts)"""
    stats = defaultdict(lambda: defaultdict(lambda: {
        'total': 0,
        'valid': 0,
        'invalid': 0,
        'no_response': 0,
        'error_types': {
            'contains_zero': 0,
            'length_error': 0,
            'digit_mismatch': 0,
            'duplicates': 0
        }
    }))
    
    for entry in data:
        model = entry.get("model", "unknown")  # Default to "unknown" if model field is missing
        question = entry.get("question_extracted", "")
        difficulty = get_difficulty(question)
        stats[model][difficulty]['total'] += 1
        
        if "error" in entry:
            stats[model][difficulty]['no_response'] += 1
            continue
            
        question = entry.get("question_extracted", "")
        response = entry.get("response_extracted", "")
        
        is_valid, message = validate_sudoku(question, response)
        
        if is_valid:
            stats[model][difficulty]['valid'] += 1
        else:
            stats[model][difficulty]['invalid'] += 1
            if "contains zeros" in message:
                stats[model][difficulty]['error_types']['contains_zero'] += 1
            elif "81 digits" in message:
                stats[model][difficulty]['error_types']['length_error'] += 1
            elif "modifies original" in message:
                stats[model][difficulty]['error_types']['digit_mismatch'] += 1
            elif "duplicates" in message:
                stats[model][difficulty]['error_types']['duplicates'] += 1
    
    # Calculate statistics
    results = []
    for model in stats:
        for difficulty in stats[model]:
            total = stats[model][difficulty]['total']
            valid = stats[model][difficulty]['valid']
            invalid = stats[model][difficulty]['invalid']
            accuracy = (valid / total * 100) if total > 0 else 0
            
            results.append({
                "difficulty": difficulty,
                "model": model,
                "total_entries": total,
                "valid_solutions": valid,
                "invalid_solutions": invalid,
                "no_response": stats[model][difficulty]['no_response'],
                "accuracy": f"{accuracy:.2f}%",
                "error_breakdown": stats[model][difficulty]['error_types']
            })
    
    return results



def extract_sudoku_solution(response, input_type):
    """
    Extracts a 81-digit Sudoku solution from various response formats.
    
    Args:
        response (str): The response text containing the Sudoku solution
        
    Returns:
        str: 81-digit string representing the Sudoku solution (empty string if not found)
    """
    # Find the "result" section (case insensitive)
    if input_type == 'response':
        result_match = re.search(r'result["\']?\s*:\s*(.*)', response, re.IGNORECASE | re.DOTALL)
        if not result_match:
            return ""
        result_content = result_match.group(1)
        
    else:
        result_content = response
    
    # Find all digits in the result content
    digits = re.findall(r'\d', result_content)
    
    return ''.join(digits[:81])


def process_json_file(input_file, output_file):
    """
    Processes a JSON file containing Sudoku responses, extracts solutions,
    and saves results to a new JSON file.
    """
    try:
        with open(input_file, 'r') as f:
            data = json.load(f)
        
        for entry in data:
            if 'question' in entry:
                entry['question_extracted'] = extract_sudoku_solution(entry['question'], 'question')
            if 'response' in entry:
                entry['response_extracted'] = extract_sudoku_solution(entry['response'], 'response')
        
        with open(output_file, 'w') as f:
            json.dump(data, f, indent=2)
            
        print(f"Successfully processed {len(data)} entries. Results saved to {output_file}")
    
    except Exception as e:
        print(f"Error processing file: {str(e)}")
        sys.exit(1)


def main():
    if len(sys.argv) < 3 or len(sys.argv) > 3:
        print("Usage: python process_sudoku_results.py <input_file.json> <output_file.json>")
        sys.exit(1)
    input_file = sys.argv[1]
    output_file = sys.argv[2]
    try:
        with open(input_file, 'r') as f:
            data = json.load(f)
        for entry in data:
            if 'question' in entry:
                entry['question_extracted'] = extract_sudoku_solution(entry['question'], 'question')
            if 'response' in entry:
                entry['response_extracted'] = extract_sudoku_solution(entry['response'], 'response')
    except Exception as e:
        print(f"Error processing file: {str(e)}")
        sys.exit(1)
    results = analyze_sudoku_results(data)
    print(json.dumps(results, indent=2))
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"Results saved to {output_file}")
