"""Default configurations for PFN model."""

# Default TNPA config for 1D GP regression (from TNP-pytorch)
PFN_GP_CONFIG = {
    "dim_x": 1,
    "dim_y": 1,
    "d_model": 512,
    "dim_feedforward": 1024,
    "nhead": 4,
    "dropout": 0.0,
    "num_layers": 6,
    "head_num_buckets": 1000, # they trained on both 1000 and 10000 buckets
}

# Config for higher dimensional problems
PFN_MULTIDIM_CONFIG = {
    "dim_x": 2,
    "dim_y": 1,
    "d_model": 512,
    "dim_feedforward": 1024,
    "nhead": 8,
    "dropout": 0.0,
    "num_layers": 8,
    "head_num_buckets": 1000,
}

def get_pfn_config(preset: str = "gp") -> dict:
    """Get a preset PFN configuration.
    
    Args:
        preset: Configuration preset name. Options: "gp", "gp_permute", "multidim"
        
    Returns:
        Dictionary of configuration parameters
    """
    configs = {
        "gp": PFN_GP_CONFIG,
        "multidim": PFN_MULTIDIM_CONFIG,
    }
    
    if preset not in configs:
        raise ValueError(f"Unknown preset {preset}. Options: {list(configs.keys())}")
    
    return configs[preset].copy()