import torch
import numpy as np
from torch import nn
from loguru import logger
from numpy.typing import NDArray
from typing import Optional, Literal, Callable, Dict
from sklearn.model_selection import BaseCrossValidator


Model = Literal['linear', '2-layer']


CPU_ONLY: bool = True
MODELS: Dict[Model, Callable[[int], nn.Sequential]] = {
    'linear': lambda input_dim: nn.Sequential(
        nn.Linear(input_dim, 1, bias=False)
    ),
    '2-layer': lambda input_dim: nn.Sequential(
        nn.Linear(input_dim, 20),
        nn.LeakyReLU(0.2),
        nn.Linear(20, 1)
    )
}


def device():
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps'
    else:
        device = 'cpu'
    if CPU_ONLY:
        device = 'cpu'
    logger.info(f'Using {device} device.')
    return torch.device(device)