"""
KernelBench Dataset Definition.

This module defines the operator dataset for KernelBench benchmarks.
Each operator has a category (for few-shot example selection) and a level.

Levels:
- level1: Basic operators (matmul, activation, pooling, normalization, etc.)
- level2: Fused operators (Conv+BN+ReLU, Matmul+GELU, etc.)
- level3: Architecture components (ResNet, VGG, Transformer blocks, etc.)
- level4: Complete HuggingFace models (GPT2, BART, etc.)
"""

import json
import os
import random
# Category to example operator mapping (for few-shot learning)
# These map to existing AscendC few-shot examples in /prompts/ascendc_new_model_*.py
# Available examples: add, leaky_relu, matmul_add, mse_loss, reduce_sum, sigmoid
category2exampleop = {
    "matmul": "matmul_add",        # matmul + bias fusion
    "activation": "leaky_relu",    # activation function
    "normalization": "reduce_sum", # reduction-based (layer_norm removed)
    "pooling": "reduce_sum",       # pooling is reduction-like
    "reduce": "reduce_sum",        # direct match
    "convolution": "matmul_add",   # heavy compute pattern
    "loss": "mse_loss",            # loss function
    "fuse": "matmul_add",          # fused ops
    "arch": "matmul_add",          # complex architecture
    "transformer": "matmul_add",   # transformer blocks
}

# KernelBench dataset with category and level info
# Loaded from dataset.json which contains normalized and PascalCase names
_dataset_path = os.path.join(os.path.dirname(__file__), "dataset.json")

if os.path.exists(_dataset_path):
    with open(_dataset_path, "r") as _f:
        dataset = json.load(_f)
else:
    # Fallback to empty if file not found, though verification should ensure it exists
    dataset = {}

# For ablation studies: split dataset into level1 and level2 subsets, level1: 30, level2: 20
L1_items = [(name, info) for name, info in dataset.items() if info.get("level") == "level1"]
L2_items = [(name, info) for name, info in dataset.items() if info.get("level") == "level2"]
random.seed(42)  
random.shuffle(L1_items)
random.shuffle(L2_items)
L1_take = min(30, len(L1_items))
L2_take = min(20, len(L2_items))
ablation_sub_dataset = {
    name: info for name, info in (L1_items[:L1_take] + L2_items[:L2_take])
}
