from dataclasses import dataclass, field
from pathlib import Path

from lib_project.experiment import ExperimentConfig


PROJECT_BASE_DIR = Path(__file__).parent.parent.resolve()
ARTIFACTS_DIR = PROJECT_BASE_DIR / "artifacts"
# BASE_FIGURE_DIR = Path("/home/exp/figures/")
BASE_FIGURE_DIR = (
    Path.home() / "projects/llm_memorization/neurips_2024/paper/figures/"
)

BASE_MODEL_DIR_VEDANT = PROJECT_BASE_DIR / "base_models"
BASE_MODEL_DIR_TILL = PROJECT_BASE_DIR / "base_models_ANONYMOUS"


MODELS = {
    "pyt-70m": ("pythia-70m", BASE_MODEL_DIR_VEDANT),
    # "pyt-70m-hf": ("EleutherAI/pythia-70m", None),
    "pyt-160m": ("pythia-160m", BASE_MODEL_DIR_VEDANT),
    "pyt-410m": ("pythia-410m", BASE_MODEL_DIR_VEDANT),
    "pyt-1b": ("pythia-1b", BASE_MODEL_DIR_VEDANT),
    # "pyt-1b-hf": ("EleutherAI/pythia-1b", None),
    "pyt-1.4b": ("pythia-1.4b", BASE_MODEL_DIR_VEDANT),
    # "pyt-1.4b-hf": ("EleutherAI/pythia-1.4b", None),
    # "pyt-2.8b": ("pythia-2.8b", BASE_MODEL_DIR_VEDANT),
    "pyt-2.8b": ("EleutherAI/pythia-2.8b", None),
    "pyt-6.9b": ("pythia-6.9b", BASE_MODEL_DIR_VEDANT),
    "pyt-12b": ("pythia-12b", BASE_MODEL_DIR_VEDANT),
    "llama2-7b": ("Llama-2-7b-hf", BASE_MODEL_DIR_VEDANT),
    "llama2-13b": ("Llama-2-13b-hf", BASE_MODEL_DIR_VEDANT),
    "llama2-70b": ("Llama-2-70b-hf", BASE_MODEL_DIR_VEDANT),
    "llama2-7b-chat": ("Llama-2-7b-chat-hf", BASE_MODEL_DIR_VEDANT),
    "llama2-13b-chat": ("Llama-2-13b-chat-hf", BASE_MODEL_DIR_VEDANT),
    "llama2-70b-chat": ("Llama-2-70b-chat-hf", BASE_MODEL_DIR_VEDANT),
    "gpt2-124m": ("gpt2", None),
    "gpt2-medium": ("gpt2-medium", None),
    "gpt2-large": ("gpt2-large", None),
    "gpt2-1.5b": ("gpt2-xl", None),
    "phi-1": ("phi-1", None),
    "phi-1.3b": ("phi-1_5", None),
    "phi-2.7b": ("phi-2", None),
    "opt-350m": ("opt-350m", None),
    "pyt-1l-1h": ("pythia-1l-1h", None),
}

EXPERIMENT_SEEDS = [
    1670,
    3151,
    6174,
    7655,
    6369,
    4406,
    7268,
    8937,
    6679,
    8791,
]
# We sample seeds in the range [0, 10000), so adding 10000 to the seed
# creates a new, unique seed
SEED_OFFSET = 10000


@dataclass
class BaseConfigArgs:
    model_id: str
    base_dir: Path | None = field(default=None, kw_only=True)

    @property
    def model_name(self) -> str:
        return MODELS[self.model_id][0]

    @property
    def model_dir(self) -> Path | None:
        if self.base_dir is None:
            return MODELS[self.model_id][1]
        return self.base_dir


@dataclass
class LLMExperimentConfig(ExperimentConfig):
    pass
