#this is to get the decoding accuracy for a bunch of different subsets of data
Sys.setenv("CUDA_VISIBLE_DEVICES" = "")

source("utils.R")
source("MuSCH.R")
library(sqldf)
library(tidyr)
library(nnet)
library(plotly)
library(keras3)
library(tensorflow)

set.seed(123)

#get the experimental data
load("train_data.RData")

rat_names <- c("Barat","Buchanan","Mitt","Stella","SuperChris")
odor_names <- 1:5

#determine how many partitions (observations) are in each 250-500ms window after odor presentation
n_partitions <- 1

#get the data from all the rats
all_data <- lapply(1:5, function(i){
  temp <- process_raw_data(raw_data = train_data[[rat_names[i]]],
                           start = 0.25, end = 0.5, partitions = n_partitions)
  #standardize the data in each column
  temp$x = standardize_matrix(temp$x)
  temp$rat_id = rep(i, length(temp$y))
  return(temp)
}) %>% purrr::list_transpose()

rm(train_data)
gc()

#Determine the latent representation dimension
hidden_dim <- 15L
encoder_epochs <- 100
iterations <- 60
test_percentage <- 0.3

for(target_rat_id in 0:4){
  #create the results data structure
  results <- list(
    decoding_accuracy = list(
      raw_data_single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                                   dimnames = list(NULL, c("A","B","C","D","E","total"))),
      raw_data_all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                                 dimnames = list(NULL, c("A","B","C","D","E","total"))),
      single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                          dimnames = list(NULL, c("A","B","C","D","E","total"))),
      all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                        dimnames = list(NULL, c("A","B","C","D","E","total")))),
    cce = list(
      raw_data_single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                                   dimnames = list(NULL, c("A","B","C","D","E","total"))),
      raw_data_all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                                 dimnames = list(NULL, c("A","B","C","D","E","total"))),
      single_rat = matrix(NA, ncol = 6, nrow = iterations, 
                          dimnames = list(NULL, c("A","B","C","D","E","total"))),
      all_rats = matrix(NA, ncol = 6, nrow = iterations, 
                        dimnames = list(NULL, c("A","B","C","D","E","total"))))
  )
  
  for(i in 1:iterations){
    sprintf("Rat: %s, iteration %s of %s", rat_names[target_rat_id + 1], i, iterations) %>% 
      print()
    
    #first split your data into training and testing
    data <- split_data(all_data, target_rat_id = target_rat_id, 
                       test_percentage = test_percentage)
    
    #enrich data with decoding accuracy measurements
    data <- get_decoding_accuracy(data, hidden_dim = hidden_dim, 
                                  encoder_epochs = encoder_epochs,
                                  SCL_temp = 0.07, learning_rate = 0.005)
    
    #get the test rows for each odor and finally all test rows
    single_rat_test <- lapply(0:5, function(i){
      if(i < 5){
        which(data$test == TRUE & data$odor == i)
      }else{
        which(data$test == TRUE)
      }
    })
    
    #get decoding accuracy
    results$decoding_accuracy$raw_data_single_rat[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_raw_data_single_rat[x] %>% mean()
      })
    results$decoding_accuracy$raw_data_all_rats[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_raw_data_all_rats[x] %>% mean()
      })
    results$decoding_accuracy$single_rat[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_single_rat[x] %>% mean()
      })
    results$decoding_accuracy$all_rats[i,] <- 
      sapply(single_rat_test, function(x){
        data$accuracy_all_rats[x] %>% mean()
      })
    
    #get categorical cross-entropy
    results$cce$raw_data_single_rat[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_raw_data_single_rat[x] %>% mean()
      })
    results$cce$raw_data_all_rats[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_raw_data_all_rats[x] %>% mean()
      })
    results$cce$single_rat[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_single_rat[x] %>% mean()
      })
    results$cce$all_rats[i,] <-
      sapply(single_rat_test, function(x){
        data$cce_all_rats[x] %>% mean()
      })
    
    rm(data)
    gc()
  }

  file <- sprintf("results_%s.Rdata", rat_names[target_rat_id + 1])
  sprintf("Saving to file: %s", file) %>% print()
  save(results, file = file)
}
