#!/usr/bin/env python3
"""Generate normalized op names in examples/dataset.py.

This script is intended to be run once to materialize normalized_name fields.
"""
import os
import re

# Operator implementation file name and kernel function name must be the same, both are the value after converting operator type to underscore naming format. The following describes the process of converting operator type to operator implementation file name and kernel function name:
# Convert uppercase character at the first position to lowercase. For example: Abc -> abc.
# If the character before an uppercase character is lowercase or a digit, insert an underscore "_" before the uppercase character and convert it to lowercase. For example: AbcDef -> abc_def.
# If the character before an uppercase character is uppercase and the character after it is lowercase, insert an underscore "_" before the uppercase character and convert it to lowercase. For example: AbcAAc -> abc_a_ac.
# Convert other uppercase characters to lowercase, keep lowercase characters unchanged.

def camel_to_snake(name: str) -> str:
    result = re.sub(r"ReLU", "Relu", name)
    result = re.sub(r"GELU", "Gelu", result)
    result = re.sub(r"SELU", "Selu", result)
    result = re.sub(r"\bELU\b", "Elu", result)
    result = re.sub(r"([a-z])([A-Z])", r"\1_\2", result)
    result = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", result)

    def _split_digit_letter(match: re.Match) -> str:
        digit, letter = match.group(1), match.group(2)
        if letter in ("d", "D"):
            next_char = match.string[match.end():match.end() + 1]
            if next_char == "" or next_char == "_" or next_char.isupper():
                return f"{digit}{letter}"
        return f"{digit}_{letter}"

    result = re.sub(r"(\d)([a-zA-Z])", _split_digit_letter, result)
    return result.lower()


def normalize_op_name(op_name: str) -> str:
    dim_words = {
        "0": "zero",
        "1": "one",
        "2": "two",
        "3": "three",
        "4": "four",
        "5": "five",
        "6": "six",
        "7": "seven",
        "8": "eight",
        "9": "nine",
        "10": "ten",
    }

    def _replace_leading_dim(match: re.Match) -> str:
        number = match.group(1)
        word = dim_words.get(number)
        if word:
            return f"{word}_dim"
        return f"dim{number}"

    op = re.sub(r"^\d+_", "", op_name)
    op = re.sub(r"[^0-9A-Za-z_]+", "_", op)
    op = re.sub(r"^(\d+)[dD](?=_|$)", _replace_leading_dim, op)
    op = re.sub(r"_(\d+)[dD](?=_|$)", r"\g<1>d", op)
    op = camel_to_snake(op)
    op = re.sub(r"(\d)_d(_|$)", r"\1d\2", op)
    op = op.strip("_")
    op = re.sub(r"_+", "_", op)
    if op and not re.match(r"[A-Za-z]", op):
        op = f"op_{op}"
    return op


def update_dataset_file(dataset_path: str) -> bool:
    with open(dataset_path, "r") as f:
        lines = f.readlines()

    out_lines = []
    changed = False
    entry_re = re.compile(
        r'^(\s*)"([^"]+)": \{"category": "([^"]+)", "level": "([^"]+)"\}(,?)\s*$'
    )

    for line in lines:
        if '"normalized_name"' in line:
            out_lines.append(line)
            continue
        match = entry_re.match(line)
        if match:
            indent, op_name, category, level, comma = match.groups()
            normalized = normalize_op_name(op_name)
            out_lines.append(
                f'{indent}"{op_name}": {{"category": "{category}", "level": "{level}", '
                f'"normalized_name": "{normalized}"}}{comma}\n'
            )
            changed = True
            continue
        out_lines.append(line)

    if changed:
        with open(dataset_path, "w") as f:
            f.writelines(out_lines)
    return changed


def main() -> None:
    repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    dataset_path = os.path.join(repo_root, "examples", "dataset.py")
    if not os.path.exists(dataset_path):
        raise FileNotFoundError(dataset_path)
    changed = update_dataset_file(dataset_path)
    if changed:
        print(f"Updated {dataset_path}")
    else:
        print(f"No changes needed in {dataset_path}")


if __name__ == "__main__":
    main()
