import pandas as pd
import os
import re
import numpy as np # Import numpy for np.nan

def merge(parent_directory):
    # --- Configuration ---
    # Define output filenames
    merged_all_file = 'merged_all_summary.csv'
    merged_base_file = 'merged_base_summary.csv'
    averaged_all_file = 'averaged_all_summary.csv'
    averaged_base_file = 'averaged_base_summary.csv'

    # --- Initialize lists to store dataframes ---
    all_dfs = []
    base_dfs = []

    folder_pattern = re.compile(r'node(\d+)_edge(\d+)_porders(\d+)_T(\d+)(?:_(.*))?')

    print(f"Starting processing in parent directory: {os.path.abspath(parent_directory)}")

    # --- Iterate over items in the parent directory ---
    for item in os.listdir(parent_directory):
        item_path = os.path.join(parent_directory, item)

        # Check if it's a directory and matches the pattern
        if os.path.isdir(item_path):
            match = folder_pattern.match(item)
            if match:
                print(f"  Found matching folder: {item}")
                # Extract numerical values
                node_val = int(match.group(1))
                edge_val = int(match.group(2))
                porders_val = int(match.group(3))
                t_val = int(match.group(4))

                # --- Extract noise type (group 5 is optional) ---
                noise_type_val = match.group(5)
                # Handle case where noise type is missing in the folder name
                if noise_type_val is None:
                    noise_type_val = 'unknown' # Assign 'unknown' if noise type part is missing
                    print(f"    Warning: Noise type suffix not found in folder name '{item}'. Assigning 'unknown'.")
                else:
                    # Optional: You might want to strip leading/trailing whitespace if necessary
                    noise_type_val = noise_type_val.strip()
                    print(f"    Extracted noise type: {noise_type_val}")


                # --- Process all_final_aggregated_summary.csv ---
                all_csv_path = os.path.join(item_path, 'all_final_aggregated_summary.csv')
                if os.path.exists(all_csv_path):
                    try: # Keep try-except for robustness during file reading
                        df_all = pd.read_csv(all_csv_path)
                        if df_all.empty:
                             print(f"    Warning: {all_csv_path} is empty, skipped.")
                             # Skip appending empty dataframes
                        else:
                            # Add metadata columns
                            df_all['node'] = node_val
                            df_all['edge'] = edge_val
                            df_all['porders'] = porders_val
                            df_all['T'] = t_val
                            df_all['noise_type'] = noise_type_val # <-- Add the new column here
                            all_dfs.append(df_all)
                            print(f"    Read and processed: {os.path.basename(all_csv_path)}")
                    except pd.errors.EmptyDataError:
                         print(f"    Warning: {all_csv_path} is empty or invalid, skipped.")
                    except Exception as e:
                         print(f"    Error reading {all_csv_path}: {e}")

                else:
                    print(f"    Warning: {all_csv_path} not found.")

                # --- Process base_final_aggregated_summary.csv ---
                base_csv_path = os.path.join(item_path, 'base_final_aggregated_summary.csv')
                if os.path.exists(base_csv_path):
                    try: # Keep try-except for robustness during file reading
                        df_base = pd.read_csv(base_csv_path)
                        if df_base.empty:
                            print(f"    Warning: {base_csv_path} is empty, skipped.")
                            # Skip appending empty dataframes
                        else:
                            # Add metadata columns
                            df_base['node'] = node_val
                            df_base['edge'] = edge_val
                            df_base['porders'] = porders_val
                            df_base['T'] = t_val
                            df_base['noise_type'] = noise_type_val # <-- Add the new column here
                            base_dfs.append(df_base)
                            print(f"    Read and processed: {os.path.basename(base_csv_path)}")
                    except pd.errors.EmptyDataError:
                        print(f"    Warning: {base_csv_path} is empty or invalid, skipped.")
                    except Exception as e:
                        print(f"    Error reading {base_csv_path}: {e}")
                else:
                    print(f"    Warning: {base_csv_path} not found.")

    print("\n--- Step 1: Merging files and calculating weak_recall ---")

    # --- Function to calculate weak recall ---
    def calculate_weak_recall(df, df_name):
        if 'weak_correct_edge_num' in df.columns and 'weak_edge_num' in df.columns:
            # Ensure columns are numeric, coercing errors to NaN
            # Use float type to accommodate potential NaNs from coercion or calculation
            df['weak_correct_edge_num'] = pd.to_numeric(df['weak_correct_edge_num'], errors='coerce').astype(float)
            df['weak_edge_num'] = pd.to_numeric(df['weak_edge_num'], errors='coerce').astype(float)

            # Get the columns for calculation
            correct_col = df['weak_correct_edge_num']
            total_col = df['weak_edge_num']

            # Calculate weak_recall using numpy.where for vectorized operation
            # Set to np.nan if total_col is 0 or NaN
            df['weak_recall'] = pd.Series(
                np.where( (total_col != 0) & (~total_col.isna()), # Condition: not 0 AND not NaN
                         correct_col / total_col,
                         np.nan), # Value if false: set to NaN
                index=df.index # Ensure index alignment
            ).astype(float) # Ensure the final column is float to hold NaN
            print(f"Calculated 'weak_recall' for merged '{df_name}' data (using NaN for zero denominator).")
        else:
            # Only add the column if calculation couldn't happen but it's needed later
            if 'weak_recall' not in df.columns:
                 df['weak_recall'] = np.nan
            print(f"Warning: 'weak_correct_edge_num' or 'weak_edge_num' not found in merged '{df_name}' data. Cannot calculate 'weak_recall'. Setting to NaN.")
        return df

    # --- Define the desired final column order (excluding edge_loss) ---
    # Added 'noise_type' to the desired order
    final_column_order = [
        'node', 'edge', 'porders', 'T', 'noise_type', 'dataset_index', 'edge_prior_prob','name',
        'accuracy', 'recall', 'f1', #'edge_loss', # Removed edge_loss
        'edge_recovery',
        'weak_correct_edge_num', 'weak_edge_num', 'weak_recall'
    ]
    # Add known optional columns (if they exist, they will be placed after the main list)
    optional_cols_to_keep = [
        'prior_recall', 'near_threshold_gt_recall', 'near_threshold_prior_recall'
    ]


    # --- Merge 'all' files ---
    if all_dfs:
        merged_all_df = pd.concat(all_dfs, ignore_index=True)
        merged_all_df = calculate_weak_recall(merged_all_df, 'all') # Calculate weak recall

        # --- Reorder columns ---
        # Get columns that actually exist in the dataframe
        existing_cols_in_order = [col for col in final_column_order if col in merged_all_df.columns]
        # Get any other columns present in the dataframe that were not in the desired list
        other_existing_cols = [col for col in merged_all_df.columns if col not in existing_cols_in_order]
        # Combine the lists to get the final order, placing known optional ones after the main list
        # Ensure edge_loss is explicitly excluded if it sneaks into other_existing_cols
        final_order = existing_cols_in_order + \
                      [col for col in optional_cols_to_keep if col in other_existing_cols] + \
                      [col for col in other_existing_cols if col not in optional_cols_to_keep and col != 'edge_loss'] # Explicitly exclude edge_loss here too

        merged_all_df = merged_all_df[final_order] # Apply the final column order

        merged_all_output_path = os.path.join(parent_directory, merged_all_file)
        merged_all_df.to_csv(merged_all_output_path, index=False)
        print(f"Merged 'all' files saved to: {merged_all_output_path}")
    else:
        print("No 'all_final_aggregated_summary.csv' files found to merge.")
        merged_all_df = pd.DataFrame() # Define as empty DF if none found

    # --- Merge 'base' files ---
    if base_dfs:
        merged_base_df = pd.concat(base_dfs, ignore_index=True)
        merged_base_df = calculate_weak_recall(merged_base_df, 'base') # Calculate weak recall

        # --- Reorder columns (using the same logic as 'all') ---
        existing_cols_in_order_base = [col for col in final_column_order if col in merged_base_df.columns]
        other_existing_cols_base = [col for col in merged_base_df.columns if col not in existing_cols_in_order_base]
        final_order_base = existing_cols_in_order_base + \
                           [col for col in optional_cols_to_keep if col in other_existing_cols_base] + \
                           [col for col in other_existing_cols_base if col not in optional_cols_to_keep and col != 'edge_loss'] # Explicitly exclude edge_loss here too

        merged_base_df = merged_base_df[final_order_base] # Apply the final column order

        merged_base_output_path = os.path.join(parent_directory, merged_base_file)
        merged_base_df.to_csv(merged_base_output_path, index=False)
        print(f"Merged 'base' files saved to: {merged_base_output_path}")
    else:
        print("No 'base_final_aggregated_summary.csv' files found to merge.")
        merged_base_df = pd.DataFrame() # Define as empty DF if none found


    print("\n--- Step 2: Calculating group averages ---")

    # --- Define grouping and aggregation columns ---
    # Added 'noise_type' to the grouping columns
    grouping_columns = ['node', 'edge', 'porders', 'T', 'noise_type', 'edge_prior_prob', 'name']

    # Define ALL potential aggregation columns based on the desired output order + optionals (excluding edge_loss)
    all_potential_agg_cols = [
        'accuracy', 'recall', 'f1', #'edge_loss', # Removed edge_loss
        'edge_recovery',
        'weak_correct_edge_num', 'weak_edge_num', 'weak_recall',
        'prior_recall', 'near_threshold_gt_recall', 'near_threshold_prior_recall'
    ]


    # --- Function to perform aggregation ---
    def calculate_averages(df, grouping_cols, potential_agg_cols, output_path, df_name):
        if df.empty:
            print(f"Cannot calculate averages for '{df_name}' file as merged data is empty.")
            return

        # Determine actual grouping and aggregation columns available in this specific df
        final_grouping_cols = [col for col in grouping_cols if col in df.columns]
        # Filter potential_agg_cols to only include those present in the df AND not 'edge_loss'
        final_agg_cols = [col for col in potential_agg_cols if col in df.columns and col != 'edge_loss']

        # Check if essential grouping columns and at least one aggregation column exist
        # Added 'noise_type' as an essential grouping column if it's expected
        # If 'noise_type' is optional, you might remove it from essential_grouping_cols
        essential_grouping_cols = ['node', 'edge', 'porders', 'T', 'noise_type', 'name'] # Added 'noise_type'
        missing_essential_grouping = [col for col in essential_grouping_cols if col not in final_grouping_cols]

        if missing_essential_grouping:
            print(f"Error: Cannot calculate averages for '{df_name}' file. Missing essential grouping columns: {missing_essential_grouping}")
            print(f"Available columns: {df.columns.tolist()}")
            return
        if not final_agg_cols:
             print(f"Error: Cannot calculate averages for '{df_name}' file. No aggregation columns available for averaging (excluding edge_loss).")
             print(f"Available columns: {df.columns.tolist()}")
             return

        try:
            # Groupby().mean() automatically skips NaN values during calculation
            # Use observed=False to include combinations even if no data exists for them in this df
            # dropna=False keeps groups with NaN keys if they exist (relevant for 'noise_type' if it can be NaN/None)
            averaged_df = df.groupby(final_grouping_cols, observed=False, dropna=False)[final_agg_cols].mean().reset_index().round(4)

            # Reorder averaged columns to match desired output as much as possible (excluding edge_loss)
            # Use the full grouping_columns list for ordering the start of the dataframe
            avg_cols_in_order = [col for col in grouping_columns if col in averaged_df.columns] + \
                                [col for col in potential_agg_cols if col in averaged_df.columns and col != 'edge_loss' and col not in grouping_columns] # Exclude grouping cols here too
            other_avg_cols = [col for col in averaged_df.columns if col not in avg_cols_in_order]
            averaged_df = averaged_df[avg_cols_in_order + other_avg_cols]


            averaged_df.to_csv(output_path, index=False)
            print(f"Calculated group averages for '{df_name}' files saved to: {output_path}")
        except Exception as e:
            print(f"Error during group averaging for '{df_name}': {e}")
            print(f"Grouping by: {final_grouping_cols}")
            print(f"Aggregating on: {final_agg_cols}")
            print(f"Available columns: {df.columns.tolist()}")


    # --- Calculate averages for 'all' files ---
    averaged_all_output_path = os.path.join(parent_directory, averaged_all_file)
    calculate_averages(merged_all_df, grouping_columns, all_potential_agg_cols, averaged_all_output_path, 'all') # Pass the full grouping_columns list

    # --- Calculate averages for 'base' files ---
    averaged_base_output_path = os.path.join(parent_directory, averaged_base_file)
    calculate_averages(merged_base_df, grouping_columns, all_potential_agg_cols, averaged_base_output_path, 'base') # Pass the full grouping_columns list


    print("\nProcessing finished!")


# --- Call the merge function for your directories ---
# Ensure these paths are correct for your setup

merge('result/result_lambda_e/and_initdata')
merge('result/result_lambda_e/and_init0')
merge('result/result_lambda_e/multiply_initdata')
merge('result/result_lambda_e/multiply_init0')

# merge('result/exp_and') # Uncomment if needed

