# @package _global_
# Epiformer Model Configuration


model:
  name: "epiformer"  # Options: "epiformer", "epiformer"
  
  # ag encoder parameters
  ag_encoder:
    # Block enabling/disabling
    atommp_enabled: true
    edgemp_enabled: true 
    resmp_enabled: true
    
    # GNN type selection
    atom_mp_type: "egnn"  # Options: "egnn", "et", "gcn", "gat", "gin"
    edgemp_type: "egnn"   # Options: "egnn", "gcn", "gat", "gin"
    resmp_type: "egnn"    # Options: "egnn", "regnn", "gcn", "gat", "gin", "rgcn"
    
    # AtomMP parameters
    atom_in_nf: 28
    atom_hidden_nf: 64
    atom_out_nf: 32
    atom_layers: 3
    atom2res_inj: "ca_only" # Options: "mean", "ca_only", "sum"

    # EdgeMP parameters  
    edge_input_dim: 100
    edge_hidden_dims: [64, 64]
    edge_layers: 3

    # Feature fusion
    feature_fusion_type: "concat"  # Options: "gated", "concat"
    
    # ResMP parameters
    residue_dim: 105 # Should match between ag and ab encoders
    residue_hidden_dim: 128
    residue_layers: 4
    num_relations: 4
    
    # PLM parameters
    plm_in_dim: 480 # ESM-2 dimension
    plm_dim: 128

  ab_encoder:
    # Block enabling/disabling (same as ag_encoder typically)
    atommp_enabled: true
    edgemp_enabled: true
    resmp_enabled: true
    
    # GNN type selection (can differ from ag_encoder)
    atom_mp_type: "egnn"  # Options: "egnn", "et", "gcn", "gat", "gin"
    edgemp_type: "egnn"   # Options: "egnn", "gcn", "gat", "gin"
    resmp_type: "egnn"    # Options: "egnn", "regnn", "gcn", "gat", "gin", "rgcn"
    
    # AtomMP parameters
    atom_in_nf: 28
    atom_hidden_nf: 64
    atom_out_nf: 32
    atom_layers: 3
    atom2res_inj: "ca_only" # Options: "mean", "ca_only", "sum"

    # EdgeMP parameters
    edge_input_dim: 100
    edge_hidden_dims: [64, 64]
    edge_layers: 3

    # Feature fusion
    feature_fusion_type: "concat"  # Options: "gated", "concat"
    
    # ResMP parameters
    residue_dim: 105 # Should match ag_encoder
    residue_hidden_dim: 128
    residue_layers: 4
    num_relations: 4
    
    # PLM parameters
    plm_in_dim: 512   # AntiBERTy/AbLang dimension
    plm_dim: 128

  geo_dim: 105
  activation: "silu"  # relu, gelu, silu/swish, leaky_relu

  # Epiformer-specific parameters (used when model.name = "epiformer")
  epiformer:
    ag_resmp_type: "egnn"
    ab_resmp_type: "egnn"
    # Basic parameters 
    geo_dim: 105
    residue_dim: 105
    residue_hidden_dim: 128
    residue_layers: 4
    edge_dim: 100
    num_relations: 4
    
    # PLM parameters
    plm_dim: 128
    ag_plm_in_dim: 480  # ESM-2 dimension for antigen
    ab_plm_in_dim: 512  # AntiBERTy dimension for antibody
    
    # Cross-attention parameters
    n_heads: 3  # Number of attention heads (105 divisible by 3)
    attention_dropout: 0.1
    
    # Feedforward network parameters
    ffn_expansion_factor: 4  # FFN hidden dim = expansion_factor * residue_dim
    
    # Pair representation parameters
    use_pair_repr: false  # Enable pair representation with triangle updates
    pair_dim: 64  # Dimension of pair representation

    ab_feature_fusion_type: "gated" # can be "gated", "concat"
    ag_feature_fusion_type: "concat"
    
    # General parameters
    activation: "silu"  # relu, gelu, silu/swish, leaky_relu
    dropout: 0.1
    use_layer_norm: true
    
    # Memory optimization
    use_gradient_checkpointing: false
    checkpoint_segments: 2

  # Decoder parameters
  decoder:
    type: "cross_attention"  # "cross_attention", "dot_product", "enhanced_bilinear", "dual", or "walle"
    sampling_strat: "mean_row" # Available strategies:
    # - "max_row": Row-wise maximum (original)
    # - "mean_row": Row-wise mean (current default)  
    # - "top_k_mean_2": Mean of top-2 interactions (biologically motivated)
    # - "top_k_mean_3": Mean of top-3 interactions (biologically motivated)
    # - "softmax_attention": Learned attention weights for interactions
    # - "edge_budget_aware": WALLE-style sparse, high-confidence predictions
    # - "epiformer_pooling": Combines local specificity with global context
    d_model: 128
    n_heads: 8
    decoder_layers: 3
    d_ff: 512
    d_k: 64
    num_rbf: 16              # TODO: RBF centers for enhanced_bilinear decoder
    predict_distances: False

  # General
  dropout: 0.1
  use_layer_norm: true
  
  # Specific dropout rates for different components
  dropout_rates:
    atom_mp: 0.2
    edge_mp: 0.2
    res_mp: 0.1
    decoder: 0.1
    projections: 0.1

  # Fixed threshold parameters (like v4) - no optimization
  epi_threshold: 0.3
  para_threshold: 0.3

  threshold: 0.5
  
  # Memory optimization settings
  use_gradient_checkpointing: false
  checkpoint_segments: 2

  # Hybrid model configuration (BaseModel + lightweight cross-attention)
  hybrid_n_heads: 8                    # Number of attention heads for cross-attention
  hybrid_attention_dropout: 0.1       # Dropout for cross-attention layers  
  hybrid_cross_attn_weight: 0.1       # Initial weight for residual connection (learnable)

