import os
import re
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm # Import tqdm for progress bar
import contextlib
import argparse

# Assuming these imports exist and are correct
from dynotears import from_pandas_dynamic as from_pandas_dynamic
from dynotears_and import from_pandas_dynamic as from_pandas_dynamic_and
from dynotears_and_init import from_pandas_dynamic as from_pandas_dynamic_and_initdata
from dynotears_multiply import from_pandas_dynamic as from_pandas_dynamic_multiply
from dynotears_multiply_init import from_pandas_dynamic as from_pandas_dynamic_multiply_initdata

from utils import get_solve_matrix, get_gt_matrix, get_result_matrix, read_markdown_table

def save_weights_to_csv(w_matrix, a_matrix, p_orders, d_vars, filename):
    """
    Saves the W and A matrices to a CSV file in a stacked format.

    Args:
        w_matrix (np.ndarray): The estimated W matrix.
        a_matrix (np.ndarray): The estimated A matrix.
        p_orders (int): The number of lags (p).
        d_vars (int): The number of variables.
        filename (str): The output CSV filename.
    """
    # Prepare W DataFrame
    df_w = pd.DataFrame(w_matrix, columns=[f'dest_{j}' for j in range(d_vars)])
    df_w['matrix_type'] = 'W'
    df_w['lag'] = 0
    df_w['source_node'] = range(d_vars)

    # Prepare A DataFrame
    df_a = pd.DataFrame(a_matrix, columns=[f'dest_{j}' for j in range(d_vars)])
    df_a['matrix_type'] = 'A'
    lags = np.repeat(np.arange(1, p_orders + 1), d_vars)
    source_nodes = np.tile(np.arange(d_vars), p_orders)
    df_a['lag'] = lags
    df_a['source_node'] = source_nodes

    # Combine DataFrames
    df_combined = pd.concat([df_w, df_a], ignore_index=True)

    # Reorder columns
    id_cols = ['matrix_type', 'lag', 'source_node']
    weight_cols = [f'dest_{j}' for j in range(d_vars)]
    df_combined = df_combined[id_cols + weight_cols]

    # Save to CSV
    # df_combined.to_csv(filename, index=False, float_format='%.6g', encoding='utf-8-sig')
    df_combined.to_csv(filename, index=False, encoding='utf-8-sig')
    print(f"  Weights saved to '{os.path.basename(filename)}'") # Print relative path for tidiness

def get_result_matrix_df(result_matrix):
    index = [f"timeseries_{i}" for i in range(result_matrix.shape[0] - 2)]
    return pd.DataFrame(result_matrix, index=["timeseries", "total"] + index, columns=["accuracy", "recall", "f1", "shd", "edge_loss", "edge_recovery", "weak_correct_edge_num", "weak_edge_num"])

def merge_results(data_dir: str, repeat_num: int = 6):
    """
    Calculates the average difference between constrained results and
    their corresponding baseline results over several runs.

    Args:
        data_dir: Path to the directory containing baseline files and
                  'exist_edges_prob_X' subdirectories with constrained results.
        repeat_num: Number of repeated constrained/baseline runs.
    """
    # Find subdirectories like 'exist_edges_prob_X'
    try:
        exps_dirs = sorted([
            d for d in os.listdir(data_dir) if
            os.path.isdir(os.path.join(data_dir, d)) and d.startswith('exist_edges_prob_')
        ])
    except FileNotFoundError:
        print(f"  Error: Data directory not found: {data_dir}")
        return
    except Exception as e:
        print(f"  Error listing directories in {data_dir}: {e}")
        return
    if not exps_dirs:
        print(f"  Info: No 'exist_edges_prob_' subdirectories found in {data_dir} for merging.")
        return

    print(f"Merging results (calculating diffs) in {os.path.basename(data_dir)}...")
    for exp_dir in exps_dirs:
        full_exp_path = os.path.join(data_dir, exp_dir)
        print(f"  Processing experiment: {exp_dir}")

        diff_matrix_list = [] # Store individual difference matrices
        mean_matrix_list = []
        valid_runs_processed = 0

        for i in range(repeat_num):
            baseline_result_matrix = None # Reset for each run
            constrained_matrix = None # Reset for each run

            # --- Try to load baseline result for run i ---
            baseline_file = os.path.join(full_exp_path, f"result_baseline_{i}.md")
            if os.path.exists(baseline_file):
                try:
                    # Assuming first col is index name, adjust if necessary
                    baseline_df = read_markdown_table(baseline_file)
                    if baseline_df.shape[1] > 1: # Check if there's more than an index col
                        baseline_result_matrix = baseline_df.to_numpy()[:, 1:].astype(float) # Ensure float for subtraction
                        mean_matrix_list.append(baseline_result_matrix)
                    else:
                         print(f"  Warning: Baseline file {baseline_file} has too few columns. Skipping run {i}.")
                         continue # Skip to next i
                except pd.errors.EmptyDataError:
                     print(f"  Warning: Baseline file {baseline_file} is empty. Skipping run {i}.")
                     continue
                except Exception as e:
                    print(f"  Warning: Could not read/process baseline {baseline_file}: {e}. Skipping run {i}.")
                    continue # Skip to next i
            else:
                continue # Skip if baseline must exist

            # --- Try to load constrained result for run i ---
            constrained_result_file = os.path.join(full_exp_path, f"result_constrained_{i}.md")
            if os.path.exists(constrained_result_file):
                try:
                    # Assuming first col is index name, adjust if necessary
                    constrained_df = read_markdown_table(constrained_result_file)
                    if constrained_df.shape[1] > 1:
                        constrained_matrix = constrained_df.to_numpy()[:, 1:].astype(float) # Ensure float
                    else:
                         print(f"  Warning: Constrained file {constrained_result_file} has too few columns. Skipping run {i}.")
                         continue
                except pd.errors.EmptyDataError:
                     print(f"  Warning: Constrained file {constrained_result_file} is empty. Skipping run {i}.")
                     continue
                except Exception as e:
                    print(f"  Warning: Could not read/process constrained {constrained_result_file}: {e}. Skipping run {i}.")
                    continue
            else:
                continue # Skip if constrained must exist

            # --- Check shapes and calculate difference if both loaded ---
            if baseline_result_matrix is not None and constrained_matrix is not None:
                if baseline_result_matrix.shape == constrained_matrix.shape:
                    # Calculate the difference for this run
                    diff = constrained_matrix - baseline_result_matrix
                    diff_matrix_list.append(diff)
                    valid_runs_processed += 1
                else:
                    print(f"  Warning: Shape mismatch for run {i}. Baseline: {baseline_result_matrix.shape}, Constrained: {constrained_matrix.shape}. Skipping.")

        if mean_matrix_list: # Correct check for non-empty list
            print(f"    Calculating average mean from {len(mean_matrix_list)} valid runs for {exp_dir}...")
            # Stack the list of 2D arrays into a 3D array and calculate mean along axis 0
            mean_mean_matrix = np.mean(np.array(mean_matrix_list), axis=0)

            try:
                 # Ensure get_result_matrix_df handles numpy array input
                 mean_matrix_df = get_result_matrix_df(mean_mean_matrix)
                 mean_output_file = os.path.join(full_exp_path, "result_constrained_mean.md")
                 mean_matrix_df.to_markdown(mean_output_file)
                 print(f"    Saved average mean to {mean_output_file}")
            except Exception as e:
                 print(f"  Error: Failed to convert/save average mean for {exp_dir}: {e}")
        else:
            print(f"    Warning: No valid runs found to calculate mean for {exp_dir}.")
        
        # --- Calculate average difference and save AFTER the inner loop ---
        if diff_matrix_list: # Correct check for non-empty list
            print(f"    Calculating average difference from {len(diff_matrix_list)} valid runs for {exp_dir}...")
            # Stack the list of 2D arrays into a 3D array and calculate mean along axis 0
            mean_diff_matrix = np.mean(np.array(diff_matrix_list), axis=0)

            # Convert the mean difference matrix back to DataFrame (using helper)
            try:
                 # Ensure get_result_matrix_df handles numpy array input
                 diff_matrix_df = get_result_matrix_df(mean_diff_matrix)
                 diff_output_file = os.path.join(full_exp_path, "result_constrained_diff.md")
                 diff_matrix_df.to_markdown(diff_output_file)
                 print(f"    Saved average difference to {diff_output_file}")
            except Exception as e:
                 print(f"  Error: Failed to convert/save average difference for {exp_dir}: {e}")
        else:
            print(f"    Warning: No valid runs found to calculate difference for {exp_dir}.")

    print("Merging process finished.")

def consolidate_results(data_dir: str, output_filename: str = "results_consolidated.md"):
    """
    Consolidate baseline and diff results under all probability subdirectories
    within the specified directory.

    Args:
        data_dir: Directory for a single experimental repeat (e.g., '.../repeat0').
        output_filename: The name for the consolidated output file (e.g., 'results_consolidated.md' or '.csv').
    """
    print(f"  Starting consolidation for directory: {os.path.basename(data_dir)}")
    all_results_dfs = [] # To store DataFrames read from various files
    base_results_dfs = []
    # Find and process all exist_edges_prob_X subdirectories
    exps_dirs = sorted([
        d for d in os.listdir(data_dir) if
        os.path.isdir(os.path.join(data_dir, d)) and d.startswith('exist_edges_prob_')
    ])

    if not exps_dirs:
        print(f"  Info: No 'exist_edges_prob_' subdirectories found in {os.path.basename(data_dir)} for consolidation.")
        # Proceed to save if only baseline exists, otherwise return error handled later

    # Iterate through the diff file in each probability subdirectory
    for exp_dir in exps_dirs:
        exp_path = os.path.join(data_dir, exp_dir)

        # Parse probability from directory name
        match = re.search(r'exist_edges_prob_(\d+(\.\d+)?)', exp_dir)
        if not match:
            continue
        prob_value = float(match.group(1))
        
        baseline_file_path = os.path.join(exp_path, f"result_constrained_mean.md")
        baseline_df = read_markdown_table(baseline_file_path)
        if not baseline_df.empty:
            baseline_df.insert(0, 'edge_prior_prob', prob_value) # Insert the parsed probability value as the first column
            base_results_dfs.append(baseline_df)
        
        diff_file_path = os.path.join(exp_path, "result_constrained_diff.md")
        diff_df = read_markdown_table(diff_file_path)
        if not diff_df.empty:
            diff_df.insert(0, 'edge_prior_prob', prob_value) # Insert the parsed probability value as the first column
            all_results_dfs.append(diff_df)
            
    # Combine all DataFrames
    base_combined_df = pd.concat(base_results_dfs, ignore_index=True)
    all_combined_df = pd.concat(all_results_dfs, ignore_index=True)

    # Save the combined results
    base_output_path = os.path.join(data_dir, f'base_{output_filename}')
    output_format = output_filename.split('.')[-1].lower()
    if output_format == "csv":
        base_combined_df.to_csv(base_output_path, index=False, float_format='%.4f')
        print(f"  Consolidation successful. Saved to CSV: base_{output_filename}")
    elif output_format == "md":
        # For Markdown output, the default pandas row index is usually not desired
        base_combined_df.to_markdown(base_output_path, index=False, floatfmt=".4f")
        print(f"  Consolidation successful. Saved to Markdown: base_{output_filename}")
    else:
        print(f"  Warning: Unsupported output file format 'base_{output_filename}'. Please use .csv or .md. File not saved.")
        
    all_output_path = os.path.join(data_dir, f'all_{output_filename}')
    output_format = output_filename.split('.')[-1].lower()
    if output_format == "csv":
        all_combined_df.to_csv(all_output_path, index=False, float_format='%.4f')
        print(f"  Consolidation successful. Saved to CSV: all_{output_filename}")
    elif output_format == "md":
        # For Markdown output, the default pandas row index is usually not desired
        all_combined_df.to_markdown(all_output_path, index=False, floatfmt=".4f")
        print(f"  Consolidation successful. Saved to Markdown: all_{output_filename}")
    else:
        print(f"  Warning: Unsupported output file format 'all_{output_filename}'. Please use .csv or .md. File not saved.")
        
def aggregate_all_consolidated(
    data_path: str,
    data_info: str,
    name: str,
    consolidated_filename: str,
    final_output_filename: str = "final_aggregated_summary.csv"
):
    """
    Aggregate consolidated files from all repeatX directories under the specified data_path.

    Args:
        data_path: The main path containing all repeatX directories (e.g., '.../nodeX_edgeY_pZ_T W').
        consolidated_filename: The name of the file to aggregate within each repeatX directory
                               (e.g., 'results_consolidated.csv').
        final_output_filename: The name for the final aggregated file (saved under data_path).
    """
    print(f"\nStarting aggregation of all consolidated files in: {os.path.basename(data_path)}")
    print(f"Looking for files named: {consolidated_filename}")

    all_repeat_dfs = [] # To store DataFrames read from each repeatX directory
    base_repeat_dfs = [] # To store DataFrames read from each repeatX directory

    # Find and sort all repeatX directories
    repeat_dirs = sorted([
        d for d in os.listdir(data_path) if
        os.path.isdir(os.path.join(data_path, d)) and d.startswith('repeat')
    ])

    if not repeat_dirs:
        print(f"Error: No 'repeatX' subdirectories found in {data_path}. Cannot aggregate.")
        return
    # Iterate through each repeatX directory
    for dir_name in repeat_dirs:
        full_dir_path = os.path.join(data_path, dir_name)
        base_consolidated_file_path = os.path.join(full_dir_path, f'base_{consolidated_filename}')
        all_consolidated_file_path = os.path.join(full_dir_path, f'all_{consolidated_filename}')

        # Extract dataset_index
        match = re.search(r'repeat(\d+)', dir_name)
        if not match:
            continue
        dataset_index = int(match.group(1))
        
        # Check if the consolidated file exists
        if os.path.exists(base_consolidated_file_path):
            try:
                # Select read method based on file extension
                if consolidated_filename.endswith(".csv"):
                    df = pd.read_csv(base_consolidated_file_path)
                elif consolidated_filename.endswith(".md"):
                    df = read_markdown_table(base_consolidated_file_path) # Use your Markdown reading function
                else:
                    print(f"  Warning: Unsupported file type {consolidated_filename}, skipping.")
                    continue
                if not df.empty:
                    # Insert dataset_index column at the beginning
                    df.insert(0, 'dataset_index', dataset_index)
                    base_repeat_dfs.append(df)
            except Exception as e:
                print(f"  Warning: Failed to read file {base_consolidated_file_path}: {e}")
        
        
        # Check if the consolidated file exists
        if os.path.exists(all_consolidated_file_path):
            try:
                # Select read method based on file extension
                if consolidated_filename.endswith(".csv"):
                    df = pd.read_csv(all_consolidated_file_path)
                elif consolidated_filename.endswith(".md"):
                    df = read_markdown_table(all_consolidated_file_path) # Use your Markdown reading function
                else:
                    print(f"  Warning: Unsupported file type {consolidated_filename}, skipping.")
                    continue
                if not df.empty:
                    # Insert dataset_index column at the beginning
                    df.insert(0, 'dataset_index', dataset_index)
                    all_repeat_dfs.append(df)
            except Exception as e:
                print(f"  Warning: Failed to read file {all_consolidated_file_path}: {e}")

    base_final_combined_df = pd.concat(base_repeat_dfs, ignore_index=True)
    print(f"Successfully combined data from {len(base_repeat_dfs)} repeat directories.")
    all_final_combined_df = pd.concat(all_repeat_dfs, ignore_index=True)
    print(f"Successfully combined data from {len(all_repeat_dfs)} repeat directories.")

    # Save the final aggregated file
    if not os.path.exists(os.path.join('result/result_absence', name)):
        os.mkdir(os.path.join('result/result_absence', name))

    # Save the final aggregated file
    if not os.path.exists(os.path.join('result/result_absence', name, data_info)):
        os.mkdir(os.path.join('result/result_absence', name, data_info))
        
    base_final_output_path = os.path.join('result/result_absence', name, data_info, f'base_{final_output_filename}')
    
    output_format = final_output_filename.split('.')[-1].lower()
    if output_format == "csv":
        base_final_combined_df.to_csv(base_final_output_path, index=False, float_format='%.4f')
    elif output_format == "md":
        base_final_combined_df.to_markdown(base_final_output_path, index=False, floatfmt=".4f")
    else:
        print(f"Warning: Unsupported final output format '{output_format}'. Saving as CSV.")
        base_final_output_path = os.path.join('result/result_absence', name, data_info, "base_final_aggregated_summary.csv") # Defaulting to CSV
        base_final_combined_df.to_csv(base_final_output_path, index=False, float_format='%.4f')
    
    all_final_output_path = os.path.join('result/result_absence', name, data_info, f'all_{final_output_filename}')
    
    output_format = final_output_filename.split('.')[-1].lower()
    if output_format == "csv":
        all_final_combined_df.to_csv(all_final_output_path, index=False, float_format='%.4f')
    elif output_format == "md":
        all_final_combined_df.to_markdown(all_final_output_path, index=False, floatfmt=".4f")
    else:
        print(f"Warning: Unsupported final output format '{output_format}'. Saving as CSV.")
        all_final_output_path = os.path.join('result/result_absence', name, data_info, "all_final_aggregated_summary.csv") # Defaulting to CSV
        all_final_combined_df.to_csv(all_final_output_path, index=False, float_format='%.4f')

    print(f"Successfully saved final aggregated results to: {all_final_output_path}")

def solve(data_dir: str, args, exist_edges_prob: float = 0.1, repeat_num: int = 6):
    """
    Solves the causal discovery problem for baseline and constrained scenarios.

    Args:
        data_dir: Path to the directory containing data.pkl and links_infos.pkl.
        exist_edges_prob: Probability of including existing edges in the constrained mask.
        repeat_num: Number of times to repeat the constrained solving with different random masks.
    """

    data = pickle.load(open(os.path.join(data_dir, "data.pkl"), "rb"))
    links_infos = pickle.load(open(os.path.join(data_dir, "links_infos.pkl"), "rb"))
    p = int(re.findall(r'porders(\d+)', data_dir)[0])

    d_vars = data.shape[1]
    gt = get_gt_matrix(links_infos=links_infos, p=p)
    gt_2D = gt.copy().reshape((p + 1) * d_vars, d_vars)
    gt_filename = os.path.join(data_dir, "ground_truth.csv") # Save in the main repeat directory
    df_gt = pd.DataFrame(gt_2D) # Convert the 2D gt matrix
    df_gt.to_csv(gt_filename, index=False, header=False, encoding='utf-8-sig')

    # --- Constrained Solve (with prior edges) ---
    constrained_dir = os.path.join(data_dir, f"exist_edges_prob_{exist_edges_prob}")
    os.makedirs(constrained_dir, exist_ok=True)

    constrain_between_nodes = np.any(gt, axis=0) # Find nodes that have *any* connection in GT
    possible_edge_indices = np.where(constrain_between_nodes) # Get indices (row, col) of possible edges
    num_possible_edges = len(possible_edge_indices[0])
    num_edges_to_select = int(np.ceil(exist_edges_prob * num_possible_edges))

    if num_possible_edges == 0:
        print(f"  Warning: No potential edges found based on GT for {os.path.basename(data_dir)}. Skipping constrained solve for prob {exist_edges_prob}.")
        return
    if num_edges_to_select > num_possible_edges:
        print(f"  Warning: Requested more prior edges ({num_edges_to_select}) than possible ({num_possible_edges}). Using all possible edges.")
        num_edges_to_select = num_possible_edges
    
    gt_absence = gt==0
    constrain_between_nodes_absence = np.all(gt_absence, axis=0) # Find nodes that have *any* connection in GT
    possible_edge_indices_absence = np.where(constrain_between_nodes_absence) # Get indices (row, col) of possible edges
    num_possible_edges_absence = len(possible_edge_indices_absence[0])
    num_edges_to_select_absence = int(np.ceil(exist_edges_prob * num_possible_edges_absence))

    if num_possible_edges_absence == 0:
        print(f"  Warning: No potential edges found based on GT for {os.path.basename(data_dir)}. Skipping constrained solve for prob {exist_edges_prob}.")
        return
    if num_edges_to_select_absence > num_possible_edges_absence:
        print(f"  Warning: Requested more prior edges ({num_edges_to_select_absence}) than possible ({num_possible_edges_absence}). Using all possible edges.")
        num_edges_to_select_absence = num_possible_edges_absence

    alg_mapping={
        'and_init0':from_pandas_dynamic_and,
        'and_initdata':from_pandas_dynamic_and_initdata,
        'multiply_init0':from_pandas_dynamic_multiply,
        'multiply_initdata':from_pandas_dynamic_multiply_initdata,
        'absence':from_pandas_dynamic
    }
    
    selected_alg = alg_mapping[args.algorithm]
    
    for i in tqdm(range(repeat_num), desc=f"    Runs (p={exist_edges_prob})", unit="run", leave=False):
        rng = np.random.default_rng(seed=i)
        exist_edges_mask = np.zeros((d_vars, d_vars))
        selected_indices_flat = rng.choice(np.arange(num_possible_edges), size=num_edges_to_select, replace=False)
        selected_rows = possible_edge_indices[0][selected_indices_flat]
        selected_cols = possible_edge_indices[1][selected_indices_flat]
        exist_edges_mask[selected_rows, selected_cols] = 1
        absence_edges_mask = np.zeros((d_vars, d_vars))
        selected_indices_flat_absence = rng.choice(np.arange(num_possible_edges_absence), size=num_edges_to_select_absence, replace=False)
        selected_rows_absence = possible_edge_indices_absence[0][selected_indices_flat_absence]
        selected_cols_absence = possible_edge_indices_absence[1][selected_indices_flat_absence]
        absence_edges_mask[selected_rows_absence, selected_cols_absence] = 1
        mask_filename_csv = os.path.join(constrained_dir, f"exist_edges_mask_{i}.csv")
        mask_filename_csv_absence = os.path.join(constrained_dir, f"absence_edges_mask_{i}.csv")
        df_mask = pd.DataFrame(exist_edges_mask)
        df_mask.to_csv(mask_filename_csv, index=False, header=False, encoding='utf-8-sig')
        df_mask_absence = pd.DataFrame(absence_edges_mask)
        df_mask_absence.to_csv(mask_filename_csv_absence, index=False, header=False, encoding='utf-8-sig')
        
        common_params = {
            'p': p,
            'w_threshold': 0.1,
            'exist_edges_mask': exist_edges_mask
        }
        
        if 'multiply' in args.algorithm:
            common_params['lambda_e'] = 0.5
            common_params['lambda_p'] = 10
        elif 'and' in args.algorithm:
            common_params['lambda_e'] = 0.5
        elif 'absence' in args.algorithm:
            common_params['lambda_e'] = 0.5
            common_params['tabu_edges'] = []
            nonzero_row_indices, nonzero_col_indices = np.nonzero(absence_edges_mask)
            for r, c in zip(nonzero_row_indices, nonzero_col_indices):
                for lag in range(p+1):
                    common_params['tabu_edges'].append((lag, r, c))

        # --- Baseline Solve ---
        print(f"  Solving baseline for {os.path.basename(data_dir)}...")
        # Note: exist_edges_mask=None is the default, explicitly setting for clarity
        sm_baseline, eloss, (w_est, a_est)  = from_pandas_dynamic(pd.DataFrame(data), p=p, lambda_e=0.5, w_threshold=0.1, exist_edges_mask=exist_edges_mask)
        weights_baseline_filename = os.path.join(constrained_dir, f"baseline_weights_{i}.csv")
        save_weights_to_csv(w_est, a_est, p, d_vars, weights_baseline_filename)
        solve_baseline = get_solve_matrix(sm=sm_baseline, p=p, d_vars=d_vars)
        result_matrix_baseline = get_result_matrix_df(get_result_matrix(solve_baseline, gt, eloss, exist_edges_mask,w_est, a_est, 0.1, p, d_vars))
        result_matrix_baseline.to_markdown(os.path.join(constrained_dir, f"result_baseline_{i}.md"))
        print(f"  Baseline results saved to result_baseline.md")
        
        print(f"  Solving & constrained (prob={exist_edges_prob}) for {os.path.basename(data_dir)}...")
        # Perform the constrained solve
        sm_constrained, eloss, (w_est, a_est) = selected_alg(pd.DataFrame(data), **common_params)
        weights_constrained_filename = os.path.join(constrained_dir, f"constrained_multiply_weights_{i}.csv")
        save_weights_to_csv(w_est, a_est, p, d_vars, weights_constrained_filename)
        solve_constrained = get_solve_matrix(sm=sm_constrained, p=p, d_vars=d_vars)
        result_matrix_constrained = get_result_matrix_df(get_result_matrix(solve_constrained, gt, eloss, exist_edges_mask,w_est, a_est, 0.1, p, d_vars))
        result_filename = os.path.join(constrained_dir, f"result_constrained_{i}.md")
        result_matrix_constrained.to_markdown(result_filename)
        print(f"  & Constrained solving (prob={exist_edges_prob}) finished.") # Print completion after loop


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run causal discovery experiments and aggregate results.")
    parser.add_argument("--data_root", type=str, default="simulated_data", help="Root directory containing simulation parameter folders.")
    parser.add_argument("--repeat_num", type=int, default=6, help="Number of random mask repetitions for constrained solving.")
    parser.add_argument('--node_num', type=int, default=10, help="Number of nodes in the simulation.")
    parser.add_argument('--edge_num', type=int, default=15, help="Number of edges in the simulation.")
    parser.add_argument('--p_orders', type=int, default=3, help="Number of time lags (p orders).")
    parser.add_argument('--T', type=int, default=500, help="Length of the time series.")
    parser.add_argument("--noises_type", type=str, default="gauss")
    parser.add_argument('--edge_probs', nargs='+', type=float, default=[0.8], help='List of edge prior probabilities to test.')
    parser.add_argument('--algorithm', type=str, default='and_init0')
    parser.add_argument('--consolidated_format', type=str, default="csv", choices=['csv', 'md'], help='Format for the consolidated results file per repeat.')
    parser.add_argument('--final_format', type=str, default="csv", choices=['csv', 'md'], help='Format for the final aggregated summary file.')
    args = parser.parse_args()
    np.random.seed(42)

    data_info=f'node{str(args.node_num).zfill(3)}_edge{str(args.edge_num).zfill(3)}_porders{str(args.p_orders).zfill(1)}_T{str(args.T).zfill(4)}_noise{args.noises_type}'
    data_path = os.path.join(args.data_root, data_info)
    repeat_dirs = sorted([
            d for d in os.listdir(data_path) if
            os.path.isdir(os.path.join(data_path, d)) and d.startswith('repeat')
        ])

    for data_dir_basename in repeat_dirs:
        data_dir_path = os.path.join(data_path, data_dir_basename)
        print(f"\nProcessing repeat directory: {data_dir_basename}")

        # Run solve for baseline and each probability
        for edge_prob in args.edge_probs:
            solve(data_dir_path, args, exist_edges_prob=edge_prob, repeat_num = args.repeat_num)
            
        merge_results(data_dir_path, repeat_num = args.repeat_num)

        # Consolidate results for this repeat directory
        output_file = f"results_consolidated.csv"
        consolidate_results(data_dir_path, output_filename=output_file)

    # --- Final Aggregation ---
    # Aggregate results across all repeat directories
    consolidated_filename = f"results_consolidated.csv"
    final_output_filename = f"final_aggregated_summary.csv"
    aggregate_all_consolidated(
        data_path=data_path,
        data_info=data_info,
        name=args.algorithm,
        consolidated_filename=consolidated_filename,
        final_output_filename=final_output_filename
    )
    print("\nScript finished.")
