# Copyright (c) 2025-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
##################################################################

import numpy as np


class SCM:
    def __init__(self):
        self.variables = {}

    def add_variable(self, variable, value_or_equation):
        self.variables[variable] = value_or_equation

    def do(self, variable, value):
        """
        Identical to `add_variable()`.

        This changes the graph, too, since functions are not allowed:
        After intervention on a variable, the connection between its parents and itself is removed.
        """

        if self.variables.get(variable) is None:
            raise ValueError("The variable is not defined: {}".format(variable))

        if callable(value):
            raise ValueError("The value to intervene with must be a value, not a function.")

        self.add_variable(variable, value)

    def sample(self, variable):
        """
        apply the SCM function that generates `variable`
        """

        if self.variables.get(variable) is None:
            raise ValueError("The variable is not defined: {}".format(variable))

        value_or_equation = self.variables[variable]
        parents = self.get_parents(variable)
        parent_values = [self.sample(parent) for parent in parents]

        if callable(value_or_equation):
            return value_or_equation(*parent_values)
        else:
            return value_or_equation

    def get_parents(self, child):
        if self.variables.get(child) is None:
            raise ValueError("The variable is not defined: {}".format(child))

        parents = None
        for var, value_or_equation in self.variables.items():
            if child == var:
                if callable(value_or_equation):
                    parents = get_function_arguments(value_or_equation)
                else:
                    parents = []

                break

        return parents

    def get_ancestors(self, descendant):
        if self.variables.get(descendant) is None:
            raise ValueError("The variable is not defined: {}".format(descendant))

        ancestors = []

        parents = self.get_parents(descendant)
        ancestors.extend(parents)

        if len(parents) > 0:
            for parent in parents:
                ancestors.extend(self.get_ancestors(parent))

        ancestors = list(set(ancestors))
        return ancestors

    def get_children(self, parent):
        if self.variables.get(parent) is None:
            raise ValueError("The variable is not defined: {}".format(parent))

        children = []
        for var, value_or_equation in self.variables.items():
            if callable(value_or_equation):
                parents = get_function_arguments(value_or_equation)
                if parent in parents:
                    children.append(var)
            else:
                # if a variable has no parents, then it is not a child of any variable.
                pass

        return children


def get_function_arguments(scm_equation):
    """
    any local variables inside the function are also included in `co_varnames`,
    therefore, ignore them and only take the arguments.
    """
    input_arguments = list(scm_equation.__code__.co_varnames[:scm_equation.__code__.co_argcount])
    return input_arguments
