"""Default configurations for TNPND model."""

from src.enums.model_enums import CovApprox

# Default TNPND config for 1D GP regression (from TNP-pytorch)
TNPND_GP_CONFIG = {
    "dim_x": 1,
    "dim_y": 1,
    "d_model": 64,
    "emb_depth": 4,
    "dim_feedforward": 128,
    "nhead": 4,
    "dropout": 0.0,
    "num_layers": 6,
    "num_std_layers": 2,
    "cov_approx": CovApprox.CHOLESKY,  # cholesky or lowrank parameterization
    "prj_dim": 20,
    "prj_depth": 4,
    "diag_depth": 4,  # only for lowrank parameterization option
    "bound_std": False,
}

# Config for higher dimensional problems
TNPND_MULTIDIM_CONFIG = {
    "dim_x": 2,
    "dim_y": 1,
    "d_model": 128,
    "emb_depth": 4,
    "dim_feedforward": 256,
    "nhead": 8,
    "dropout": 0.0,
    "num_layers": 8,
    "num_std_layers": 3,
    "cov_approx": CovApprox.LOWRANK,
    "prj_dim": 20,
    "prj_depth": 4,
    "diag_depth": 4,
    "bound_std": False,
}

def get_tnpnd_config(preset: str = "gp") -> dict:
    """Get a preset TNPND configuration.
    
    Args:
        preset: Configuration preset name. Options: "gp", "multidim"
        
    Returns:
        Dictionary of configuration parameters
    """
    configs = {
        "gp": TNPND_GP_CONFIG,
        "multidim": TNPND_MULTIDIM_CONFIG,
    }
    
    if preset not in configs:
        raise ValueError(f"Unknown preset {preset}. Options: {list(configs.keys())}")
    
    return configs[preset].copy()