import os

GLOBAL_EMBEDDING_DIM = 1024

# Base directory derived from this file's location
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# OpenAI API settings
OPENAI_CONFIG = {
    "api_key": "YOUR_OPENAI_API_KEY",
    "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1"
}

# default LLM
LLM_MODELS = ["qwen3-0.6b", "qwen3-1.7b", "qwen3-4b", "qwen3-8b", "qwen3-14b", "qwen3-32b"]

# default verifier model
VERIFIER_MODEL = "skywork-o1-prm-1.5b"


# TTS configurations
def generate_tts_combinations():
    """Generate all valid QP/CP/BS combinations, consistent with precompute_embeddings.py"""
    combinations = []
    qp_values = [1, 2, 4, 8, 16, 32, 64]
    cp_values = [1, 2, 4, 8, 16, 32, 64]
    cp_to_bs = {1: 1, 2: 1, 4: 2, 8: 2, 16: 4, 32: 4, 64: 4}

    for qp in qp_values:
        for cp in cp_values:
            if qp * cp <= 64:
                bs = cp_to_bs[cp]
                combinations.append((qp, cp, bs))
    return combinations

TTS_COMBINATIONS = generate_tts_combinations()

# Cache and file paths
CACHE_DIR = os.path.join(BASE_DIR, "cache")
EMBEDDING_FILE = os.path.join(CACHE_DIR, "action_embeddings_origin.json")
LOSS_LOG_PATH = os.path.join(CACHE_DIR, "loss_log.csv")

# ---- eFLOPs related model parameters and constants ----
MODEL_PARAM_P = {
    "qwen3-0.6b": 0.6e9,
    "qwen3-1.7b": 1.7e9,
    "qwen3-4b": 4e9,
    "qwen3-8b": 8.2e9,
    "qwen3-14b": 14.8e9,
    "qwen3-32b": 32.8e9,
    "skywork-o1-prm-1.5b": 1.54e9,
}
MODEL_GQA_RATIO_R = {
    "qwen3-0.6b": 0.5,
    "qwen3-1.7b": 0.5,
    "qwen3-4b": 0.25,
    "qwen3-8b": 0.25,
    "qwen3-14b": 0.2,
    "qwen3-32b": 1.0,
    "skywork-o1-prm-1.5b": 2.0,
}

# KV size D (KV dimension per token; default unified to 128)
MODEL_KV_SIZE_D = {
    "qwen3-0.6b": 57344,
    "qwen3-1.7b": 114688,
    "qwen3-4b": 73728,
    "qwen3-8b": 147456,
    "qwen3-14b": 196608,
    "qwen3-32b": 262144,
    "skywork-o1-prm-1.5b": 28672,
}

# Arithmetic Intensity I (e.g., B200 takes 562.5)
ARITHMETIC_INTENSITY_I = 156