library(distributions3, warn.conflicts = FALSE)
library(igraph, warn.conflicts = FALSE)

# source('R/scm/base_scm.R')


#' Constructor for linear-gaussian SCM objects
#'
#' This function creates a new SCM object based on the provided directed acyclic
#' graph (DAG), node functions, and noise distributions.
#'
#' @param dag An igraph object representing the DAG structure of the SCM.
#' @param node_functions A list of functions, one for each node in the DAG,
#' following the order returned by igraph::V(dag). Functions must be of type 'linear_function'.
#' @param noise_vars A list of distributions3 distributions, one for each node
#' in the DAG, following the order returned by igraph::V(dag). Distributions must be of type Normal.
#'
#' @return An object of class "scm" containing the DAG, node functions, and noise variables.
#'
#' @examples
#' \dontrun{
#' dag <- igraph::graph_from_literal(A -+ B, A -+ C, B -+ D, C -+ D)
#' node_functions <- list(function(parent_data, noise) noise,
#'                        function(parent_data, noise) -2 * parent_data + noise,
#'                        function(parent_data, noise) 3 * parent_data + noise,
#'                        function(parent_data, noise) parent_data %*% c(-1, 2) + noise)
#' noise_distributions <- list(distributions3::Normal(0,1), distributions3::Normal(0,1),
#'  distributions3::Normal(0,1), distributions3::Normal(0,1))
#' scm <- new_linear_gaussian_scm(dag, node_functions, noise_distributions)
#' }
new_linear_gaussian_scm <- function(dag, node_functions, noise_distributions) {
  stopifnot(all(sapply(noise_distributions, distributions3::is_distribution)))
  stopifnot(all(sapply(noise_distributions, function(x) inherits(x, "Normal"))))
  stopifnot(all(sapply(node_functions, function(x) return(inherits(x, "linear_node_function") | inherits(x, "linear_root_node_function")))))

  base_scm <- new_scm(dag, node_functions)
  base_scm$noise_distributions <- noise_distributions
  class(base_scm) <- c("linear_gaussian_scm", class(base_scm))
  return(base_scm)
}

#' Generate Data from a Structural Causal Model (SCM)
#'
#' This function generates synthetic data based on the provided SCM object.
#' It uses the directed acyclic graph (DAG), node functions, and noise distributions
#' specified in the SCM to generate data for each node in the graph.
#'
#' @param scm An object of class "scm" containing the DAG, node functions, and noise distributions.
#' @param n_samples The number of samples to generate.
#'
#' @return A matrix where each column corresponds to a node in the DAG and each row is a generated sample.
#' The column names of the matrix are the names of the nodes in the DAG.
#'
#' @examples
#' \dontrun{
#' dag <- igraph::graph_from_literal(A -+ B, B -+ C)
#' node_functions <- list(function(parent_data, noise) noise,
#'                        function(parent_data, noise) parent_data + noise,
#'                        function(parent_data, noise) 3 * parent_data + noise)
#' noise_distributions <- list(distributions3::Normal(0,1),
#'                             distributions3::Normal(0,1),
#'                             distributions3::Normal(0,1))
#' scm <- new_scm(dag, node_functions, noise_distributions)
#' data <- generate_data(scm, 100)
#' }
#'
generate_data.linear_gaussian_scm <- function(scm, n_samples, seed = NULL) {

  if (!is.null(seed)) {
    set.seed(seed)
  }

  dag <- scm$dag
  node_functions <- scm$node_functions
  noise_distributions <- scm$noise_distributions

  n_nodes <- igraph::vcount(dag)
  data <- matrix(0, nrow = n_samples, ncol = n_nodes)
  colnames(data) <- igraph::V(dag)$name

  topo_order <- as.numeric(igraph::topo_sort(dag))
  for (node_idx in topo_order) {
    parents <- igraph::neighbors(dag, node_idx, mode = "in")
    parents_idxs <- as.numeric(parents)

    parent_outputs <- data[, parents_idxs, drop = FALSE]

    noise <- distributions3::random(noise_distributions[[node_idx]], n_samples)
    data[, node_idx] <- node_functions[[node_idx]](parent_outputs, noise)
  }

  return(data)
}

#' @export
generate_data_under_intervention.linear_gaussian_scm <- function(scm, n_samples, interventions, seed = NULL) {

  if (!is.null(seed)) {
    set.seed(seed)
  }

  dag <- scm$dag
  node_functions <- scm$node_functions
  noise_distributions <- scm$noise_distributions

  n_nodes <- igraph::vcount(dag)
  data <- matrix(0, nrow = n_samples, ncol = n_nodes)
  colnames(data) <- igraph::V(dag)$name

  topo_order <- as.numeric(igraph::topo_sort(dag))
  for (node_idx in topo_order) {
    if (as.character(node_idx) %in% names(interventions)) {
      data[, node_idx] <- interventions[[as.character(node_idx)]]
      next
    }

    parents <- igraph::neighbors(dag, node_idx, mode = "in")
    parents_idxs <- as.numeric(parents)

    parent_outputs <- data[, parents_idxs, drop = FALSE]

    noise <- distributions3::random(noise_distributions[[node_idx]], n_samples)
    data[, node_idx] <- node_functions[[node_idx]](parent_outputs, noise)
  }

  return(data)
}

get_linear_root_node_function <- function() {
  func <- function(parent_data, noise) {
    return(noise)
  }
  class(func) <- "linear_root_node_function"
  return(func)
}


get_linear_function <- function(weights) {
  func <- function(parent_values, noise) {
    return(as.vector(parent_values %*% weights) + noise)
  }

  attr(func, "n_parents") <- length(weights)
  attr(func, "weights") <- weights
  class(func) <- "linear_node_function"
  return(func)
}


#' Function to create an SCM given an existing DAG.
#'
#' @param dag An igraph DAG object.
#' @param weight_function A function to sample edge weights from: the weight
#' each parent gets in each node's linear function. The function should take
#' a single parameter: the number of weights to return.
#' @param noise_distributions A list of distribution3 distributions to sample
#' noise from. If the list is of length 1, the same noise distribution is used
#' for each node. If not specified, a standard normal distribution is used for
#' all nodes.
#' @return An SCM object.
#' @examples
#'
sample_linear_gaussian_scm_given_dag <- function(dag, weight_function,
                                                 noise_variance_function, standardized = TRUE, seed = NULL) {
  if (!is.null(seed)) {
    set.seed(seed)
  }

  n_nodes <- igraph::vcount(dag)

  node_functions <- vector("list", length = n_nodes)
  noise_distributions <- vector("list", length = n_nodes)
  names(node_functions) <- igraph::V(dag)$name
  names(noise_distributions) <- igraph::V(dag)$name

  nodes <- igraph::V(dag)
  for (node_idx in nodes) {
    n_parents <- length(igraph::neighbors(dag, node_idx, mode = "in"))
    if (n_parents == 0) {
      node_functions[[node_idx]] <- get_linear_root_node_function()
      noise_distributions[[node_idx]] <- distributions3::Normal(0, 1)
      next
    } else {
      # Standardized means each node will have (approximately) unit variance
      if (standardized) {
        weights <- weight_function(n_parents) # Sample weights

        # Sample noise variance (for a proper signal-to-noise ratio the value returned should be between 0.33 and 0.66)
        # Why? Because we want the total variance of the node = variance(signal) + variance(noise) to be equal to 1
        noise_variance <- noise_variance_function(1)
        noise_std <- sqrt(noise_variance)

        # We scale the weights such that this node will have approximately unit variance
        # It is an approximation because we assume independence between parents
        # See Reisach et al. (2021) and Mooij et al. (2020) for more details
        signal_std <- sqrt(1 - noise_variance)
        weights_std <- sqrt(sum(weights^2))
        weights <- weights * (1 / weights_std) * signal_std
      }

      noise_distribution <- distributions3::Normal(0, noise_std)

      node_function <- create_linear_function(weights)
      node_functions[[node_idx]] <- node_function
      noise_distributions[[node_idx]] <- noise_distribution
    }
  }
  return(new_linear_gaussian_scm(dag, node_functions, noise_distributions))
}

print.linear_gaussian_scm <- function(x, ...) {
  scm <- x
  dag <- scm$dag
  node_functions <- scm$node_functions
  noise_distributions <- scm$noise_distributions
  for (node_idx in seq_along(node_functions)) {
    node_name <- igraph::V(dag)$name[[node_idx]]
    node_func <- node_functions[[node_idx]]
    noise_dist <- noise_distributions[[node_idx]]
    if (identical(attr(node_func, "class"), "linear_root_node_function")) {
      cat(sprintf("%s <- ", node_name))
      cat_distribution(noise_dist)
    }
    else if (identical(attr(node_func, "class"), "linear_node_function")) {
      cat_node_assignment_linear(scm, node_idx)
    } else {
      print('error')
    }
    cat("\n")
  }
}

#' Function to cat/print the assignment of a node in an SCM with a linear node function.
cat_node_assignment_linear <- function(scm, node_idx) {
  if (!identical(attr(scm$node_functions[[node_idx]], "class"), "linear_node_function")) {
    stop("Error. Not a linear node function.")
  }

  dag <- scm$dag
  node_name <- igraph::V(dag)$name[[node_idx]]
  node_func <- scm$node_functions[[node_idx]]
  noise_dist <- scm$noise_distributions[[node_idx]]


  cat(sprintf("%s <- ", node_name))
  n_parents <- attr(node_func, "n_parents")
  weights <- attr(node_func, "weights")

  parent_idxs <- igraph::neighbors(dag, node_idx, mode = "in")
  for (i in seq_len(n_parents)) {
    parent_name <- igraph::V(dag)$name[[parent_idxs[i]]]
    if (grepl("^[0-9]+$", parent_name)) {
      parent_name <- paste0("x_", parent_name)
    }

    cat(sprintf("%s * %f", parent_name, weights[i]))
    if (i < n_parents) {
      cat(" + ")
    }
  }
  cat(" + ")
  cat_distribution(noise_dist)
}
