library(densratio)

calc_mse <- function(pred_vec, true_vec) {
  res <- mean((pred_vec-true_vec)^2)
  return(res)
}

pred_and_calc_mse_with_densratio <- function(
    data_top_dir,
    experiment_dir,
    densratio_method,
    densratio_kernel_num){
  train_data_dir <- file.path(
    data_top_dir, experiment_dir, "train")
  eval_data_dir <- file.path(
    data_top_dir, experiment_dir, "eval")
  test_data_dir <- file.path(
    data_top_dir, experiment_dir, "test")
  train_de <- read.csv(
    file.path(train_data_dir, 'de.csv'),
    header=FALSE)
  train_nu <- read.csv(
    file.path(train_data_dir, 'nu.csv'),
    header=FALSE)
  test_de <- read.csv(
    file.path(test_data_dir, 'de.csv'),
    header=FALSE)
  
  test_true_rate_df <- read.csv(
    file.path(test_data_dir, 'true_rate.csv'),
    header=FALSE)
  true_dre <- unlist(test_true_rate_df)
  
  tryCatch({ 
      densratio_obj <- densratio(
        train_nu, train_de,
        method = densratio_method,
        kernel_num = densratio_kernel_num)
      pred_dre <- densratio_obj$compute_density_ratio(test_de)
      res_mse <- calc_mse(pred_dre, true_dre)
    }, 
    error = function(e) {
      res_mse <- NA
    }, silent = TRUE)
  
  return(res_mse)
}
  
