import os
import sys
import numpy as np
import random
import torch

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

from scm import SCM
from variable import Variable
from constants import (
    VariableDataType,
    MechanismFamily,
    NoiseDistribution,
    NoiseMode,
    NeuralNetworkType,
)


def example_linear_scm():
    """Linear SCM with Additive Noise"""
    # Define variables
    X = Variable(name="X", variable_type=VariableDataType.CONTINUOUS)
    Y = Variable(name="Y", variable_type=VariableDataType.CONTINUOUS)
    Z = Variable(name="Z", variable_type=VariableDataType.CONTINUOUS)

    # Initialize SCM
    scm = SCM(
        variables=[X, Y, Z],
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.ADDITIVE,
        noise_args=[0, 1],
    )

    # Add edges
    scm.add_edge(X, Y)
    scm.add_edge(Y, Z)

    # Set mechanisms
    scm.set_function(X, MechanismFamily.LINEAR)
    scm.set_function(Y, MechanismFamily.LINEAR)
    scm.set_function(Z, MechanismFamily.LINEAR)

    # Sample noise variables
    scm.sample_noise_variables()

    # Compute variables
    scm.compute_variable(Z, backwards=True)

    # Print values
    print(f"X: {X.value}")
    print(f"Y: {Y.value}")
    print(f"Z: {Z.value}")


def example_nn_scm_multiplicative():
    """SCM with Neural Network Mechanism and Multiplicative Noise"""
    # Define variables
    A = Variable(name="A", variable_type=VariableDataType.CONTINUOUS)
    B = Variable(name="B", variable_type=VariableDataType.CONTINUOUS)
    C = Variable(name="C", variable_type=VariableDataType.CONTINUOUS)

    # Initialize SCM
    scm = SCM(
        variables=[A, B, C],
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.MULTIPLICATIVE,
        noise_args=[0, 1],
    )

    # Add edges
    scm.add_edge(A, B)
    scm.add_edge(B, C)

    # Set mechanisms
    scm.set_function(
        A,
        MechanismFamily.NEURAL_NETWORK,
        mechanism_args=[NeuralNetworkType.FEEDFORWARD, 3],
    )
    # scm.set_function(A, MechanismFamily.NEURAL_NETWORK, mechanism_args=[NeuralNetworkType.FEEDFORWARD, 3])
    scm.set_function(
        B,
        MechanismFamily.NEURAL_NETWORK,
        mechanism_args=[NeuralNetworkType.FEEDFORWARD, 10, 5],
    )
    scm.set_function(C, MechanismFamily.LINEAR)

    # Sample noise variables
    scm.sample_noise_variables()

    # Compute variables
    scm.compute_variable(C, backwards=True)

    # Print values
    print(f"A: {A.value}")
    print(f"B: {B.value}")
    print(f"C: {C.value}")


def example_discrete_tabular_scm():
    """Discrete SCM with Tabular Mechanism"""
    # Define variables
    X = Variable(
        name="X",
        variable_type=VariableDataType.DISCRETE,
        num_discrete_values=2,
    )
    Y = Variable(
        name="Y",
        variable_type=VariableDataType.DISCRETE,
        num_discrete_values=2,
    )
    Z = Variable(
        name="Z",
        variable_type=VariableDataType.DISCRETE,
        num_discrete_values=2,
    )
    # Initialize SCM
    scm = SCM(
        variables=[X, Y, Z],
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_mode=NoiseMode.ADDITIVE,
        noise_args=[0, 1],
    )
    # Manually set noise regions of exogenous variables
    [U_X_name] = scm.parents[X.name]
    scm.variables[U_X_name].noise_regions = [1.0]
    [U_Y_name] = scm.parents[Y.name]
    scm.variables[U_Y_name].noise_regions = [1.0]
    [U_Z_name] = scm.parents[Z.name]
    scm.variables[U_Z_name].noise_regions = [1.0]

    # Add edges
    scm.add_edges([(X, Z), (Y, Z)])

    # Set mechanism for Z using a tabular function
    # Incomplete mechanisms but you know that U_Z will be 0 because of the noise region
    mechanism_args = [
        ([(0, 0, 0)], 1),
        ([(0, 1, 0), (1, 0, 0)], 2),
        ([(1, 1, 0)], 3),
    ]
    scm.set_function(Z, MechanismFamily.TABULAR, mechanism_args=mechanism_args)

    scm.sample_noise_variables()
    # Assign values to X and Y (Intervention)
    scm.do_hard_intervention(X, 1)
    scm.do_hard_intervention(Y, 1)
    # Compute Z
    scm.compute_variable(Z, backwards=True)

    # Print values
    print(f"X: {X.value}")
    print(f"Y: {Y.value}")
    print(f"Z: {Z.value}")


def example_gaussian_mixture_scm():
    """SCM with Gaussian Mixture Noise"""
    # Define variables
    D = Variable(name="D", variable_type=VariableDataType.CONTINUOUS)
    E = Variable(name="E", variable_type=VariableDataType.CONTINUOUS)

    # Initialize SCM
    scm = SCM(
        variables=[D, E],
        noise_distribution=NoiseDistribution.GAUSSIAN_MIXTURE,
        noise_mode=NoiseMode.ADDITIVE,
        noise_args=[
            -2,
            1,
            0.3,
            2,
            1,
            0.7,
        ],  # [mean1, std1, weight1, mean2, std2, weight2]
    )

    # Add edge
    scm.add_edge(D, E)

    # Set mechanisms
    scm.set_function(D, MechanismFamily.LINEAR)
    scm.set_function(E, MechanismFamily.LINEAR)

    # Sample noise variables
    scm.sample_noise_variables()

    # Compute variables
    scm.compute_variable(E, backwards=True)

    # Print values
    print(f"D: {D.value}")
    print(f"E: {E.value}")


def example_nn_scm_functional():
    """SCM with Neural Network Mechanism and Functional Noise"""
    # Define variables
    A = Variable(name="A", variable_type=VariableDataType.CONTINUOUS)
    B = Variable(name="B", variable_type=VariableDataType.CONTINUOUS)
    C = Variable(name="C", variable_type=VariableDataType.CONTINUOUS)

    # Initialize SCM
    scm = SCM(
        variables=[A, B, C],
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.FUNCTIONAL,
        noise_args=[0, 1],
    )

    # Add edges
    scm.add_edge(A, B)
    scm.add_edge(B, C)

    # Set mechanisms
    scm.set_function(
        A,
        MechanismFamily.NEURAL_NETWORK,
        mechanism_args=[NeuralNetworkType.FEEDFORWARD, 3],
    )
    # scm.set_function(A, MechanismFamily.NEURAL_NETWORK, mechanism_args=[NeuralNetworkType.FEEDFORWARD, 3])
    scm.set_function(
        B,
        MechanismFamily.NEURAL_NETWORK,
        mechanism_args=[NeuralNetworkType.FEEDFORWARD, 10, 5],
    )
    scm.set_function(C, MechanismFamily.LINEAR)

    # Sample noise variables
    scm.sample_noise_variables()

    # Compute variables
    scm.compute_variable(C, backwards=True)

    # Print values
    print(f"A: {A.value}")
    print(f"B: {B.value}")
    print(f"C: {C.value}")


def example_multi_dim_variable():
    A = Variable(name="A", variable_type=VariableDataType.CONTINUOUS, dimensionality=2)
    B = Variable(name="B", variable_type=VariableDataType.CONTINUOUS, dimensionality=3)
    scm = SCM(
        variables=[A, B],
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.FUNCTIONAL,
        noise_args=[0, 1],
    )
    scm.add_edge(B, A)
    scm.set_function(
        A,
        MechanismFamily.NEURAL_NETWORK,
        mechanism_args=[NeuralNetworkType.FEEDFORWARD, 3],
    )
    scm.set_function(
        B,
        MechanismFamily.NEURAL_NETWORK,
        mechanism_args=[NeuralNetworkType.FEEDFORWARD, 3],
    )
    scm.sample_noise_variables()
    scm.compute_variable(A, backwards=True)
    print(f"A: {A.value}")


def example_sample_linear_scm():
    # Define variables
    X = Variable(name="X", variable_type=VariableDataType.CONTINUOUS)
    Y = Variable(name="Y", variable_type=VariableDataType.CONTINUOUS)
    Z = Variable(name="Z", variable_type=VariableDataType.CONTINUOUS)
    scm = SCM(
        variables=[X, Y, Z],
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.ADDITIVE,
        noise_args=[0, 1],
    )
    scm.add_edge(X, Y)
    scm.add_edge(Y, Z)
    scm.set_function(X, MechanismFamily.LINEAR)
    scm.set_function(Y, MechanismFamily.LINEAR)
    scm.set_function(Z, MechanismFamily.LINEAR)
    scm.sample_data(total_samples=1000, batch_size=1000)
    # scm.sample_data(total_samples=100000, batch_size=100000)


# Run the examples
if __name__ == "__main__":
    print("Example 1: Linear SCM with Additive Noise")
    example_linear_scm()
    print("Example 2: SCM with Neural Network Mechanism and Multiplicative Noise")
    example_nn_scm_multiplicative()
    print("Example 3: Discrete SCM with Tabular Mechanism")
    example_discrete_tabular_scm()
    print("Example 4: SCM with Gaussian Mixture Noise")
    example_gaussian_mixture_scm()
    print("Example 5: SCM with Neural Network Mechanism and Functional Noise")
    example_nn_scm_functional()
    print("Example 6: Including multidimensional variables")
    example_multi_dim_variable()
    print("Example 7: Sample data from linear SCM")
    example_sample_linear_scm()
