import os
import sys
import numpy as np
import random
import torch

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from causal_profiler.constants import (
    MechanismFamily,
    NoiseMode,
    NoiseDistribution,
    VariableDataType,
    QueryType,
    ErrorMetric,
)
from causal_profiler.space_of_interest import SpaceOfInterest
from causal_profiler.causal_profiler import CausalProfiler


def continuous_scm():
    # Step 1: Define a space of interest
    space1 = SpaceOfInterest(
        number_of_nodes=(5, 10),  # SCM will have between 5 to 10 variables
        variable_dimensionality=(1, 5),  # Variables have dimensionality between 1 and 5
        mechanism_family=MechanismFamily.LINEAR,  # Linear mechanisms for relationships
        expected_edges="2 * N",  # Total expected number of edges
        noise_mode=NoiseMode.ADDITIVE,  # Additive noise
        noise_distribution=NoiseDistribution.GAUSSIAN,  # Gaussian noise
        noise_args=[0, 1],  # Mean 0, Std 1
        variable_type=VariableDataType.CONTINUOUS,  # Variables are continuous
        number_of_queries=3,  # Sample 3 queries
        query_type=QueryType.ATE,  # Queries are Average Treatment Effects
        number_of_data_points=100,  # Sample 100 data points from the SCM
    )

    # Step 2: Initialize the causal profiler with the space of interest
    profiler = CausalProfiler(
        space_of_interest=space1, metric=ErrorMetric.L2, return_adjacency_matrix=False
    )

    # Step 3: Sample an SCM and get the results
    data, (queries, estimates), (graph, index_to_variable) = (
        profiler.generate_samples_and_queries()
    )

    # Step 4: Print the results
    print("Sampled Data (first 5 rows):")
    for var_id, values in data.items():
        print(f"{var_id}: {values[:5]}")  # Show the first 5 values for each variable

    print("\nSampled Queries and Their Estimates:")
    for query, estimate in zip(queries, estimates):
        print(f"Query: {query}, Estimate: {estimate}")

    print("\nGraph (Adjacency List):")
    print(graph, index_to_variable)
    for parent, children in graph.items():
        print(
            f"{index_to_variable[parent]} -> {[index_to_variable[child] for child in children]}"
        )

    # Step 5: Assume user provides their own estimates and evaluate error
    user_estimates = [
        estimate + 0.1 * (i % 2 - 0.5) for i, estimate in enumerate(estimates)
    ]  # Dummy user estimates
    error_l2, num_failed = profiler.evaluate_error(
        estimated=user_estimates, target=estimates
    )
    print(f"\nL2 Error between sampled and user-provided estimates: {error_l2}")
    print(f"Number of failed to estimate queries: {num_failed}")


def discrete_scm():
    space2 = SpaceOfInterest(
        number_of_nodes=(3, 5),  # Fewer variables
        variable_dimensionality=(1, 1),  # Single-dimensional variables
        mechanism_family=MechanismFamily.TABULAR,  # Tabular mechanisms
        expected_edges=4,  # Low connectivity
        number_of_categories=(2, 3),
        noise_mode=NoiseMode.ADDITIVE,  # Additive noise
        noise_distribution=NoiseDistribution.UNIFORM,  # Uniform noise
        noise_args=[-1, 1],  # Uniform noise between -1 and 1
        variable_type=VariableDataType.DISCRETE,  # Discrete variables
        number_of_queries=2,  # Two queries
        query_type=QueryType.CONDITIONAL,  # Conditional queries
        number_of_data_points=50,  # 50 data points
    )

    profiler2 = CausalProfiler(
        space2, metric=ErrorMetric.MAPE, return_adjacency_matrix=True
    )
    data2, (queries2, estimates2), (graph2, index_to_variable2) = (
        profiler2.generate_samples_and_queries()
    )

    print("\nSampled Data from Space 2 (first 5 rows):")
    for var_id, values in data2.items():
        print(f"{var_id}: {values[:5]}")

    print("\nGraph (Adjacency Matrix):")
    print(graph2, index_to_variable2)

    # Compute another metric
    error_mape, num_failed = profiler2.evaluate_error(
        # estimated=[e + 0.05 for e in estimates2[:-1]] + [float("NaN")],
        estimated=[e + 0.05 for e in estimates2],
        target=estimates2,
    )
    print(
        f"\nMAPE Error between sampled and user-provided estimates (Space 2): {error_mape}"
    )
    print(f"Number of failed to estimate queries: {num_failed}")


if __name__ == "__main__":
    print("Example 1: Continuous SCM")
    continuous_scm()
    print("Example 2: Discrete SCM")
    discrete_scm()
