## Script to apply the pc algorithm to generated data
## The output CPDAG is enumerated and the cpdag and its containing DAGs are saved

source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')

source('R/causal_discovery.R')
source('R/cpdag_enumerate.R')

source('R/experiments/helper_functions.R')

library(pcalg)
library(kpcalg)

library(reticulate)
reticulate::use_python_version("3.9.7", required = TRUE)
cur_wd <- getwd()
setwd("~/py-tetrad/pytetrad/")
reticulate::source_python("tools/TetradSearch.py")
setwd(cur_wd)

apply_fges <- function(data_path, n_observational, standardize) {
  log_to_shell(paste0("Running FGES with ", n_observational, " observational data points. Standardized = ", standardize))
  data_and_parameters <- readRDS(data_path)
  data_observational <- data_and_parameters$data_observational
  scm_type <- data_and_parameters$scm_type
  n_nodes <- data_and_parameters$n_nodes
  n_expected_neighbors <- data_and_parameters$n_expected_neighbors
  n_expected_parents_y <- data_and_parameters$n_expected_parents_y
  seed <- data_and_parameters$seed

  # Limit amount of observational data
  data_observational <- data_observational[seq(n_observational), seq(n_nodes - 1)] # Remove Y

  if (standardize) {
    data_observational <- scale(data_observational)
  }

  log_to_shell("Running FGES...")
  ts <- TetradSearch(as.data.frame(data_observational))
  ts$use_sem_bic(penalty_discount=1)
  ts$run_fges()
  log_to_shell("Finished running FGES. Enumerating DAGs in MEC.")
  
  # Convert output to cpdag and dags
  cpdag_adj_mat <- as.matrix(ts$get_graph_to_matrix())

  # Convert to standard format (ugly way)
  # Note (i, j) = 3 and (j, i) = 3 then i <-> j
  for (i in seq_len(nrow(cpdag_adj_mat))) {
    for (j in seq_len(ncol(cpdag_adj_mat))) {
      if (cpdag_adj_mat[i, j] == 3 && cpdag_adj_mat[j, i] == 3) {
        cpdag_adj_mat[i, j] <- 1
        cpdag_adj_mat[j, i] <- 1
      } else if (cpdag_adj_mat[i, j] == 3 && cpdag_adj_mat[j, i] == 2) {
        cpdag_adj_mat[i, j] <- 0
        cpdag_adj_mat[j, i] <- 1
      }
    }
  }

  cpdag_adj_mat <- as.matrix(cpdag_adj_mat)
  cpdag_and_dags <- adj_mat_cpdag_to_dags_in_cpdag(cpdag_adj_mat)
  cpdag <- cpdag_and_dags$cpdag
  dags <- cpdag_and_dags$dags
  
  log_to_shell("Finished enumerating DAGs. Saving results.")
  # Save the dags
  save_name <- paste0("fges_", n_observational)
  if (standardize) {
    save_name <- paste0(save_name, "_standardized")
  }
  save_cpdag(cpdag,
             dags,
             save_name,
             scm_type,
             n_nodes,
             n_expected_neighbors,
             n_expected_parents_y,
             seed,
             data_path = data_path,
             n_observational = n_observational,
             method = "fges",
             standardize = standardize
             )

  log_to_shell("Saved results.")
}


args <- commandArgs(trailingOnly = TRUE)

log_to_shell(paste0("causal_discovery_fges.R ", paste(args, collapse = " ")))

if (length(args) != 3) {
  stop("Usage: Rscript causal_discovery_fges.R <data_path> <n_observational> <standardize>")
}

data_path <- args[1]
n_observational <- as.numeric(args[2])
standardize <- as.logical(args[3])


apply_fges(data_path, n_observational, standardize)
