"""Default configurations for TNPND model."""

# Default TNPD config for 1D GP regression (from TNP-pytorch)
TNPD_GP_CONFIG = {
    "dim_x": 1,
    "dim_y": 1,
    "d_model": 64,
    "emb_depth": 4,
    "dim_feedforward": 128,
    "nhead": 4,
    "dropout": 0.0,
    "num_layers": 6,
    "bound_std": False,
}

# Config for higher dimensional problems
TNPD_MULTIDIM_CONFIG = {
    "dim_x": 2,
    "dim_y": 1,
    "d_model": 128,
    "emb_depth": 4,
    "dim_feedforward": 256,
    "nhead": 8,
    "dropout": 0.0,
    "num_layers": 8,
    "bound_std": False,
}

def get_tnpd_config(preset: str = "gp") -> dict:
    """Get a preset TNPD configuration.
    
    Args:
        preset: Configuration preset name. Options: "gp", "multidim"
        
    Returns:
        Dictionary of configuration parameters
    """
    configs = {
        "gp": TNPD_GP_CONFIG,
        "multidim": TNPD_MULTIDIM_CONFIG,
    }
    
    if preset not in configs:
        raise ValueError(f"Unknown preset {preset}. Options: {list(configs.keys())}")
    
    return configs[preset].copy()