#This function returns the supervised contrastive loss
supervised_contrastive_loss <- function(temperature = 0.07, psi = 1) {
  function(odor, latents) {
    # Pairwise similarity matrix: [batch, batch]
    logits <- tf$matmul(latents, latents, transpose_b = TRUE) / temperature
    
    # Mask: positive pairs for each anchor
    odor <- tf$reshape(odor, shape = c(-1L, 1L))
    matches <- tf$equal(odor, tf$transpose(odor))   # [batch, batch], TRUE if same class
    mask_positive <- tf$cast(matches, tf$float32)
    
    # Remove diagonal (don't include self-contrast)
    logits_mask <- 1 - tf$eye(tf$shape(odor)[1])
    mask_positive_nodiag <- mask_positive * logits_mask   # [batch, batch]
    
    # For numerical stability, subtract max of logits (per row)
    logits_max <- tf$reduce_max(logits, axis = 1L, keepdims = TRUE)
    logits <- (logits - logits_max) %>% tf$cast(tf$float32)
    
    # Exponentiate logits
    # 0 for diagonal terms
    exp_logits <- tf$cast(tf$exp(logits), tf$float32) * logits_mask
    
    # Denominator for each anchor: sum over all samples except self
    log_prob <- logits - 
      (tf$reduce_sum(exp_logits, axis = 1L, keepdims = TRUE) %>% 
         tf$math$log() * psi + 1e-12)
    
    # Number of positives for each anchor
    positives_per_anchor <- tf$reduce_sum(mask_positive_nodiag, axis = 1L)  # Per-row sum of positive mask
    
    # Loss for each positive pair (anchors and their positives)
    mean_log_prob_pos <- tf$reduce_sum(mask_positive_nodiag * log_prob, axis = 1L) / (positives_per_anchor + 1e-12)
    
    # Final loss: mean over all anchors that have >0 positives
    loss <- -tf$reduce_mean(mean_log_prob_pos)
    return(loss)
  }
}

#This function returns the encoder model, which consists of independent, subject-specific encoding networks
get_encoder_model <- function(hidden_dim, num_subjects, num_neurons,
                              SCL_temp = 0.07, learning_rate = 0.005){
  #Define model inputs
  input_features <- layer_input(shape = num_neurons, name = "input_features")
  input_subjects  <- layer_input(shape = 1, dtype = "int32", name = "input_subjects")
  
  #Define the subject-specific encoder networks
  subject_encoders <- lapply(1:num_subjects, function(i) {
    keras_model_sequential(input_shape = c(num_neurons),
                           name = sprintf("subject_encoder_%s", i)) %>%
      layer_dropout(rate = 0.2) %>% 
      layer_dense(units = hidden_dim, activation = "linear",
                  kernel_regularizer = regularizer_l2(0.2))
  })
  
  #define the custom switching layer
  layer_switch <- Layer(
    classname = "SwitchLayer",
    
    initialize = function(subject_encoders, hidden_dim) {
      super$initialize()
      self$subject_encoders <- subject_encoders
      self$hidden_dim <- hidden_dim
    },
    
    call = function(self, inputs, mask = NULL) {
      input_features <- inputs[[1]]
      input_subject <- inputs[[2]]
      input_subject_id <- tf$squeeze(input_subject, axis = 1L)
      subject_encoders <- self$subject_encoders
      
      tf$map_fn(
        function(args) {
          vec <- args[[1]]                            # shape: (num_features,)
          s_id <- args[[2]]
          vec_batched <- tf$expand_dims(vec, axis = 0L)# shape: (1, num_features)
          branches <- lapply(subject_encoders, function(model) function() {
            output <- model(vec_batched)               # shape: (1, latent_dim)
            tf$squeeze(output, axis = 0L)              # shape: (latent_dim)
          })
          tf$switch_case(s_id, branches)
        },
        elems = list(input_features, input_subject_id),
        fn_output_signature = tf$TensorSpec(shape = list(self$hidden_dim), 
                                            dtype = tf$float32)
      )
    }
  )
  
  #define the last layer, which will normalize all the encoder outputs
  normalize_layer <- keras_model_sequential(name = "normalize") %>%
    layer_lambda(
      f = function(x) tf$linalg$l2_normalize(x, axis = 1L),
      output_shape = c(hidden_dim)
    )
  
  #Pass through SwitchLayer
  switched <- layer_switch(
    subject_encoders = subject_encoders, hidden_dim = hidden_dim
  )(list(input_features, input_subjects))
  
  #Pass through normalization (latents) layer
  latents <- normalize_layer(switched)
  
  #Build final Model
  encoder_model <- keras_model(
    inputs = list(input_features, input_subjects),
    outputs = latents
  )
  
  #compile the model
  encoder_model %>% compile(
    loss = supervised_contrastive_loss(temperature = SCL_temp),
    optimizer = optimizer_adam(learning_rate = learning_rate, clipnorm = 1.0)
  )
  
  return(encoder_model)
}

