from pathlib import Path

import torch

BASE_DIR = Path(__file__).resolve().parent.parent.parent
DATA_DIR = BASE_DIR / "data"
MODEL_DIR = BASE_DIR / "models"
RESULTS_DIR = BASE_DIR / "results"
MODEL_CHECKPOINTS_DIR = MODEL_DIR / "checkpoints"
SAE_CHECKPOINTS_DIR = MODEL_DIR / "sae"


def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


DEVICE = get_device()

VERBOSE = False


BASE_COLOURS = [
    "#FF0000",  # Red
    "#00A000",  # Green (adjusted to be more distinct)
    "#0000FF",  # Blue
    "#FF00FF",  # Magenta
    "#FFD700",  # Gold (replaced yellow for better visibility)
    "#00FFFF",  # Cyan
    "#FF8000",  # Orange
    "#8B008B",  # Dark Magenta (replaced similar pink)
    "#32CD32",  # Lime Green
    "#00CED1",  # Dark Turquoise (replaced similar green)
    "#4169E1",  # Royal Blue (replaced similar blue)
    "#9400D3",  # Dark Violet
    "#FF6347",  # Tomato (replaced light red)
    "#2E8B57",  # Sea Green
    "#6A5ACD",  # Slate Blue
    "#C71585",  # Medium Violet Red
    "#20B2AA",  # Light Sea Green
    "#1E90FF",  # Dodger Blue
    "#FF4500",  # Orange Red
    "#7CFC00",  # Lawn Green
]

DEFAULT_MODULO = 113
MODULO = 113
TRAIN_FRACTION = 0.3

FLOAT_PRECISION = torch.float64
FLOAT_PRECISION_MAP = {
    16: torch.float16,
    32: torch.float32,
    64: torch.float64,
}
