#A function to assign names to the columns of a data.frame
colnames = function(df, names){
  colnames(df) <- names
  return(df)
}

#Randomly initialize a network to make a more complex topogrophy of simulated data
create_keras_transform <- function(input_dim, output_dim) {
  model <- keras_model_sequential() %>%
    layer_dense(units = 64, activation = "relu", input_shape = input_dim) %>%
    layer_dropout(0.2) %>%
    layer_dense(units = 32, activation = "tanh") %>%
    layer_dropout(0.2) %>% 
    layer_dense(units = output_dim, activation = "linear")
  
  # Initialize with random weights (no training needed for simulation)
  return(model)
}

#Simulate n_obs_per_class observations from each class for a given subject
simulate_with_keras <- function(subject_id, n_classes = 5, n_obs_per_class = 30, 
                                input_dim = 3, output_dim = 3, mean_sd = 3, noise_sd = 1,
                                percent_nonencoding = 0.8) {
  # Create subject-specific transformation
  transform_model <- create_keras_transform(input_dim = input_dim, output_dim = output_dim)
  
  all_data <- NULL
  all_labels <- NULL
  
  for (class_i in 1:n_classes) {
    # Generate base data
    class_mean <- rnorm(input_dim, mean = 0, sd = mean_sd)
    
    raw_data <- MASS::mvrnorm(n_obs_per_class, class_mean, diag(input_dim) * noise_sd)
    
    # Transform through neural network
    transformed_data <- predict(transform_model, raw_data, verbose = 0)
    
    all_data <- rbind(all_data, transformed_data)
    all_labels <- c(all_labels, rep(class_i, n_obs_per_class))
  }
  
  #jumble up some of the columns so they don't encode class
  column_indices <- sample(1:output_dim, floor(output_dim * percent_nonencoding))
  for(i in column_indices){
    all_data[,i] <- sample(all_data[,i])
  }
  
  return(list(data = all_data, labels = all_labels, 
              subject = rep(subject_id, length(all_labels))))
}

#This function formats the data and identifies a stratefied sample of the target rat's data
#to be held out for testing
split_data <- function(all_data, target_rat_id = 0, test_percentage = 0.2){
  n_rats <- length(all_data$x)
  odors <- unique(all_data$y[[target_rat_id + 1]])
  n_odors <- length(odors)
  n_neurons <- sapply(all_data$x, function(x){dim(x)[2]})
  max_neurons <- max(n_neurons)
  
  #identify the test indices for the target_rat
  test_indices <- lapply(1:n_rats, function(i){
    if(i != target_rat_id + 1){return(NULL)}
    lapply(odors, function(k){
      n_test <- floor(sum(all_data$y[[i]] == k) * test_percentage)
      which(all_data$y[[i]] == k) %>% 
        sample(., n_test)
    }) %>% Reduce("c", .)
  })
  
  indices <- lapply(1:n_rats, function(i){
    1:length(all_data$y[[i]])
  })
  
  #start joining all the data together
  x <- lapply(1:n_rats, function(i){
    #pad each rat's data with 0s so they all have the same number of neurons
    cbind(all_data$x[[i]], 
          matrix(0, nrow = nrow(all_data$x[[i]]), 
                 ncol = max_neurons - ncol(all_data$x[[i]])))
  }) %>% Reduce("rbind", .)
  
  odor <- all_data$y %>% Reduce("c", .)
  
  rat_id <- all_data$rat_id %>% Reduce("c", .)
  
  augmented <- rep(FALSE, length(odor))
  
  test <- lapply(1:n_rats, function(i){
    temp <- rep(FALSE, length(indices[[i]]))
    temp[test_indices[[i]]] <- TRUE
    return(temp)
  }) %>% Reduce("c", .)
  
  #everything that is not test is train
  train <- rep(FALSE, length(odor))
  train[test == FALSE] <- TRUE
  
  #now make sure the odor and rat_id are 0 indexed
  odor <- odor - 1
  rat_id <- rat_id - 1
  
  #rename the columns of x
  colnames(x) <- sprintf("X%s", 1:max_neurons)
  
  return(list(x = x, odor = odor, rat_id = rat_id, augmented = augmented, 
              train = train, test = test,
              target_rat_id = target_rat_id))
}

#This function scrambles selected columns so they no longer encode odor
scramble_columns <- function(base_data, column_indices, percent_nonencoding){
  n_subjects <- length(base_data$rat_id)
  n_columns <- ncol(base_data$x[[1]])
  n_columns_to_scramble <- floor(n_columns * percent_nonencoding)
  
  all_data <- base_data
  
  for(i in 1:n_subjects){
    for(j in column_indices[[i]][1:n_columns_to_scramble]){
      all_data$x[[i]][,j] <- sample(all_data$x[[i]][,j])
    }
  }
  return(all_data)
}

#This function processes the raw neural data from the rat experiment
process_raw_data <- function(raw_data, start = 0.250, end = 0.500, partitions = 1){
  #this dataframe tells you the start and end time of each partition of a sequence_id
  trial_details <- 
    data.frame(partition = 1:partitions,
               start_time = seq(from = start, 
                                to = end, 
                                length.out = (partitions +1) )[1:partitions],
               end_time = seq(from = start,
                              to = end,
                              length.out = (partitions + 1) )[2:(partitions + 1)] )
  
  temp <- raw_data %>% 
    group_by(trial_id) %>% 
    mutate(event_time = round(TimeBin - min(TimeBin), 3) ) %>% 
    ungroup() %>% 
    inner_join(trial_details, 
               by = join_by(event_time >= start_time, event_time < end_time))
  
  trial_details <- sqldf("select distinct y, trial_id from temp")
  
  event_details <- temp %>% 
    select(c("trial_id","partition","y","event_time")) %>% 
    colnames(c("trial_id","partition","label","event_time"))
  
  data <- temp %>% 
    select(-c("y","TimeBin","seq_id","trial_id","event_time","partition","start_time","end_time"))
  
  n_neurons <- dim(data)[2]
  colnames(data) <- sprintf("Y%s",1:n_neurons)
  
  #now I have to aggregate the spike activity for each seq_id/trial_id
  trial_id_partition <- sqldf("select distinct trial_id, partition from event_details")
  
  result <- lapply(1:nrow(trial_id_partition), function(i){
    trial_id <- trial_id_partition$trial_id[i]
    partition <- trial_id_partition$partition[i]
    
    indices <- event_details$trial_id == trial_id & event_details$partition == partition
    
    y <- trial_details$y[trial_details$trial_id == trial_id][1]
    x <- data[indices,] %>% apply(., 2, mean)
    
    return(list(y = y, x = x, trial_id = trial_id))
  }) %>% purrr::list_transpose()
  
  x <-  result$x %>% Reduce("rbind", .)
  
  return(list(y = result$y, x = x, trial_id = result$trial_id, 
              trial_details = trial_details))
}

#This function standardizes the columns of a data matrix; if the column is constant, it sets that column to 0
standardize_matrix <- function(X) {
  # Center
  centered <- sweep(X, 2, colMeans(X, na.rm = TRUE), "-")
  # Compute SD for each column
  sds <- apply(centered, 2, sd, na.rm = TRUE)
  # If SD is 0, set to 1 to avoid division by zero (these columns are all 0 after centering anyway)
  sds[sds == 0] <- 1
  # Scale
  scaled <- sweep(centered, 2, sds, "/")
  return(scaled)
}

#Define a function for splitting data into training, testing and validation, with upsampling
augment_data <- function(x, odor, rat_id){
  #this function takes observations x and odor and returns the augmented samples to make the number of observations equal
  max_count <- max(table(odor))
  n_col = ncol(x)
  unique_odor <- unique(odor)
  
  x_aug <- lapply(unique_odor, function(k){
    filter <- odor == k
    n <- sum(filter)
    n_to_sample <- max_count - sum(filter)
    
    lapply(1:n_col, function(j){
      sample(x[filter,j], n_to_sample, replace=TRUE)
    }) %>% Reduce("cbind", .) %>% 
      list(x = ., odor = rep(k, n_to_sample), rat_id = rep(rat_id, n_to_sample))
  }) %>% purrr::list_transpose()
  
  x_aug$x <- lapply(unique_odor, function(i){x_aug$x[[i]]}) %>% Reduce("rbind", .)
  x_aug$odor <- lapply(unique_odor, function(i){x_aug$odor[[i]]}) %>% Reduce("c", .)
  x_aug$rat_id <- lapply(unique_odor, function(i){x_aug$rat_id[[i]]}) %>% Reduce("c", .)
  
  return(x_aug)
}

#This function gets the decoding accuracy (and categorical cross-entropy) on each of the four model setups
get_decoding_accuracy <- function(data, hidden_dim = 5L, 
                                  encoder_epochs = 100, SCL_temp = 0.07, 
                                  learning_rate = 0.005){
  
  num_subjects <- data$rat_id %>% unique() %>% length()
  num_classes <- data$odor %>% unique() %>% length()
  num_neurons <- ncol(data$x)
  
  #get the encoder model for all rats
  encoder_model_all_rats <- 
    get_encoder_model(hidden_dim = hidden_dim, num_subjects = num_subjects,
                      num_neurons = num_neurons, SCL_temp = SCL_temp, 
                      learning_rate = learning_rate)
  
  #get the encoder model all rats' randomized weights
  weights <- get_weights(encoder_model_all_rats)
  
  #get the encoder model for single rat
  encoder_model_single_rat <- 
    get_encoder_model(hidden_dim = hidden_dim, num_subjects = num_subjects,
                      num_neurons = num_neurons, SCL_temp = SCL_temp, 
                      learning_rate = learning_rate)
  
  #set the encoder model single rat's weights to be identical to the all rats model
  set_weights(encoder_model_single_rat, weights)
  
  all_rats_train <- which(data$train == TRUE)
  single_rat_train <- which(data$train == TRUE & data$rat_id == data$target_rat_id)
  all_rats_test <- which(data$test == TRUE)
  single_rat_test <- which(data$test == TRUE & data$rat_id == data$target_rat_id)
  
  #Stop training when loss stops improving
  early_stop <- callback_early_stopping(
    monitor = "loss",
    patience = 10,
    min_delta = 0.3,
    restore_best_weights = TRUE
  )
  
  #train the all rats encoder network
  encoder_model_all_rats %>% fit(
    x = list(input_features = data$x[all_rats_train,], 
             input_subjects = data$rat_id[all_rats_train]),
    y = data$odor[all_rats_train],
    batch_size = 50,
    epochs = encoder_epochs,
    callbacks = list(early_stop),
    verbose = 1
  )
  
  #extract the latent representations
  data$latents_all_rats <- predict(encoder_model_all_rats,
                                   list(data$x, data$rat_id))
  
  #train the single rat encoder network
  encoder_model_single_rat %>% fit(
    x = list(input_features = data$x[single_rat_train,],
             input_subjects = data$rat_id[single_rat_train]),
    y = data$odor[single_rat_train],
    batch_size = 50,
    epochs = encoder_epochs,
    callbacks = list(early_stop),
    verbose = 1
  )
  
  #extract the latent representations
  data$latents_single_rat <- predict(encoder_model_single_rat,
                                     list(data$x, data$rat_id))
  
  #train the classification networks
  classifier_model_from_raw_data_single_rat <- 
    nnet(odor ~ ., data = data.frame(odor = as.factor(data$odor[single_rat_train]),
                                     X = data$x[single_rat_train,]), 
         size = hidden_dim, decay = 0.1, maxit = 500, MaxNWts=1700)
  
  classifier_mdoel_from_raw_data_all_rats <- 
    nnet(odor ~ ., data = data.frame(odor = as.factor(data$odor[all_rats_train]),
                                     X = data$x[all_rats_train,]),
         size = hidden_dim, decay = 0.1, maxit = 500, MaxNWts=1700)
  
  classifier_model_single_rat <- 
    nnet(odor ~ ., data = data.frame(odor = as.factor(data$odor[single_rat_train]),
                                     X = data$latents_single_rat[single_rat_train,]), 
         size = hidden_dim, decay = 0.1, maxit = 500)
  
  classifier_model_all_rats <- 
    nnet(odor ~ ., data = data.frame(odor = as.factor(data$odor[all_rats_train]),
                                     X = data$latents_all_rats[all_rats_train,]), 
         size = hidden_dim, decay = 0.1, maxit = 500)
  
  #get class probabilities
  data$probs_raw_data_single_rat <- 
    predict(classifier_model_from_raw_data_single_rat,
            newdata = data.frame(X = data$x),
            type = "raw")
  data$probs_raw_data_all_rats <- 
    predict(classifier_mdoel_from_raw_data_all_rats,
            newdata = data.frame(X = data$x),
            type = "raw")
  data$probs_all_rats <- 
    predict(classifier_model_all_rats,
            newdata = data.frame(X = data$latents_all_rats),
            type = "raw")
  data$probs_single_rat <- 
    predict(classifier_model_single_rat,
            newdata = data.frame(X = data$latents_single_rat),
            type = "raw")
  
  #get categorical crossentropy
  data$cce_raw_data_single_rat <- sapply(1:length(data$odor), function(i){
    -log(data$probs_raw_data_single_rat[i,(data$odor[i]+1)])
  })
  
  data$cce_raw_data_all_rats <- sapply(1:length(data$odor), function(i){
    -log(data$probs_raw_data_all_rats[i,(data$odor[i]+1)])
  })
  
  data$cce_all_rats <- sapply(1:length(data$odor), function(i){
    -log(data$probs_single_rat[i,(data$odor[i]+1)])
  })
  
  data$cce_single_rat <- sapply(1:length(data$odor), function(i){
    -log(data$probs_all_rats[i,(data$odor[i]+1)])
  })
  
  #get classification accuracy
  data$accuracy_raw_data_single_rat <- 
    (apply(data$probs_raw_data_single_rat, 1, which.max) == data$odor + 1)
  data$accuracy_raw_data_all_rats <- 
    (apply(data$probs_raw_data_all_rats, 1, which.max) == data$odor + 1)
  data$accuracy_single_rat <- 
    (apply(data$probs_single_rat, 1, which.max) == data$odor + 1)
  data$accuracy_all_rats <- 
    (apply(data$probs_all_rats, 1, which.max) == data$odor + 1)
  
  return(data)
}

