from pathlib import Path

import torch

DATA_DIR = Path.cwd().parent / 'data'
OUTS_DIR = Path.cwd().parent / 'outs'
LOG_OUTS_DIR = OUTS_DIR / 'log'
GEN_OUTS_DIR = OUTS_DIR / 'gen'
FIG_OUTS_DIR = OUTS_DIR / 'fig'
STATS_OUTS_DIR = OUTS_DIR / 'stats'
REPORT_OUTS_DIR = OUTS_DIR / 'report'

# dryrun settings
CFG_MODE_DRYRUN = True
DRYRUN_TEST_NUM = 10

dryrun_mark = '_.' if CFG_MODE_DRYRUN else ''
LOG_TEMP = dryrun_mark + 'me.{data_name}.{model_name}.log'
GEN_TEMP = dryrun_mark + 'me.{data_name}.{gen_mark}.jsonl'

# ========== EXPERIMENT CONFIGURATION ==========
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

MODEL_REGISTRY = {
    'gpt2': 'gpt2',
    'gpt2-xl': 'gpt2-xl',
    # ...
    # Qwen3 2025.05
    'qwen3-0.6b': 'Qwen/Qwen3-0.6B',  # 28 layers
    'qwen3-1.7b': 'Qwen/Qwen3-1.7B',  # 28 layers
    'qwen3-4b': 'Qwen/Qwen3-4B',  # 36 layers
    'qwen3-8b': 'Qwen/Qwen3-8B',  # 36 layers
    'qwen3-14b': 'Qwen/Qwen3-14B',  # 40 layers
    'qwen3-32b': 'Qwen/Qwen3-32B',
    # Gemma3 2025.03
    'gemma3-1b': 'google/gemma-3-1b-pt',
    'gemma3-4b': 'google/gemma-3-4b-pt',
    'gemma3-12b': 'google/gemma-3-12b-pt',
    'gemma3-27b': 'google/gemma-3-27b-pt',
    # Phi4 2024.12
    'phi4-14b': 'microsoft/phi-4',
    # Llama3.2 ...
    'llama3.2-1b': 'meta-llama/Llama-3.2-1B',
    'llama3.2-3b': 'meta-llama/Llama-3.2-3B',
    # ...
    # code llama * 32016 32016 32000 32016 (BF16)
    'codellama-7b': 'meta-llama/CodeLlama-7b-hf',  # 32 layers * (4096 * 4) neurons/layer *
    'codellama-13b': 'meta-llama/CodeLlama-13b-hf',  # 40 layers * (5120 * 4) neurons/layer !!!
    'codellama-34b': 'meta-llama/CodeLlama-34b-hf',  # 48 layers * (8192 * 4) neurons/layer !!!
    'codellama-70b': 'meta-llama/CodeLlama-70b-hf',  # 80 layers * (8192 * 4) neurons/layer !!!
    # starcoder2 * 49152 (F32, BF16, F32)
    'starcoder2-3b': 'bigcode/starcoder2-3b',  # 30 layers * (3072 * 4) neurons/layer
    'starcoder2-7b': 'bigcode/starcoder2-7b',  # 32 layers * (4608 * 4) neurons/layer
    'starcoder2-15b': 'bigcode/starcoder2-15b',  # 40 layers * (6144 * 4) neurons/layer !!!
    # opencoder * (F32 + F16)
    'opencoder-1.5b': 'infly/OpenCoder-1.5B-Base',
    'opencoder-8b': 'infly/OpenCoder-8B-Base',
    # qwen2.5-coder 151936 * 3 152064 * 3 (BF16)
    # (qwen2.5-coder is at the same period with opencoder ...)
    'qwencoder-0.5b': 'Qwen/Qwen2.5-Coder-0.5B',  # 24 layers * (896 * 4) neurons/layer !!!
    'qwencoder-1.5b': 'Qwen/Qwen2.5-Coder-1.5B',  # 28 layers * (1536 * 4) neurons/layer !!!
    'qwencoder-3b': 'Qwen/Qwen2.5-Coder-3B',  # 36 layers * (2048 * 4) neurons/layer !!!
    'qwencoder-7b': 'Qwen/Qwen2.5-Coder-7B',  # 28 layers * (3584 * 4) neurons/layer !!!
    'qwencoder-14b': 'Qwen/Qwen2.5-Coder-14B',  # 48 layers * (5120 * 4) neurons/layer !!!
    'qwencoder-32b': 'Qwen/Qwen2.5-Coder-32B',  # 64 layers * (5120 * 4) neurons/layer !!!
}
