import json
import numpy as np

def verify_json_format(raw_data):
    for index, item in enumerate(raw_data):
        try:
            rubric_tree = item['model_response']
            rubric_tree = json.loads(rubric_tree.replace("```json", "").replace("```", "").strip())
            item['rubric_tree'] = rubric_tree
            item.pop('model_response')
        except Exception as e:
            print(f"Error processing {index+1}/{len(raw_data)}, id: {item['question_id']}, error: {e}")
            continue
    
def verify_tree_structure(raw_data):
    def verify_tree_structure_helper(tree_item):
        def _verify_node_structure(node):
            # Each node must be a dictionary with "description" and "children"
            if not isinstance(node, dict) or "description" not in node or "children" not in node:
                return False
            # Description must be a non-empty string
            if not isinstance(node["description"], str) or not node["description"]:
                return False
            # Children must be a list or None
            if node["children"] is not None and not isinstance(node["children"], list):
                return False
            # If it has children, verify them recursively
            if isinstance(node["children"], list):
                for child in node["children"]:
                    if not _verify_node_structure(child):
                        return False
            return True

        # 1. Verify top-level keys
        basic_keys = ['intention', 'static', 'dynamic']
        if sorted(list(tree_item.keys())) != sorted(basic_keys):
            return False

        # 2. Verify structure for each main branch
        for key in basic_keys:
            branch = tree_item[key]
            if not _verify_node_structure(branch):
                return False
            # For top-level branches, 'children' should be a non-empty list
            if not isinstance(branch.get("children"), list) or not branch.get("children"):
                return False

        # 3. Specific checks for 'dynamic' branch
        dynamic_children = tree_item['dynamic']['children']
        if len(dynamic_children) != 2:
            return False

        # It must have exactly two children: one for "basic" interactions and one for "complex" interactions.

        return True

    for index, item in enumerate(raw_data):
        if not verify_tree_structure_helper(item['rubric_tree']):
            print(f"Error processing {index+1}/{len(raw_data)}, id: {item['question_id']}")
            continue


def statistics(raw_data):
    def _compute_depth(tree_item):
        if not isinstance(tree_item, dict) or "children" not in tree_item or tree_item["children"] is None:
            return 1
        
        children = tree_item["children"]
        if not children:
            return 1
        
        return 1 + max(_compute_depth(child) for child in children)
    
    def _compute_num_leaves(tree_item):
        # count '"children": null' in the json string
        return json.dumps(tree_item, indent=4).count('\"children\": null')
    
    depth_count = []
    num_leaves_count = []
    static_depth_count = []
    intention_depth_count = []
    dynamic_depth_count = []
    static_num_leaves_count = []
    intention_num_leaves_count = []
    dynamic_num_leaves_count = []

    for index, item in enumerate(raw_data):
        num_leaves = _compute_num_leaves(item['rubric_tree'])
        num_leaves_count.append(num_leaves)
        static_depth_count.append(_compute_depth(item['rubric_tree']['static']))
        intention_depth_count.append(_compute_depth(item['rubric_tree']['intention']))
        dynamic_depth_count.append(_compute_depth(item['rubric_tree']['dynamic']))
        depth_count.append(1 + max(static_depth_count[-1], intention_depth_count[-1], dynamic_depth_count[-1]))
        static_num_leaves_count.append(_compute_num_leaves(item['rubric_tree']['static']))
        intention_num_leaves_count.append(_compute_num_leaves(item['rubric_tree']['intention']))
        dynamic_num_leaves_count.append(_compute_num_leaves(item['rubric_tree']['dynamic']))

    print(f"Mean depth: {np.mean(depth_count)}. Max depth: {np.max(depth_count)}. Min depth: {np.min(depth_count)}")
    print(f"Mean num leaves: {np.mean(num_leaves_count)}. Max num leaves: {np.max(num_leaves_count)}. Min num leaves: {np.min(num_leaves_count)}")
    print(f"Mean static depth: {np.mean(static_depth_count)}. Max static depth: {np.max(static_depth_count)}. Min static depth: {np.min(static_depth_count)}")
    print(f"Mean intention depth: {np.mean(intention_depth_count)}. Max intention depth: {np.max(intention_depth_count)}. Min intention depth: {np.min(intention_depth_count)}")
    print(f"Mean dynamic depth: {np.mean(dynamic_depth_count)}. Max dynamic depth: {np.max(dynamic_depth_count)}. Min dynamic depth: {np.min(dynamic_depth_count)}")
    print(f"Mean static num leaves: {np.mean(static_num_leaves_count)}. Max static num leaves: {np.max(static_num_leaves_count)}. Min static num leaves: {np.min(static_num_leaves_count)}")
    print(f"Mean intention num leaves: {np.mean(intention_num_leaves_count)}. Max intention num leaves: {np.max(intention_num_leaves_count)}. Min intention num leaves: {np.min(intention_num_leaves_count)}")
    print(f"Mean dynamic num leaves: {np.mean(dynamic_num_leaves_count)}. Max dynamic num leaves: {np.max(dynamic_num_leaves_count)}. Min dynamic num leaves: {np.min(dynamic_num_leaves_count)}")


if __name__ == "__main__":
    with open('data/annotations/rubric.jsonl', 'r') as f:
        data = [json.loads(line) for line in f]
    statistics(data)