import os
import sys
import torch
import numpy as np
import random
from typing import List
import multiprocessing as mp
import json
from collections import defaultdict, deque

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

from constants import (
    NoiseDistribution,
    NoiseMode,
    VariableDataType,
    MechanismFamily,
    NeuralNetworkType,
)
from scm import SCM
from variable import Variable
from sampler import Sampler
from query import Query
from space_of_interest import SpaceOfInterest
from query_estimator import QueryEstimator


def evaluate_ATE_simple_SCM():
    query_estimator = QueryEstimator(n_samples=10000)
    # Create SCM
    X = Variable(name="X", variable_type=VariableDataType.CONTINUOUS)
    Y = Variable(name="Y", variable_type=VariableDataType.CONTINUOUS)
    Z = Variable(name="Z", variable_type=VariableDataType.CONTINUOUS)
    scm = SCM(
        variables=[X, Y, Z],
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.FUNCTIONAL,
        noise_args=[0, 1],
    )
    scm.add_edges([(Z, X), (Z, Y), (X, Y)])
    for variable in [X, Y, Z]:
        scm.set_function(
            variable,
            MechanismFamily.LINEAR,
        )
    # Create ATE Query: E[Y | do(X=1)] - E[Y | do(X=0)]
    query = Query.createL2ATE(Y=Y, T=X, T1_value=1, T0_value=0)

    # Evaluate query
    estimate = query_estimator.evaluate_query(scm, query)
    print(f"Estimate: {estimate}")


def evaluate_CONDITIONAL_discrete_SCM():
    query_estimator = QueryEstimator(n_samples=10000)
    X = Variable(
        name="X",
        variable_type=VariableDataType.DISCRETE,
        num_discrete_values=2,
    )
    Y = Variable(
        name="Y",
        variable_type=VariableDataType.DISCRETE,
        num_discrete_values=2,
    )
    scm = SCM(
        variables=[X, Y],
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_mode=NoiseMode.ADDITIVE,
        noise_args=[0, 1],
    )
    # Manually set noise regions of exogenous variables
    [U_X_name] = scm.parents[X.name]
    scm.variables[U_X_name].noise_regions = [0.2]
    [U_Y_name] = scm.parents[Y.name]
    scm.variables[U_Y_name].noise_regions = [0.1, 0.6]
    scm.add_edges([(Y, X)])
    # Set mechanism for Y using a tabular function (includes U_Y)
    mechanism_args_y = [
        ([(0,), (2,)], 1),
        ([(1,)], 0),
    ]
    scm.set_function(Y, MechanismFamily.TABULAR, mechanism_args=mechanism_args_y)
    mechanism_args_x = [
        ([(0, 0), (1, 1)], 0),
        ([(1, 0), (0, 1)], 1),
    ]
    scm.set_function(X, MechanismFamily.TABULAR, mechanism_args=mechanism_args_x)

    # Create CONDITIONAL Query: P(Y=1 | X=0)
    query = Query.createL1Conditional(Y=Y, X=X, Y_value=1, X_value=0)

    # Evaluate query
    estimate = query_estimator.evaluate_query(scm, query)
    print(f"Estimate: {estimate}")


def main():
    print("1. Evaluating ATE on a simple SCM")
    evaluate_ATE_simple_SCM()
    print("2. Evaluate Conditional on a discrete SCM")
    evaluate_CONDITIONAL_discrete_SCM()


if __name__ == "__main__":
    main()
