import os
import sys
import numpy as np
import random
import torch

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

from constants import (
    MechanismFamily,
    NoiseMode,
    NoiseDistribution,
    VariableDataType,
    QueryType,
    ErrorMetric,
)
from space_of_interest import SpaceOfInterest
from sampler import Sampler


def continuous_scm():
    # Step 1: Define a space of interest
    space1 = SpaceOfInterest(
        number_of_nodes=(5, 10),  # SCM will have between 5 to 10 variables
        variable_dimensionality=(1, 5),  # Variables have dimensionality between 1 and 5
        mechanism_family=MechanismFamily.LINEAR,  # Linear mechanisms for relationships
        expected_edges="2 * N",  # Expected number of edges per node
        noise_mode=NoiseMode.ADDITIVE,  # Additive noise
        noise_distribution=NoiseDistribution.GAUSSIAN,  # Gaussian noise
        noise_args=[0, 1],  # Mean 0, Std 1
        variable_type=VariableDataType.CONTINUOUS,  # Variables are continuous
        markovian=True,  # Markovian structure
        semi_markovian=False,  # Not semi-Markovian
        number_of_queries=3,  # Sample 3 queries
        query_type=QueryType.ATE,  # Queries are Average Treatment Effects
        number_of_data_points=100,  # 100,  # Sample 100 data points from the SCM <>
    )

    # Step 2: Initialize the sampler with the space of interest
    sampler = Sampler(space1, return_adjacency_matrix=False)

    # Step 3: Sample an SCM and get the results
    data, (queries, estimates), (graph, index_to_variable) = (
        sampler.generate_samples_and_queries()
    )

    # Step 4: Print the results
    print("Sampled Data (first 5 rows):")
    for var_id, values in data.items():
        print(f"{var_id}: {values[:5]}")  # Show the first 5 values for each variable

    print("\nSampled Queries and Their Estimates:")
    for query, estimate in zip(queries, estimates):
        print(f"Query: {query}, Estimate: {estimate}")

    print("\nGraph (Adjacency List):")
    print(graph, index_to_variable)
    for parent, children in graph.items():
        print(
            f"{index_to_variable[parent]} -> {[index_to_variable[child] for child in children]}"
        )


def discrete_scm():
    space2 = SpaceOfInterest(
        number_of_nodes=(3, 5),  # Fewer variables
        variable_dimensionality=(1, 1),  # Single-dimensional variables
        mechanism_family=MechanismFamily.TABULAR,  # Tabular mechanisms
        expected_edges=4,  # Low connectivity
        number_of_categories=(2, 3),
        noise_mode=NoiseMode.ADDITIVE,  # Additive noise
        noise_distribution=NoiseDistribution.UNIFORM,  # Uniform noise
        noise_args=[-1, 1],  # Uniform noise between -1 and 1
        variable_type=VariableDataType.DISCRETE,  # Discrete variables
        markovian=True,
        semi_markovian=False,
        number_of_queries=2,  # Two queries
        query_type=QueryType.CONDITIONAL,  # Conditional queries
        number_of_data_points=50,  # 50 data points
        number_of_noise_regions="V",
    )

    sampler2 = Sampler(space2, return_adjacency_matrix=True)
    data2, (queries2, estimates2), (graph2, index_to_variable2) = (
        sampler2.generate_samples_and_queries()
    )

    print("\nSampled Data from Space 2 (first 5 rows):")
    for var_id, values in data2.items():
        print(f"{var_id}: {values[:5]}")

    print("\nSampled Queries and Their Estimates:")
    for query, estimate in zip(queries2, estimates2):
        print(f"Query: {query}, Estimate: {estimate}")

    print("\nGraph (Adjacency Matrix):")
    print(graph2, index_to_variable2)


if __name__ == "__main__":
    print("Example 1: Continuous SCM")
    continuous_scm()
    print("Example 2: Discrete SCM")
    discrete_scm()
