import os
from pathlib import Path

import torch

BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent
DATA_DIR = BASE_DIR / "data"
RESULTS_DIR = BASE_DIR / "results"
MODEL_DIR = BASE_DIR / "models"
ENV_FILE_PATH = BASE_DIR / ".env"
BENCHMARK_RESULTS_DIR = BASE_DIR / "src/eliciting_contexts/benchmark/results"

DCT_DASHBOARD_CACHE = BASE_DIR / "external/dashboard/dct_cache/cache"
DCT_DASHBOARDS = BASE_DIR / "external/dashboard/dct_cache/dashboards"
DCT_DB_CACHE = BASE_DIR / "external/dashboard/dct_cache/db"

SAE_DASHBOARD_CACHE = BASE_DIR / "external/dashboard/sae_cache/cache"
SAE_DASHBOARDS = BASE_DIR / "external/dashboard/sae_cache/dashboards"
SAE_DB_CACHE = BASE_DIR / "external/dashboard/sae_cache/db"


def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


DEVICE = get_device()

WANDB_ENTITY = ""

DEFAULT_DB_PATH = os.environ.get(
    "DCT_DB_PATH", str(DATA_DIR / "dct" / "dct_results.db")
)
DEFAULT_TENSOR_DIR = os.environ.get("DCT_TENSOR_DIR", str(DATA_DIR / "dct" / "tensors"))
