import os
import sys
import torch
import numpy as np
import random
from typing import List
import multiprocessing as mp
import json
from collections import defaultdict, deque

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

from constants import NoiseDistribution, NoiseMode, VariableDataType, MechanismFamily
from scm import SCM
from variable import Variable

NUM_SCMS = 100
NUM_VARIABLES_PER_SCM = 50
NUM_SAMPLES = 100000  # 100k samples

NUM_SCMS = 1
NUM_VARIABLES_PER_SCM = 20
NUM_SAMPLES = 100000  # 100k samples


# Function to check if adding an edge introduces a cycle
def introduces_cycle(graph, start, end):
    visited = set()
    stack = deque([start])

    while stack:
        node = stack.pop()
        if node == end:  # A cycle is formed
            return True
        if node not in visited:
            visited.add(node)
            stack.extend(graph[node])
    return False


def generate_random_scm(num_variables: int) -> SCM:
    variables = []
    for i in range(num_variables):
        var = Variable(
            name=f"X{i}",
            dimensionality=10,
            exogenous=False,
            variable_type=VariableDataType.CONTINUOUS,
        )
        variables.append(var)

    scm = SCM(
        variables=variables,
        noise_distribution=NoiseDistribution.GAUSSIAN,
        noise_mode=NoiseMode.ADDITIVE,
        noise_args=[0, 1],  # mean=0, std=1
    )

    # Adjacency list representation of the graph
    graph = defaultdict(list)
    # Randomly add edges
    for from_var in variables:
        for to_var in variables:
            if not introduces_cycle(graph, to_var.name, from_var.name):
                if (
                    from_var != to_var and random.random() < 0.3
                ):  # 30% chance to add an edge
                    scm.add_edge(from_var, to_var)
                    graph[from_var.name].append(to_var.name)

    # Set mechanisms for variables
    for var in variables:
        scm.set_function(variable=var, mechanism_family=MechanismFamily.LINEAR)

    return scm


def sample_from_scm_new(scm_index):
    num_variables_per_scm = NUM_VARIABLES_PER_SCM
    scm = generate_random_scm(num_variables_per_scm)
    data = scm.sample_data(total_samples=NUM_SAMPLES, batch_size=NUM_SAMPLES)
    return scm_index, data


def sample_from_scm(scm_index):
    num_variables_per_scm = NUM_VARIABLES_PER_SCM
    scm = generate_random_scm(num_variables_per_scm)
    scm.n_samples = NUM_SAMPLES
    # Sample data.
    scm.reset_values()
    scm.sample_noise_variables()
    scm.compute_variables()
    data = {}
    for var_id, var in scm.variables.items():
        data[var_id] = var.value
    return scm_index, data


def sample_data_from_scm_list(num_scms: int):
    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = pool.map(
            sample_from_scm_new, range(num_scms)
        )  # queue of SCMs and sample in parallel?
        # results = pool.map(sample_from_scm, range(num_scms))
    data_all_scms = {}
    for scm_index, data in results:
        data_all_scms[scm_index] = data
    return data_all_scms


def main():
    num_scms = NUM_SCMS

    # Sample data from SCMs in parallel
    data_all_scms = sample_data_from_scm_list(num_scms)
    print("Generated data")

    # data_all_scms is a dictionary where each key is the SCM index
    # and each value is a dictionary of variable data for that SCM
    print("Saving dictionary of data")
    # np.load(filename, allow_pickle=True)
    # data_all_scms = {str(k): v for k, v in data_all_scms.items()}
    # np.savez_compressed("data_all_scms.npz", **data_all_scms)


if __name__ == "__main__":
    main()
