import os
import json


def get_question_type(question: str) -> str:
    """Identify the question type from its template."""
    if "Does the character" in question and "appear in the text?" in question:
        return 'char_exists'
    elif "How many times does the character" in question and "appear in the text?" in question:
        return 'char_count'
    elif "At what position(s) does the character" in question and "appear in the text?" in question:
        return 'char_position'
    elif "What is the character at position" in question and "in the text?" in question:
        return 'char_at_pos'
    elif "How many characters are there in the text?" in question:
        return 'string_length'
    else:
        return 'unknown'


def separate_and_evaluate(input_file_path: str):
    """Evaluate accuracy by question type."""
    try:
        with open(input_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        type_results = {
            'char_exists': [],
            'char_count': [],
            'char_position': [],
            'char_at_pos': [],
            'string_length': [],
        }

        # Group samples by question type
        for item in data:
            q_type = get_question_type(item['question'])
            type_results[q_type].append(item)

        # Calculate accuracy per type
        for q_type, items in type_results.items():
            if not items:
                print(f"{q_type}: No data available.")
                continue

            correct = 0
            for item in items:
                try:
                    prediction = item.get('predict', '').split("\n")[-1].replace(' ', '')
                    if not prediction:
                        continue
                    correct_label = item['answers'][0].upper()
                    # Check last or first character in the last line
                    if correct_label == prediction[-1].upper():
                        correct += 1
                except:
                    continue

            accuracy = correct / len(items)
            print(f"{q_type} Accuracy: {accuracy * 100:.2f}%")

    except FileNotFoundError:
        print(f"File not found: {input_file_path}")
    except Exception as e:
        print(f"Error while processing file: {e}")


if __name__ == '__main__':
    json_dir = "/xxx/xxx/output"  # Replace with your actual directory

    for file in sorted(os.listdir(json_dir)):
        if '3b' in file and 'ocrbench' not in file:
            file_path = os.path.join(json_dir, file)

            if 'instruct' in file:
                separate_and_evaluate(file_path)

            else:
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                except:
                    print(f"[Error reading file] {file_path}")
                    continue

                correct = 0
                for item in data:
                    try:
                        prediction = item.get('predict', '').split("\n")[-1].replace(' ', '')
                        if not prediction:
                            continue
                        correct_label = item['answers'][0].upper()
                        if correct_label == prediction[-1].upper():
                            correct += 1
                    except:
                        print(item.get('predict', ''))
                        continue

                print(f"{file} Accuracy: {correct / len(data):.4f}")
