import networkx as nx
import numpy as np
import random
import os
import json


class CausalGraphGenerator:
    """Flexible causal graph (DAG) generator with control of various network parameters"""

    def __init__(self, seed=None):
        """Initialize generator with optional random seed"""
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)

    def generate_custom_dag(self,
                            n,
                            expected_density=0.2,
                            max_parents=None,
                            max_children=None,
                            confounders_ratio=0.1,
                            colliders_ratio=0.1,
                            mediator_chains=0,
                            verbose=False):
        """Generate DAG with custom properties"""
        # Calculate expected number of edges
        max_possible_edges = n * (n - 1) // 2
        expected_edges = int(expected_density * max_possible_edges)

        # Create empty directed graph
        G = nx.DiGraph()
        G.add_nodes_from(range(n))

        # Calculate edge probability
        p = 2 * expected_edges / (n * (n - 1))

        # Add edges
        edge_count = 0
        for i in range(n):
            for j in range(i + 1, n):
                if random.random() < p:
                    # Check constraints
                    if max_parents is not None and G.in_degree(j) >= max_parents:
                        continue
                    if max_children is not None and G.out_degree(i) >= max_children:
                        continue

                    G.add_edge(i, j)
                    edge_count += 1

        # Add confounder structures
        if confounders_ratio > 0:
            self._add_confounders(G, confounders_ratio)

        # Add collider structures
        if colliders_ratio > 0:
            self._add_colliders(G, colliders_ratio)

        # Add mediator chains
        if mediator_chains > 0:
            self._add_mediator_chains(G, mediator_chains, min_length=2, max_length=4)

        return G

    def _add_confounders(self, G, ratio):
        """Add confounders to graph"""
        n = G.number_of_nodes()
        num_confounders = max(1, int(n * ratio))
        confounders_added = 0

        for _ in range(num_confounders):
            if n < 3:
                break

            for attempt in range(10):
                nodes = sorted(random.sample(range(n), 3))
                confounder, child1, child2 = nodes

                if (G.has_edge(child1, confounder) or
                        G.has_edge(child2, confounder) or
                        nx.has_path(G, child1, confounder) or
                        nx.has_path(G, child2, confounder)):
                    continue

                if G.has_edge(child1, child2):
                    G.remove_edge(child1, child2)
                elif G.has_edge(child2, child1):
                    G.remove_edge(child2, child1)

                G.add_edge(confounder, child1)
                G.add_edge(confounder, child2)
                confounders_added += 1
                break

        return confounders_added

    def _add_colliders(self, G, ratio):
        """Add collider structures to graph"""
        n = G.number_of_nodes()
        num_colliders = max(1, int(n * ratio))
        colliders_added = 0

        for _ in range(num_colliders):
            if n < 3:
                break

            for attempt in range(10):
                nodes = random.sample(range(n), 3)
                parent1, parent2, collider = nodes

                if (G.has_edge(collider, parent1) or
                        G.has_edge(collider, parent2) or
                        nx.has_path(G, collider, parent1) or
                        nx.has_path(G, collider, parent2)):
                    continue

                if G.has_edge(parent1, parent2):
                    G.remove_edge(parent1, parent2)
                elif G.has_edge(parent2, parent1):
                    G.remove_edge(parent2, parent1)

                G.add_edge(parent1, collider)
                G.add_edge(parent2, collider)
                colliders_added += 1
                break

        return colliders_added

    def _add_mediator_chains(self, G, num_chains, min_length=2, max_length=4):
        """Add mediator chains to graph"""
        n = G.number_of_nodes()
        chains_added = 0

        for _ in range(num_chains):
            chain_length = random.randint(min_length, max_length)

            if n < chain_length + 1:
                break

            for attempt in range(10):
                start, end = random.sample(range(n), 2)

                if G.has_edge(end, start) or nx.has_path(G, end, start):
                    continue

                if G.has_edge(start, end):
                    G.remove_edge(start, end)

                available_nodes = [node for node in range(n) if node != start and node != end]
                if len(available_nodes) < chain_length - 1:
                    continue

                mediators = random.sample(available_nodes, chain_length - 1)
                chain = [start] + mediators + [end]
                valid_chain = True

                for i in range(len(chain) - 1):
                    if G.has_edge(chain[i + 1], chain[i]) or nx.has_path(G, chain[i + 1], chain[i]):
                        valid_chain = False
                        break

                if not valid_chain:
                    continue

                for i in range(len(chain) - 1):
                    G.add_edge(chain[i], chain[i + 1])

                chains_added += 1
                break

        return chains_added


def batch_generate_custom_dags(num_graphs=10000, seed=None,
                               min_nodes=3, max_nodes=10,
                               expected_density=0.3,
                               max_parents_ratio=0.5,
                               max_children_ratio=0.6,
                               confounders_ratio=0.2,
                               colliders_ratio=0.2,
                               mediator_chains_ratio=0.2):
    """Batch generate custom DAGs with fixed parameters and random node count

    Parameters:
    -----------
    num_graphs : int
        Number of graphs to generate
    seed : int, optional
        Random seed for reproducibility
    min_nodes, max_nodes : int
        Range for random node count
    expected_density : float
        Fixed edge density parameter
    max_parents_ratio, max_children_ratio : float
        Ratio of max parents/children relative to node count
    confounders_ratio, colliders_ratio : float
        Fixed ratios for structural parameters
    mediator_chains_ratio : float
        Ratio determining number of mediator chains relative to node count
    """
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    generator = CausalGraphGenerator(seed=seed)
    results = []

    for i in range(num_graphs):
        # Random number of nodes within specified range
        n = random.randint(min_nodes, max_nodes)

        # Scale constraints proportionally to graph size
        max_parents = max(1, int(n * max_parents_ratio))
        max_children = max(1, int(n * max_children_ratio))

        # Scale mediator chains based on graph size
        mediator_chains = max(0, int(n * mediator_chains_ratio))

        # Parameters dictionary
        params = {
            "n": n,
            "expected_density": expected_density,
            "max_parents": max_parents,
            "max_children": max_children,
            "confounders_ratio": confounders_ratio,
            "colliders_ratio": colliders_ratio,
            "mediator_chains": mediator_chains
        }

        # Generate graph
        G = generator.generate_custom_dag(
            n=n,
            expected_density=expected_density,
            max_parents=max_parents,
            max_children=max_children,
            confounders_ratio=confounders_ratio,
            colliders_ratio=colliders_ratio,
            mediator_chains=mediator_chains,
            verbose=False
        )

        # Get adjacency matrix
        adj_matrix = nx.to_numpy_array(G, dtype=int)

        # Add to results
        results.append((G, params, adj_matrix))

    return results


def save_results(results, output_dir="dag_data"):
    """Save generated DAG data to files"""
    os.makedirs(output_dir, exist_ok=True)

    # Create list with all parameters for bulk saving
    all_params = []

    # Save each graph
    for idx, (G, params, adj_matrix) in enumerate(results):
        graph_data = {
            "idx": idx,
            "params": params,
            "adjacency_matrix": adj_matrix.tolist()
        }

        # Save individual graph as JSON
        single_graph_file = os.path.join(output_dir, f"graph_{idx}.json")
        with open(single_graph_file, 'w', encoding='utf-8') as f:
            json.dump(graph_data, f, indent=2, ensure_ascii=False)

        # Add to overall parameters list
        all_params.append({
            "idx": idx,
            "params": params
        })

    # Save all graph parameters to one JSON file
    params_file = os.path.join(output_dir, "all_graph_params.json")
    with open(params_file, 'w', encoding='utf-8') as f:
        json.dump(all_params, f, indent=2, ensure_ascii=False)

    # Save all adjacency matrices as npz file
    matrices = np.array([result[2] for result in results], dtype=object)
    np.savez_compressed(
        os.path.join(output_dir, "all_adjacency_matrices.npz"),
        matrices=matrices
    )

    print(f"Saved all data to {output_dir} directory")


def main():
    # Set random seed for reproducibility
    random_seed = 42

    # Output directory
    output_dir = "dag_data"

    # Generate 10000 DAGs with fixed parameters except for node count
    results = batch_generate_custom_dags(
        num_graphs=10000,
        seed=random_seed,
        min_nodes=3,
        max_nodes=10,
        expected_density=0.3,
        max_parents_ratio=0.5,
        max_children_ratio=0.6,
        confounders_ratio=0.2,
        colliders_ratio=0.2,
        mediator_chains_ratio=0.2
    )

    # Save results
    save_results(results, output_dir=output_dir)


if __name__ == "__main__":
    main()