import os
import sys
import numpy as np
import random
import torch

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

from scm import SCM
from variable import Variable
from constants import (
    VariableDataType,
    MechanismFamily,
    NoiseDistribution,
    NoiseMode,
    NeuralNetworkType,
)
from mechanism import *


def additive_multidim_noise_to_mechanism():
    expression = np.random.randn(10, 3)  # Deterministic part of the mechanism
    noise1 = Variable(name="Noise1", value=np.random.randn(10, 1), dimensionality=1)
    noise2 = Variable(name="Noise2", value=np.random.randn(10, 3), dimensionality=3)
    result = add_noise_to_mechanism(expression, [noise1, noise2], NoiseMode.ADDITIVE)
    print(result.shape)  # Output: (10, 3)


def multiplicative_noise_with_scalar_and_multidim_vars():
    expression = np.random.randn(5, 4)
    noise1 = Variable(name="Noise1", value=np.random.randn(5, 1), dimensionality=1)
    noise2 = Variable(name="Noise2", value=np.random.randn(5, 4), dimensionality=4)
    result = add_noise_to_mechanism(
        expression, [noise1, noise2], NoiseMode.MULTIPLICATIVE
    )
    print(result.shape)  # Output: (5, 4)


def no_noise_vars():
    # Expression should be unchanged
    expression = np.ones((10, 3))
    result = add_noise_to_mechanism(expression, [], NoiseMode.ADDITIVE)
    assert np.array_equal(result, expression)
    print("Expression correctly unchanged!")


if __name__ == "__main__":
    print("1. Additive Noise with Multi-Dimensional Variables")
    additive_multidim_noise_to_mechanism()
    print("2. Multiplicative Noise with Scalar and Multi-Dimensional Variables")
    multiplicative_noise_with_scalar_and_multidim_vars()
