#!/usr/bin/env python
# coding: utf-8

import numpy as np
import networkx as nx
import argparse
import os
import pickle

def generate_random_dag(n, p, graph_type="er"):
    """Generates a random Directed Acyclic Graph (DAG).

    Args:
        n (int): Number of nodes.
        p (float): Probability of edge creation for ER graphs, or number of edges 
                 to attach from a new node to existing nodes for BA graphs (m).
        graph_type (str): Type of graph ("er" for Erdős–Rényi, "sf" for Scale-Free/Barabási–Albert).

    Returns:
        networkx.DiGraph: A random DAG.
    """
    if graph_type == "er":
        # Generate ER random directed graph
        G = nx.DiGraph(nx.gnp_random_graph(n=n, p=p, directed=True))
    elif graph_type == "sf":
        # Generate undirected SF graph then convert to directed
        # Note: The guide mentions m=1-3 for SF. Let's use p as m here for simplicity in args.
        m = max(1, int(p)) # Interpret p as m for SF, ensure m >= 1
        if m >= n:
            print(f"Warning: m ({m}) >= n ({n}) for SF graph. Setting m = {max(1, n // 2)}.")
            m = max(1, n // 2)
        G_undir = nx.barabasi_albert_graph(n=n, m=m)
        G = nx.DiGraph()
        G.add_nodes_from(range(n))
        # Ensure directed edges follow topological order (lower index -> higher index)
        # This is a simple way to enforce acyclicity for BA graphs
        for (i, j) in G_undir.edges():
            if i < j:
                G.add_edge(i, j)
            else:
                G.add_edge(j, i)
    else:
        raise ValueError(f"Unknown graph type: {graph_type}")

    # Ensure acyclicity (especially important for ER graphs)
    while not nx.is_directed_acyclic_graph(G):
        cycles = list(nx.simple_cycles(G))
        if not cycles:
            break # Should be acyclic now
        # Remove an edge from the first found cycle
        # This is a simple strategy, might disconnect graph
        edge_to_remove = (cycles[0][0], cycles[0][1])
        if G.has_edge(*edge_to_remove):
             G.remove_edge(*edge_to_remove)
        else: # Should not happen in simple cycles, but as safeguard
             break 

    # Ensure the graph is connected (optional, but often desired)
    # If not connected, could add edges between components, ensuring acyclicity
    # For simplicity, we'll skip strict connectivity enforcement for now.

    return G

def generate_weights(G):
    """Generates edge weights for the DAG.

    Weights are sampled uniformly from [-2.0, -0.5] U [0.5, 2.0].

    Args:
        G (networkx.DiGraph): The DAG.

    Returns:
        dict: A dictionary mapping edges (u, v) to weights.
    """
    weights = {}
    for (i, j) in G.edges():
        if np.random.random() < 0.5:
            weights[(i, j)] = np.random.uniform(0.5, 2.0)
        else:
            weights[(i, j)] = np.random.uniform(-2.0, -0.5)
    return weights

def generate_scm_data(G, weights, n_samples, scm_type="linear", noise_scale=0.1, intervention_details=None):
    """Generates data from a Structural Causal Model (SCM).

    Args:
        G (networkx.DiGraph): The causal DAG.
        weights (dict): Edge weights.
        n_samples (int): Number of samples to generate.
        scm_type (str): Type of SCM ("linear" or "nonlinear").
        noise_scale (float): Standard deviation of the Gaussian noise.
        intervention_details (dict, optional): If provided, generates interventional data.
                                             Example: {"node": int, "value": float}.

    Returns:
        np.ndarray: Generated data matrix (n_samples, n_nodes).
    """
    n = G.number_of_nodes()
    X = np.zeros((n_samples, n))
    nodes_in_order = list(nx.topological_sort(G))

    intervened_node = intervention_details["node"] if intervention_details else None
    intervention_value = intervention_details["value"] if intervention_details else None

    for node in nodes_in_order:
        # Check for intervention
        if node == intervened_node:
            X[:, node] = intervention_value
            continue

        parents = list(G.predecessors(node))
        noise = np.random.normal(0, noise_scale, n_samples)

        if not parents:
            # Root node (no parents)
            X[:, node] = noise # Or sample from a base distribution, e.g., N(0,1)
        else:
            # Compute parent contribution
            parent_contribution = sum(weights[(p, node)] * X[:, p] for p in parents)

            # Apply SCM function
            if scm_type == "linear":
                X[:, node] = parent_contribution + noise
            elif scm_type == "nonlinear":
                # Example non-linearity: tanh
                X[:, node] = np.tanh(parent_contribution) + noise
                # Could add other non-linear functions based on args
            else:
                raise ValueError(f"Unknown SCM type: {scm_type}")

    return X

def main(args):
    """Main function to generate and save datasets."""
    print(f"Generating synthetic data with config: {args}")
    os.makedirs(args.output_dir, exist_ok=True)

    for i in range(args.num_graphs):
        print(f"\nGenerating graph {i+1}/{args.num_graphs}...")
        # Generate DAG
        G = generate_random_dag(args.num_nodes, args.edge_prob, args.graph_type)
        weights = generate_weights(G)
        
        # Save graph structure and weights
        graph_data = {"graph": G, "weights": weights}
        graph_filename = os.path.join(args.output_dir, f"graph_{i+1}.pkl")
        with open(graph_filename, "wb") as f:
            pickle.dump(graph_data, f)
        print(f"Saved graph structure to {graph_filename}")

        # Generate observational data
        print("Generating observational data...")
        obs_data = generate_scm_data(G, weights, args.num_samples_obs, args.scm_type, args.noise_scale)
        obs_filename = os.path.join(args.output_dir, f"obs_data_{i+1}.npy")
        np.save(obs_filename, obs_data)
        print(f"Saved observational data to {obs_filename}")

        # Generate interventional data
        if args.num_samples_interv > 0:
            print("Generating interventional data...")
            all_interv_data = []
            for node_idx in range(args.num_nodes):
                # Sample intervention value (e.g., from data range or fixed range)
                interv_value = np.random.uniform(-2, 2) 
                interv_details = {"node": node_idx, "value": interv_value}
                interv_data_node = generate_scm_data(G, weights, args.num_samples_interv, 
                                                   args.scm_type, args.noise_scale, 
                                                   intervention_details=interv_details)
                # Store data along with intervention details
                all_interv_data.append({
                    "intervened_node": node_idx,
                    "intervened_value": interv_value,
                    "data": interv_data_node
                })
            
            interv_filename = os.path.join(args.output_dir, f"interv_data_{i+1}.pkl")
            with open(interv_filename, "wb") as f:
                pickle.dump(all_interv_data, f)
            print(f"Saved interventional data to {interv_filename}")

    print("\nSynthetic data generation complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate Synthetic Causal Data")
    parser.add_argument("--output_dir", type=str, default="/home/ubuntu/ecam_project/data/synthetic", 
                        help="Directory to save generated data")
    parser.add_argument("--num_graphs", type=int, default=10, 
                        help="Number of different graph structures to generate")
    parser.add_argument("--num_nodes", type=int, default=20, 
                        help="Number of nodes in each graph")
    parser.add_argument("--graph_type", type=str, default="er", choices=["er", "sf"],
                        help="Type of random graph to generate (er or sf)")
    parser.add_argument("--edge_prob", type=float, default=0.2,
                        help="Edge probability for ER graphs, or m for SF graphs")
    parser.add_argument("--scm_type", type=str, default="linear", choices=["linear", "nonlinear"],
                        help="Type of Structural Causal Model (linear or nonlinear)")
    parser.add_argument("--noise_scale", type=float, default=0.5,
                        help="Scale (std dev) of the additive Gaussian noise")
    parser.add_argument("--num_samples_obs", type=int, default=5000,
                        help="Number of observational samples per graph")
    parser.add_argument("--num_samples_interv", type=int, default=100,
                        help="Number of interventional samples per node per graph (0 to disable)")

    args = parser.parse_args()
    main(args)

