from dataclasses import dataclass
from pathlib import Path

from lib_dl.analysis.experiment import ExperimentConfig


PROJECT_BASE_DIR = Path(__file__).parent.parent.resolve()
ARTIFACTS_DIR = PROJECT_BASE_DIR / "artifacts"

BASE_MODEL_DIR_VEDANT = PROJECT_BASE_DIR / "base_models"
BASE_MODEL_DIR_TILL = PROJECT_BASE_DIR / "base_models_tspeicher"
MODELS = {
    "pyt-70m": ("pythia-70m", BASE_MODEL_DIR_VEDANT),
    "pyt-160m": ("pythia-160m", BASE_MODEL_DIR_VEDANT),
    "pyt-410m": ("pythia-410m", BASE_MODEL_DIR_VEDANT),
    "pyt-1b": ("pythia-1b", BASE_MODEL_DIR_VEDANT),
    "pyt-1.4b": ("pythia-1.4b", BASE_MODEL_DIR_VEDANT),
    # "pyt-2.8b": ("pythia-2.8b", BASE_MODEL_DIR_VEDANT),
    "pyt-2.8b": ("EleutherAI/pythia-2.8b", None),
    "pyt-6.9b": ("pythia-6.9b", BASE_MODEL_DIR_VEDANT),
    "pyt-12b": ("pythia-12b", BASE_MODEL_DIR_VEDANT),
    "llama2-7b": ("Llama-2-7b-hf", BASE_MODEL_DIR_VEDANT),
    "llama2-13b": ("Llama-2-13b-hf", BASE_MODEL_DIR_VEDANT),
    "gpt2": ("gpt2", None),
    "gpt2-xl": ("gpt2-xl", None),
}

RANDOM_STRING_MODEL_DIR = PROJECT_BASE_DIR / "artifacts" / "random_strings"

BASE_FIGURE_DIR = Path("/home/exp/figures/")


@dataclass
class BaseConfigArgs:
    model_id: str

    @property
    def model_name(self) -> str:
        return MODELS[self.model_id][0]

    @property
    def model_dir(self) -> Path | None:
        return MODELS[self.model_id][1]


@dataclass
class LLMExperimentConfig(ExperimentConfig):
    local_rank: int
