#This scripts simulates data with varying degrees of signal and fits the MuSCH model to that data
Sys.setenv("CUDA_VISIBLE_DEVICES" = "")

source("utils.R")
source("MuSCH.R")
library(tidyr)
library(nnet)
library(plotly)
library(keras3)
library(tensorflow)

#specify the number of subjects, classes, neurons and observations per class will be simulated
n_neurons <- 25
n_classes <- 5
n_subjects <- 5
n_obs_per_class <- c(60,20,20,20,20)

#specify values for the proportion of non-encoding columns that will be investigated
percent_nonencodings <- c(0,0.2,0.4,0.6,0.8,0.9)

set.seed(123)

#Generate the data that will be modified for varying degrees of signal
base_data <- lapply(1:n_subjects, function(i){
  temp <- simulate_with_keras(subject_id = i, n_classes = n_classes, 
                              n_obs_per_class = n_obs_per_class[i], 
                              input_dim = 10, output_dim = n_neurons, mean_sd = 1, 
                              noise_sd = 2, percent_nonencoding = 0)
  return(list(x = temp$data,
              y = temp$labels,
              rat_id = temp$subject))
}) %>% purrr::list_transpose()

column_indices <- lapply(1:n_subjects, function(i){sample(1:n_neurons)})

#Get decoding accuracy
target_rat_id <- 0
hidden_dim <- 15L
encoder_epochs <- 100
iterations <- 60
test_percentage <- 1 - (n_obs_per_class[2]/n_obs_per_class[1])

for(percent_nonencoding in percent_nonencodings){
  #create the results data structure
  results <- list(
    decoding_accuracy = list(
      raw_data_single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                                   dimnames = list(NULL, c("A","B","C","D","E","total"))),
      raw_data_all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                                 dimnames = list(NULL, c("A","B","C","D","E","total"))),
      single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                          dimnames = list(NULL, c("A","B","C","D","E","total"))),
      all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                        dimnames = list(NULL, c("A","B","C","D","E","total")))),
    cce = list(
      raw_data_single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                                   dimnames = list(NULL, c("A","B","C","D","E","total"))),
      raw_data_all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                                 dimnames = list(NULL, c("A","B","C","D","E","total"))),
      single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                          dimnames = list(NULL, c("A","B","C","D","E","total"))),
      all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                        dimnames = list(NULL, c("A","B","C","D","E","total"))))
  )
  
  #get all_data from base_data; scrable some proportion of the columns
  all_data <- scramble_columns(base_data, column_indices, percent_nonencoding)
  
  #get decoding accuracy a bunch of times
  for(i in 1:iterations){
    sprintf("Percent Non-encoding: %s, iteration %s of %s", percent_nonencoding, i, iterations) %>% 
      print()
    
    #first split your data into training and testing
    data <- split_data(all_data, target_rat_id = target_rat_id, 
                       test_percentage = test_percentage)
    
    #enrich data with decoding accuracy measurements
    data <- get_decoding_accuracy(data, hidden_dim = hidden_dim, 
                                  encoder_epochs = encoder_epochs,
                                  SCL_temp = 0.07, learning_rate = 0.005)
    
    #get the test rows for each odor and finally all test rows
    single_rat_test <- lapply(0:5, function(i){
      if(i < 5){
        which(data$test == TRUE & data$odor == i)
      }else{
        which(data$test == TRUE)
      }
    })
    
    #get decoding accuracy
    results$decoding_accuracy$raw_data_single_rat[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_raw_data_single_rat[x] %>% mean()
      })
    results$decoding_accuracy$raw_data_all_rats[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_raw_data_all_rats[x] %>% mean()
      })
    results$decoding_accuracy$single_rat[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_single_rat[x] %>% mean()
      })
    results$decoding_accuracy$all_rats[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_all_rats[x] %>% mean()
      })
    
    #get categorical cross-entropy
    results$cce$raw_data_single_rat[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_raw_data_single_rat[x] %>% mean()
      })
    results$cce$raw_data_all_rats[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_raw_data_all_rats[x] %>% mean()
      })
    results$cce$single_rat[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_single_rat[x] %>% mean()
      })
    results$cce$all_rats[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_all_rats[x] %>% mean()
      })
    
    rm(data)
    gc()
  }
  
  file <- sprintf("results_percent_nonencoding_%s.Rdata", percent_nonencoding)
  sprintf("Saving to file: %s", file) %>% print()
  save(results, file = file)
}


