"""Default configurations for TNPA model."""

# Default TNPA config for 1D GP regression (from TNP-pytorch)
TNPA_GP_CONFIG = {
    "dim_x": 1,
    "dim_y": 1,
    "d_model": 64,
    "emb_depth": 4,
    "dim_feedforward": 128,
    "nhead": 4,
    "dropout": 0.0,
    "num_layers": 6,
    "bound_std": False,
    "permute": False,  # Whether to permute target order during sampling
}

# Config with permutation for better coverage
TNPA_GP_PERMUTE_CONFIG = {
    "dim_x": 1,
    "dim_y": 1,
    "d_model": 64,
    "emb_depth": 4,
    "dim_feedforward": 128,
    "nhead": 4,
    "dropout": 0.0,
    "num_layers": 6,
    "bound_std": False,
    "permute": True,
}

# Config for higher dimensional problems
TNPA_MULTIDIM_CONFIG = {
    "dim_x": 2,
    "dim_y": 1,
    "d_model": 128,
    "emb_depth": 4,
    "dim_feedforward": 256,
    "nhead": 8,
    "dropout": 0.0,
    "num_layers": 8,
    "bound_std": False,
    "permute": False,
}

def get_tnpa_config(preset: str = "gp") -> dict:
    """Get a preset TNPA configuration.
    
    Args:
        preset: Configuration preset name. Options: "gp", "gp_permute", "multidim"
        
    Returns:
        Dictionary of configuration parameters
    """
    configs = {
        "gp": TNPA_GP_CONFIG,
        "gp_permute": TNPA_GP_PERMUTE_CONFIG,
        "multidim": TNPA_MULTIDIM_CONFIG,
    }
    
    if preset not in configs:
        raise ValueError(f"Unknown preset {preset}. Options: {list(configs.keys())}")
    
    return configs[preset].copy()