library(tidyverse)

get_data_base <- function(tibble_raw) {
  # select columns age, sex, education, fdg, abeta, ptau, apoe4, dx
  tibble_relevant <- tibble_raw |>
      dplyr::select(AGE, PTGENDER, PTEDUCAT, FDG, ABETA, PTAU, APOE4, DX_bl)

  # rename columns
  tibble_relevant <- tibble_relevant |>
      dplyr::rename(SEX = PTGENDER, EDU = PTEDUCAT, DX = DX_bl)

  # remove rows with na values
  tibble_relevant <- tibble_relevant |>
                    drop_na()

  # remove rows with diagnosis SMC
  tibble_relevant <- tibble_relevant |>
        dplyr::filter(DX != 'SMC')

  # Replace > values in ABETA, PTAU
  tibble_relevant <- tibble_relevant |>
    dplyr::mutate(
      ABETA = ifelse(ABETA == '>1700', 1700, ABETA),
      PTAU = ifelse(PTAU == '>120', 120, PTAU),
      PTAU = ifelse(PTAU == '<8', 8, PTAU)
    ) |>
      dplyr::mutate(
          PTAU = as.numeric(PTAU),
          ABETA = as.numeric(ABETA)
      )

  return(tibble_relevant)

}

binarize_diagnosis <- function(tibble, only_diagnosis = FALSE) {
  if (!only_diagnosis) {
    # group EMCI, LMCI and AD together (with value 1), control as 0
    tibble <- tibble |>
        dplyr::mutate(
        DX = ifelse(DX == 'CN', 0, 1)
        )
  } else {
  tibble <- tibble |>
        dplyr::mutate(
        DX = ifelse(DX == 'AD', 1, 0)
        )
  }

  return(tibble)
}

discretize_diagnosis <- function(tibble) {
  # control = 0, lmci = 1, emci = 2, ad = 3
  tibble <- tibble |>
    dplyr::mutate(
      DX = ifelse(DX == "CN", 0, ifelse(DX == "LMCI", 1, ifelse(DX == "EMCI", 2, 3)))
    )
}

binarize_apoe <- function(tibble) {
  # turn APOE4 into two variables: APOE41, APOE42
  tibble <- tibble |>
    dplyr::mutate(
            APOE41 = ifelse(APOE4 == '1', 1, 0),
            APOE42 = ifelse(APOE4 == '2', 1, 0)
            )
  # remove original column
    tibble <- tibble |>
        dplyr::select(-APOE4)
  return(tibble)
}

binarize_sex <- function(tibble) {
  tibble <- tibble |>
    dplyr::mutate(
      SEX = ifelse(SEX == 'Female', 1, 0)
    )
  return(tibble)
}

discretize_data <- function(tibble) {
  calculate_bins <- function(data_vector, max_bins = 4) {
  iqr <- IQR(data_vector)
  n <- length(data_vector)
  bin_width <- 2 * iqr / (n^(1/3))
  num_bins <- ceiling((max(data_vector) - min(data_vector)) / bin_width)
  # Restrict the number of bins to the maximum specified
  num_bins <- min(num_bins, max_bins)
  # Create breaks with the calculated number of bins
  return(num_bins)
  }

# Apply the function to each numeric column and add the binned data as new columns
  tibble <- tibble %>%
  mutate(across(where(is.numeric),
                ~ {
                  num_bins <- calculate_bins(.)
                  # Use cut with factor to create continuous bin indices
                  as.integer(factor(cut(., breaks = num_bins, include.lowest = TRUE), levels = unique(cut(., breaks = num_bins, include.lowest = TRUE)))) - 1
                },
                .names = "{.col}"))

  # Make columns integers
  tibble <- tibble %>%
        mutate(across(where(is.numeric), as.integer))

}


get_gold_standard_dag <- function(data) {
  dag <- igraph::make_empty_graph()
  # Add vertices in order of columns
  for (col_name in colnames(data)) {
    dag <- igraph::add_vertices(dag, 1, name = col_name)
  }
  # Add edges
  dag <- igraph::add_edges(dag, c(
                                  "AGE", "ABETA", # age -> abeta
                                  "APOE41", "ABETA", # apoe41 -> abeta
                                  "APOE42", "ABETA",
                                  "ABETA", "FDG",
                                  "ABETA", "PTAU",
                                  "EDU", "DX",
                                  "FDG", "DX",
                                  "PTAU", "DX"
                                  )
  )
  dx_node <- igraph::V(dag)["DX"]
  dag <- remove_node(dag, dx_node)
  return(dag)
}

# Training idxs
p_train <- 0.8

if (!file.exists(paste0("R/experiments/alzheimers/train_idxs_", p_train, ".rds"))) {
  set.seed(123)

  n_data <- 1500 # number of rows in the data

  n_train <- as.integer(p_train * n_data)
  train_idxs <- sample(seq_len(n_data), size = n_train)
  saveRDS(train_idxs, paste0("R/experiments/alzheimers/train_idxs_", p_train, ".rds"))

  # Discretize data and save to file for Tetrad
  merge_csv <- 'R/experiments/alzheimers/ADNIMERGE_15May2024.csv'
  tibble_initial <- read_csv(merge_csv)
  tibble_relevant <- get_data_base(tibble_initial)
  tibble_relevant <- discretize_diagnosis(tibble_relevant)
  tibble_relevant <- binarize_sex(tibble_relevant)

  tibble_relevant <- binarize_apoe(tibble_relevant)
  tibble_discretized <- discretize_data(tibble_relevant)

  # Keep only train idxs
  tibble_discretized <- tibble_discretized |>
    dplyr::slice(train_idxs)

  # Remove DX column
  tibble_discretized <- tibble_discretized |>
    dplyr::select(-DX)

  write_delim(tibble_discretized, 'R/experiments/alzheimers/tibble_discretized_apoe4142.csv', delim = ';')
}


