library(distributions3, warn.conflicts = FALSE)
library(igraph, warn.conflicts = FALSE)

source('R/scm/base_scm.R')


#' Constructor for linear-gaussian SCM objects
#'
#' This function creates a new SCM object based on the provided directed acyclic
#' graph (DAG), node functions, and noise distributions.
#'
#' @param dag An igraph object representing the DAG structure of the SCM.
#' @param node_functions A list of functions, one for each node in the DAG,
#' following the order returned by igraph::V(dag). Functions must be of type 'linear_function'.
#' @param noise_vars A list of distributions3 distributions, one for each node
#' in the DAG, following the order returned by igraph::V(dag). Distributions must be of type Normal.
#'
#' @return An object of class "scm" containing the DAG, node functions, and noise variables.
#'
#' @examples
#' \dontrun{
#' dag <- igraph::graph_from_literal(A -+ B, A -+ C, B -+ D, C -+ D)
#' node_functions <- list(function(parent_data, noise) noise,
#'                        function(parent_data, noise) -2 * parent_data + noise,
#'                        function(parent_data, noise) 3 * parent_data + noise,
#'                        function(parent_data, noise) parent_data %*% c(-1, 2) + noise)
#' noise_distributions <- list(distributions3::Normal(0,1), distributions3::Normal(0,1),
#'  distributions3::Normal(0,1), distributions3::Normal(0,1))
#' scm <- new_linear_gaussian_scm(dag, node_functions, noise_distributions)
#' }
new_additive_noise_scm <- function(dag, node_functions, noise_distributions) {
  stopifnot(all(sapply(noise_distributions, distributions3::is_distribution)))
  stopifnot(all(sapply(node_functions, function(x) return(inherits(x, "additive_noise_function") |
                                                            inherits(x, "linear_root_node_function") |
                                                            inherits(x, "weighted_additive_noise_function")))))

  base_scm <- new_scm(dag, node_functions)
  base_scm$noise_distributions <- noise_distributions
  class(base_scm) <- c("additive_noise_scm", class(base_scm))
  return(base_scm)
}

#' Generate Data from a Structural Causal Model (SCM)
#'
#' This function generates synthetic data based on the provided SCM object.
#' It uses the directed acyclic graph (DAG), node functions, and noise distributions
#' specified in the SCM to generate data for each node in the graph.
#'
#' @param scm An object of class "scm" containing the DAG, node functions, and noise distributions.
#' @param n_samples The number of samples to generate.
#'
#' @return A matrix where each column corresponds to a node in the DAG and each row is a generated sample.
#' The column names of the matrix are the names of the nodes in the DAG.
#'
#' @examples
#' \dontrun{
#' dag <- igraph::graph_from_literal(A -+ B, B -+ C)
#' node_functions <- list(function(parent_data, noise) noise,
#'                        function(parent_data, noise) parent_data + noise,
#'                        function(parent_data, noise) 3 * parent_data + noise)
#' noise_distributions <- list(distributions3::Normal(0,1),
#'                             distributions3::Normal(0,1),
#'                             distributions3::Normal(0,1))
#' scm <- new_scm(dag, node_functions, noise_distributions)
#' data <- generate_data(scm, 100)
#' }
#'
generate_data.additive_noise_scm <- function(scm, n_samples, seed = NULL) {
  return(generate_data.linear_gaussian_scm(scm, n_samples, seed))
}

generate_data_under_intervention.additive_noise_scm <- function(scm, n_samples, interventions, seed = NULL) {
  return(generate_data_under_intervention.linear_gaussian_scm(scm, n_samples, interventions, seed))
}

get_additive_noise_function <- function(f, weights) {
  func <- function(parent_values, noise) {
    return(f(parent_values %*% weights) + noise)
  }

  attr(func, "f") <- f
  attr(func, "n_parents") <- length(weights)
  attr(func, "weights") <- weights
  class(func) <- "additive_noise_function"
  return(func)
}

get_weighted_additive_noise_function <- function(f, weights_parents, weight_scale) {
  func <- function(parent_values, noise) {
    return(weight_scale * f(parent_values %*% weights_parents) + noise)
  }

  attr(func, "f") <- f
  attr(func, "n_parents") <- length(weights_parents)
  attr(func, "weights_parents") <- weights_parents
  attr(func, "weight_scale") <- weight_scale
  class(func) <- "weighted_additive_noise_function"

  return(func)
}


#' Function to create an SCM given an existing DAG.
#'
#' @param dag An igraph DAG object.
#' @param weight_function A function to sample edge weights from: the weight
#' each parent gets in each node's linear function. The function should take
#' a single parameter: the number of weights to return.
#' @param noise_distributions A list of distribution3 distributions to sample
#' noise from. If the list is of length 1, the same noise distribution is used
#' for each node. If not specified, a standard normal distribution is used for
#' all nodes.
#' @return An SCM object.
#' @examples
#'
sample_weighted_additive_noise_scm_given_dag <- function(dag, weight_function_parents, weight_function_scale,
                                                 list_of_node_functions, list_of_noise_distributions, seed = NULL) {
  if (!is.null(seed)) {
    set.seed(seed)
  }

  n_nodes <- igraph::vcount(dag)

  node_functions <- vector("list", length = n_nodes)
  noise_distributions <- vector("list", length = n_nodes)
  names(node_functions) <- igraph::V(dag)$name
  names(noise_distributions) <- igraph::V(dag)$name

  nodes <- igraph::V(dag)
  for (node_idx in nodes) {
    n_parents <- length(igraph::neighbors(dag, node_idx, mode = "in"))
    if (n_parents == 0) {
      node_functions[[node_idx]] <- get_linear_root_node_function()
      noise_distributions[[node_idx]] <- sample(list_of_noise_distributions, 1)[[1]]
      next
    } else {
      noise_distribution <- sample(list_of_noise_distributions, 1)[[1]]
      node_function <- sample(list_of_node_functions, 1)[[1]]
      weights_parents <- weight_function_parents(n_parents)
      weight_scale <- weight_function_scale(1)
      node_function <- get_weighted_additive_noise_function(node_function, weights_parents, weight_scale)
      node_functions[[node_idx]] <- node_function
      noise_distributions[[node_idx]] <- noise_distribution
    }
  }
  return(new_additive_noise_scm(dag, node_functions, noise_distributions))
}



#' Function to create an SCM given an existing DAG.
#'
#' @param dag An igraph DAG object.
#' @param weight_function A function to sample edge weights from: the weight
#' each parent gets in each node's linear function. The function should take
#' a single parameter: the number of weights to return.
#' @param noise_distributions A list of distribution3 distributions to sample
#' noise from. If the list is of length 1, the same noise distribution is used
#' for each node. If not specified, a standard normal distribution is used for
#' all nodes.
#' @return An SCM object.
#' @examples
#'
sample_additive_noise_scm_given_dag <- function(dag, weight_function,
                                                 list_of_node_functions, list_of_noise_distributions) {
  n_nodes <- igraph::vcount(dag)

  node_functions <- vector("list", length = n_nodes)
  noise_distributions <- vector("list", length = n_nodes)
  names(node_functions) <- igraph::V(dag)$name
  names(noise_distributions) <- igraph::V(dag)$name

  nodes <- igraph::V(dag)
  for (node_idx in nodes) {
    n_parents <- length(igraph::neighbors(dag, node_idx, mode = "in"))
    if (n_parents == 0) {
      node_functions[[node_idx]] <- get_linear_root_node_function()
      noise_distributions[[node_idx]] <- sample(list_of_noise_distributions, 1)[[1]]
      next
    } else {
      noise_distribution <- sample(list_of_noise_distributions, 1)[[1]]
      node_function <- sample(list_of_node_functions, 1)[[1]]
      weights <- weight_function(n_parents)
      node_function <- get_additive_noise_function(node_function, weights)
      node_functions[[node_idx]] <- node_function
      noise_distributions[[node_idx]] <- noise_distribution
    }
  }
  return(new_additive_noise_scm(dag, node_functions, noise_distributions))
}

print.additive_noise_scm <- function(x, ...) {
  scm <- x
  dag <- scm$dag
  node_functions <- scm$node_functions
  noise_distributions <- scm$noise_distributions
  for (node_idx in seq_along(node_functions)) {
    node_name <- igraph::V(dag)$name[[node_idx]]
    node_func <- node_functions[[node_idx]]
    noise_dist <- noise_distributions[[node_idx]]
    if (identical(attr(node_func, "class"), "linear_root_node_function")) {
      cat(sprintf("%s <- ", node_name))
      cat_distribution(noise_dist)
    }
    else if (identical(attr(node_func, "class"), "additive_noise_function")) {
      cat_node_assignment_weighted_additive_noise(scm, node_idx)
    }
    else if (identical(attr(node_func, "class"), "weighted_additive_noise_function")) {
      cat_node_assignment_weighted_additive_noise(scm, node_idx, attr(node_func, "weight_scale"))
    } else {
      print('error')
    }
    cat("\n")
  }
}

#' Function to cat/print the assignment of a node in an SCM with a linear node function.
cat_node_assignment_weighted_additive_noise <- function(scm, node_idx, weight = 1) {
  if (!(inherits(scm$node_functions[[node_idx]], "additive_noise_function") |
        inherits(scm$node_functions[[node_idx]], "weighted_additive_noise_function"))) {
    stop("Error. Not an additive noise function.")
  }

  dag <- scm$dag
  node_name <- igraph::V(dag)$name[[node_idx]]
  node_func <- scm$node_functions[[node_idx]]
  noise_dist <- scm$noise_distributions[[node_idx]]


  cat(sprintf("%s <- ", node_name))
  if (weight != 1) {
    cat(sprintf("%f * ", weight))
  }
  if (identical(attr(node_func, "f"), plogis)) {
    cat("sigmoid(")
  }
  n_parents <- attr(node_func, "n_parents")
  weights <- attr(node_func, "weights")

  parent_idxs <- igraph::neighbors(dag, node_idx, mode = "in")
  for (i in seq_len(n_parents)) {
    parent_name <- igraph::V(dag)$name[[parent_idxs[i]]]
    if (grepl("^[0-9]+$", parent_name)) {
      parent_name <- paste0("x_", parent_name)
    }

    cat(sprintf("%s * %f", parent_name, weights[i]))
    if (i < n_parents) {
      cat(" + ")
    }
  }

  cat(")")
  cat(" + ")
  cat_distribution(noise_dist)
}
