# -*- coding: utf-8 -*-
"""
Model & experiment configuration for ICLR runs.
This module is lightweight and does NOT import any heavy libraries.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict

class ModelFamily(Enum):
    LLAMA = "llama"
    QWEN = "qwen"
    MIXTRAL = "mixtral"

@dataclass
class ModelConfig:
    model_id: str
    family: ModelFamily
    size_billion: float
    architecture: str          # "dense" or "moe"
    context_length: int
    multilingual: bool
    instruction_tuned: bool
    license: str
    huggingface_id: str
    lora_target_modules: List[str] = field(default_factory=lambda: [
        "q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"
    ])
    notes: str = ""

MODEL_REGISTRY: Dict[str, ModelConfig] = {
    "llama3-8b": ModelConfig(
        model_id="llama3-8b",
        family=ModelFamily.LLAMA,
        size_billion=8,
        architecture="dense",
        context_length=8192,
        multilingual=False,
        instruction_tuned=True,
        license="Meta custom",
        huggingface_id="meta-llama/Meta-Llama-3-8B-Instruct",
        notes="Baseline model used in original experiments."
    ),
    "llama3.1-70b": ModelConfig(
        model_id="llama3.1-70b",
        family=ModelFamily.LLAMA,
        size_billion=70,
        architecture="dense",
        context_length=128000,
        multilingual=False,
        instruction_tuned=True,
        license="Meta custom",
        huggingface_id="meta-llama/Llama-3.1-70B-Instruct",
        notes="Same family scale-up for size effect."
    ),
    "qwen2.5-72b": ModelConfig(
        model_id="qwen2.5-72b",
        family=ModelFamily.QWEN,
        size_billion=72,
        architecture="dense",
        context_length=32768,
        multilingual=True,
        instruction_tuned=True,
        license="Apache-2.0",
        huggingface_id="Qwen/Qwen2.5-72B-Instruct",
        notes="Multilingual, diverse corpus."
    ),
    "mixtral-8x7b": ModelConfig(
        model_id="mixtral-8x7b",
        family=ModelFamily.MIXTRAL,
        size_billion=56,  # total; ~13B active
        architecture="moe",
        context_length=32768,
        multilingual=True,
        instruction_tuned=True,
        license="Apache-2.0",
        huggingface_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
        lora_target_modules=["q_proj","k_proj","v_proj","o_proj","w1","w2","w3"],
        notes="MoE; ~13B active per forward."
    ),
}

@dataclass
class ExperimentConfig:
    model_ids: List[str] = field(default_factory=lambda: [
        "llama3-8b", "llama3.1-70b", "qwen2.5-72b", "mixtral-8x7b"
    ])
    guard_levels: List[str] = field(default_factory=lambda: ["L1_custom","L2_human","L3_all"])
    out_root: str = "./iclr_results"
    data_root: str = "./data"
    use_quantization: bool = True
    quant_bits: int = 4
    device_map: str = "auto"
    fast_mode: bool = False

def get_models(cfg: ExperimentConfig) -> List[ModelConfig]:
    return [MODEL_REGISTRY[m] for m in cfg.model_ids if m in MODEL_REGISTRY]
