#!/usr/bin/env python3
"""
Generate/Update dataset.json with normalized names based on Ascend C naming rules.

This script treats `dataset.json` as the source of truth for:
- Operator ID (key)
- Category
- Level
- PascalCase Name (pascal_name)

It strictly regenerates `normalized_name` from `pascal_name` using the msopgen conversion logic:
1. PascalCase -> snake_case
2. Digits: "2d" -> "2d" (no underscore) vs "2D" -> "2_d".

Usage:
    python3 scripts/generate_dataset_json.py
"""

import json
import os
import re

def msopgen_convert(pascal_str: str) -> str:
    """
    Simulates msopgen conversion from PascalCase to snake_case.
    
    Rules (from Ascend C documentation):
    1. First char Upper -> lower.
    2. Upper char: if prev is lower or digit -> insert '_' and lower.
    3. Upper char: if prev is Upper AND next is lower -> insert '_' and lower.
    4. Other Upper -> lower.
    5. Others -> keep.
    
    This ensures that digits like '2' followed by lowercase 'd' (e.g. '2d') do NOT trigger an underscore.
    """
    if not pascal_str:
        return ""
    
    res = []
    chars = list(pascal_str)
    
    # 1. First char
    if chars[0].isupper():
        res.append(chars[0].lower())
    else:
        res.append(chars[0])
        
    for i in range(1, len(chars)):
        c = chars[i]
        prev = chars[i-1]
        
        # Lookahead for Rule 3
        next_c = chars[i+1] if i + 1 < len(chars) else None
        
        if c.isupper():
            if prev.islower() or prev.isdigit():
                # Rule 2: Preceded by lower or digit -> insert _
                res.append('_')
                res.append(c.lower())
            elif prev.isupper() and (next_c and next_c.islower()):
                # Rule 3: Upper run ending -> insert _ before last Upper
                res.append('_')
                res.append(c.lower())
            else:
                # Rule 4: Just lower it
                res.append(c.lower())
        else:
            # Rule 4 / 5: Keep as is (including lower chars and digits)
            res.append(c)
            
    return "".join(res)


def main():
    repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    dataset_path = os.path.join(repo_root, "examples", "dataset.json")
    
    if not os.path.exists(dataset_path):
        print(f"Error: {dataset_path} not found.")
        return

    print(f"Loading {dataset_path}...")
    with open(dataset_path, 'r') as f:
        data = json.load(f)
        
    changed_count = 0
    total_count = 0
    
    for op_name, info in data.items():
        total_count += 1
        pascal_name = info.get("pascal_name")
        current_normalized = info.get("normalized_name")
        
        if not pascal_name:
            print(f"[WARN] No pascal_name for {op_name}")
            continue
            
        # Strip specific suffixes if needed (e.g. 'Custom' is appended for kernel class, 
        # but normalized_name is base filename).
        # Our convention: pascal_name in dataset.json includes "Custom" suffix?
        # Let's check: "Conv...Custom".
        # We want normalized name to be "conv..." WITHOUT "_custom"? 
        # Wait, previous fix script assumed normalized_name does NOT have "_custom".
        # But msopgen generally appends it?
        # Let's verify what dataset.json currently holds.
        # "normalized_name": "conv_standard2d..."
        # "pascal_name": "ConvStandard2d...Custom"
        # So we strip "Custom".
        
        if pascal_name.endswith("Custom"):
            base_pascal = pascal_name[:-6]
        else:
            base_pascal = pascal_name
            
        new_normalized = msopgen_convert(base_pascal)
        
        if new_normalized != current_normalized:
            print(f"Updating {op_name}: {current_normalized} -> {new_normalized}")
            info["normalized_name"] = new_normalized
            changed_count += 1

    # Sort by Level (level1 -> level2) then by Natural Key (1_... -> 2_...)
    def sort_key(item):
        key = item[0]
        data_dict = item[1]
        
        # 1. Level rank
        level_str = data_dict.get('level', '')
        level_match = re.search(r'level(\d+)', level_str)
        level_rank = int(level_match.group(1)) if level_match else 999
        
        # 2. Natural key (numeric prefix)
        prefix_match = re.match(r'^(\d+)', key)
        prefix_rank = int(prefix_match.group(1)) if prefix_match else float('inf')
        
        return (level_rank, prefix_rank, key)

    sorted_data = dict(sorted(data.items(), key=sort_key))

    # Always write to ensure sorting is applied, even if no content changed
    print(f"Sorting and saving {total_count} entries...")
    with open(dataset_path, 'w') as f:
        json.dump(sorted_data, f, indent=4)
    print("Done.")

if __name__ == "__main__":
    main()
