import argparse
import os
import numpy as np

def create_prior_from_ground_truth(dataset_path: str, prior_strength: float, force_overwrite: bool = False):
    """
    Reads the ground truth graph from a dataset directory and generates a prior knowledge mask.

    This function reads 'adj_matrix.npy', creates a summary graph, randomly selects a
    fraction of the true edges as prior knowledge, and saves this as 'exist_mask.npy'
    in the same directory.

    Args:
        dataset_path (str): The path to the dataset directory, which must contain 'adj_matrix.npy'.
        prior_strength (float): The fraction of true edges to include in the prior mask (e.g., 0.2 for 20%).
        force_overwrite (bool): If True, overwrite an existing 'exist_mask.npy' file. Defaults to False.
    """
    print(f"Processing dataset directory: {dataset_path}")

    # 1. 
    adj_matrix_path = os.path.join(dataset_path, "adj_matrix.npy")
    exist_mask_path = os.path.join(dataset_path, f"exist_mask_{prior_strength}.npy")

    # 2.  'adj_matrix.npy' 
    if not os.path.exists(adj_matrix_path):
        print(f"[ERROR] 'adj_matrix.npy' not found in {dataset_path}. Skipping.")
        return

    # 3.  'exist_mask.npy' 
    if os.path.exists(exist_mask_path) and not force_overwrite:
        print(f"[INFO] 'exist_mask.npy' already exists in {dataset_path}. Use --force to overwrite. Skipping.")
        return

    # 4. 
    try:
        temporal_graph = np.load(adj_matrix_path)
        print(f"Successfully loaded 'adj_matrix.npy' with shape: {temporal_graph.shape}")
    except Exception as e:
        print(f"[ERROR] Failed to load 'adj_matrix.npy': {e}")
        return

    # 5.  (lag+1, N, N)  (N, N)
    #     (i, j)， (i, j)
    print("Creating summary graph from temporal graph...")
    summary_graph = np.any(temporal_graph, axis=0).astype(int)

    # 6. 
    true_edges = np.argwhere(summary_graph == 1)
    num_true_edges = len(true_edges)
    if num_true_edges == 0:
        print("[WARNING] The summary graph has no edges. An all-zero 'exist_mask.npy' will be created.")
    else:
        print(f"Found {num_true_edges} total edges in the summary graph.")

    # 7. 
    np.random.shuffle(true_edges)

    # 8. 
    num_prior_edges = int(num_true_edges * prior_strength)
    print(f"Based on prior_strength={prior_strength}, selecting {num_prior_edges} edges for the prior.")

    # 9.  N% 
    prior_edges = true_edges[:num_prior_edges]

    # 10. ， 1
    exist_mask = np.zeros_like(summary_graph, dtype=np.float32)
    if len(prior_edges) > 0:
        exist_mask[prior_edges[:, 0], prior_edges[:, 1]] = 1.0
    print(exist_mask.sum())
    # 11. 
    try:
        np.save(exist_mask_path, exist_mask)
        print(f"Successfully saved 'exist_mask.npy' to {exist_mask_path}")
    except Exception as e:
        print(f"[ERROR] Failed to save 'exist_mask.npy': {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate a prior knowledge mask ('exist_mask.npy') from a ground truth graph ('adj_matrix.npy').")
    parser.add_argument(
        "--path",
        "-p",
        type=str,
        required=True,
        help="Path to the dataset directory. Can also be a parent directory containing multiple datasets."
    )
    parser.add_argument(
        "--strength",
        "-s",
        type=float,
        default=0.2,
        help="The fraction of true edges to use as prior knowledge (e.g., 0.2 for 20%). Default is 0.2."
    )
    parser.add_argument(
        "--recursive",
        "-r",
        action='store_true',
        help="If the provided path is a directory, recursively find all subdirectories containing 'adj_matrix.npy' and process them."
    )
    parser.add_argument(
        "--force",
        "-f",
        action='store_true',
        help="Force overwrite of existing 'exist_mask.npy' files."
    )
    
    args = parser.parse_args()

    if args.recursive:
        print(f"Recursively searching for datasets in: {args.path}")
        for root, dirs, files in os.walk(args.path):
            if 'adj_matrix.npy' in files:
                create_prior_from_ground_truth(root, args.strength, args.force)
    else:
        create_prior_from_ground_truth(args.path, args.strength, args.force)

    print("\nProcess finished.")

# python create_prior.py --path data/ER_ER_lag_2_dim_20_HistDep_0.5_mlp_spline_product_con_2_inthist_3_seed_0 --strength 0.5
