import pandas as pd
import numpy as np

# Updated format_value function, ensuring there is always a space between the mean and the scriptstyle part
def format_value(mean_val, std_dev_val, is_best):
    # Check for NaN
    if pd.isna(mean_val) or pd.isna(std_dev_val):
        return "N/A"

    mean_str = f"{mean_val:.2f}"
    std_dev_str = f"{std_dev_val:.2f}" # Assuming "var" refers to the standard deviation value

    # Core content string, including a space
    core_content = f"{mean_str} {{\\scriptstyle \\pm {std_dev_str}}}"

    if is_best:
        return f"$\\bm{{{core_content}}}$"
    else:
        return f"${core_content}$"

# Function to load data from an actual CSV file
def load_real_csv_data(file_path):
    """
    Loads data from the specified CSV file and removes leading/trailing spaces from string data in each column.

    Args:
        file_path (str): The path to the CSV file.

    Returns:
        pandas.DataFrame: The loaded and cleaned DataFrame.
                          Returns an empty DataFrame if the file is not found or an error occurs.
    """
    try:
        df = pd.read_csv(file_path)

        # Iterate over all columns
        for col in df.columns:
            # Check if the column's data type is object (usually string)
            if df[col].dtype == 'object':
                # Remove leading/trailing spaces from string data in the column
                # .str.strip() automatically handles NaN values (keeps them as NaN)
                df[col] = df[col].str.strip()

                # If you also need to remove all internal spaces from strings (e.g., "New York" -> "NewYork"),
                # you can use: df[col] = df[col].str.replace(' ', '', regex=False)
                # regex=False means ' ' is treated as a literal replacement, not a regular expression.
                # Usually, .str.strip() is used to clear unexpected whitespace.

        return df
    except FileNotFoundError:
        print(f"Error: File not found {file_path}")
        return pd.DataFrame()
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return pd.DataFrame()

# --- Configuration Information ---
BASE_PATH = "result/result_absence"
ALGO_DIRS = {
    "DYNOTEARS-ABS": "absence",
    "DYNOTEARS& (Init 0)": "and_init0",
    "DYNOTEARS& (Init Data)": "and_initdata",
    "DYNOTEARS* (Init 0)": "multiply_init0",
    "DYNOTEARS* (Init Data)": "multiply_initdata",
}
LATEX_METHOD_NAMES = {
    "Baseline": "Baseline",
    "DYNOTEARS-ABS": "DYNOTEARS-ABS",
    "DYNOTEARS& (Init 0)": "DYNOTEARS\\& (Init 0)",
    "DYNOTEARS& (Init Data)": "DYNOTEARS\\& (Init Data)",
    "DYNOTEARS* (Init 0)": "DYNOTEARS* (Init 0)",
    "DYNOTEARS* (Init Data)": "DYNOTEARS* (Init Data)",
}
METHOD_ORDER = ["Baseline"] + list(ALGO_DIRS.keys())

FIXED_PORDERS = 3
FIXED_NOISE_TYPE = 'noisegauss'
FIXED_EDGE_PRIOR_PROB = 0.8
FIXED_NAME = 'timeseries'

ER_TYPES = {"ER2": 2, "ER4": 4}
T_VALUES = [250, 1000]
NODE_VALUES = [20, 30, 50]
METRICS = ["SHD", "F1"]

# --- Data Collection and Processing ---
results = {er: {m: {meth: {t: {n: (np.nan, np.nan) for n in NODE_VALUES} for t in T_VALUES} for meth in METHOD_ORDER} for m in METRICS} for er in ER_TYPES}

baseline_file_for_all = f"{BASE_PATH}/{ALGO_DIRS[list(ALGO_DIRS.keys())[0]]}/merged_base_summary.csv"
baseline_df_full = load_real_csv_data(baseline_file_for_all)
if baseline_df_full.empty:
    print(f"Error: Failed to load baseline data file {baseline_file_for_all}. Aborting.")
    exit()

for er_name, er_multiplier in ER_TYPES.items():
    for T_val in T_VALUES:
        for node_val in NODE_VALUES:
            edge_val = node_val * er_multiplier

            try:
                baseline_filtered_cond = baseline_df_full[
                    (baseline_df_full['node'] == node_val) &
                    (baseline_df_full['edge'] == edge_val) &
                    (baseline_df_full['porders'] == FIXED_PORDERS) &
                    (baseline_df_full['T'] == T_val) &
                    (baseline_df_full['noise_type'] == FIXED_NOISE_TYPE) &
                    (baseline_df_full['edge_prior_prob'] == FIXED_EDGE_PRIOR_PROB) &
                    (baseline_df_full['name'] == FIXED_NAME)
                ]
            except KeyError as e:
                print(f"Error: Missing column in baseline data file: {e}.")
                baseline_filtered_cond = pd.DataFrame()

            if not baseline_filtered_cond.empty:
                results[er_name]["SHD"]["Baseline"][T_val][node_val] = (baseline_filtered_cond['shd'].mean(), baseline_filtered_cond['shd'].std(ddof=0 if len(baseline_filtered_cond['shd'])==1 else 1))
                results[er_name]["F1"]["Baseline"][T_val][node_val] = (baseline_filtered_cond['f1'].mean(), baseline_filtered_cond['f1'].std(ddof=0 if len(baseline_filtered_cond['f1'])==1 else 1))
            else:
                results[er_name]["SHD"]["Baseline"][T_val][node_val] = (np.nan, np.nan)
                results[er_name]["F1"]["Baseline"][T_val][node_val] = (np.nan, np.nan)

            for method_name_key, algo_dir in ALGO_DIRS.items():
                algo_diff_file = f"{BASE_PATH}/{algo_dir}/merged_all_summary.csv"
                algo_diff_df = load_real_csv_data(algo_diff_file)

                if algo_diff_df.empty:
                    # print(f"Warning: Failed to load algorithm difference file {algo_diff_file} or file is empty.") # Removed original print, as if file doesn't exist or is empty, result will be NaN, shown as N/A in table
                    results[er_name]["SHD"][method_name_key][T_val][node_val] = (np.nan, np.nan)
                    results[er_name]["F1"][method_name_key][T_val][node_val] = (np.nan, np.nan)
                    continue

                try:
                    current_cond_filter_algo = (
                        (algo_diff_df['node'] == node_val) & (algo_diff_df['edge'] == edge_val) &
                        (algo_diff_df['porders'] == FIXED_PORDERS) & (algo_diff_df['T'] == T_val) &
                        (algo_diff_df['noise_type'] == FIXED_NOISE_TYPE) & # Use updated FIXED_NOISE_TYPE
                        (algo_diff_df['edge_prior_prob'] == FIXED_EDGE_PRIOR_PROB) &
                        (algo_diff_df['name'] == FIXED_NAME)
                    )
                    algo_diff_filtered = algo_diff_df[current_cond_filter_algo]
                except KeyError as e:
                    print(f"Error: Missing column in difference data file for algorithm {method_name_key}: {e}.")
                    algo_diff_filtered = pd.DataFrame()

                if not algo_diff_filtered.empty and not baseline_filtered_cond.empty:
                    try:
                        merge_on_cols = ['node', 'edge', 'porders', 'T', 'noise_type', 'dataset_index', 'edge_prior_prob', 'name']

                        algo_diff_merge_ready = algo_diff_filtered[merge_on_cols + ['f1', 'shd']].copy()
                        baseline_merge_ready = baseline_filtered_cond[merge_on_cols + ['f1', 'shd']].copy()

                        merged_for_calc = pd.merge(
                            algo_diff_merge_ready,
                            baseline_merge_ready,
                            on=merge_on_cols,
                            suffixes=('_diff', '_base')
                        )
                    except KeyError as e:
                        print(f"Error: Missing column when merging data: {e}. Algorithm: {method_name_key}, T={T_val}, Node={node_val}, ER={er_name}")
                        merged_for_calc = pd.DataFrame()

                    if not merged_for_calc.empty:
                        true_shd_values = merged_for_calc['shd_diff'] + merged_for_calc['shd_base']
                        true_f1_values = merged_for_calc['f1_diff'] + merged_for_calc['f1_base']

                        results[er_name]["SHD"][method_name_key][T_val][node_val] = (true_shd_values.mean(), true_shd_values.std(ddof=0 if len(true_shd_values)==1 else 1))
                        results[er_name]["F1"][method_name_key][T_val][node_val] = (true_f1_values.mean(), true_f1_values.std(ddof=0 if len(true_f1_values)==1 else 1))
                    else:
                        results[er_name]["SHD"][method_name_key][T_val][node_val] = (np.nan, np.nan)
                        results[er_name]["F1"][method_name_key][T_val][node_val] = (np.nan, np.nan)
                else:
                    results[er_name]["SHD"][method_name_key][T_val][node_val] = (np.nan, np.nan)
                    results[er_name]["F1"][method_name_key][T_val][node_val] = (np.nan, np.nan)

# --- Determine best results for bolding ---
best_flags = {er: {m: {t: {n: {meth: False for meth in METHOD_ORDER} for n in NODE_VALUES} for t in T_VALUES} for m in METRICS} for er in ER_TYPES}
for er_name in ER_TYPES:
    for T_val in T_VALUES:
        for node_val in NODE_VALUES:
            for metric_name in METRICS:
                current_means_map = {}
                for method_key in METHOD_ORDER:
                    mean_val, _ = results[er_name][metric_name][method_key][T_val][node_val]
                    current_means_map[method_key] = mean_val

                valid_means = {k: v for k, v in current_means_map.items() if not pd.isna(v)}
                if not valid_means: continue

                if metric_name == "SHD":
                    best_mean_val = min(valid_means.values())
                else:
                    best_mean_val = max(valid_means.values())

                for method_key, mean_val_iter in valid_means.items():
                    if np.isclose(mean_val_iter, best_mean_val): # Use np.isclose to handle potential floating point precision issues
                        best_flags[er_name][metric_name][T_val][node_val][method_key] = True

# --- Generate LaTeX Table ---
# Required LaTeX packages: \usepackage{multirow} \usepackage{booktabs} \usepackage{bm}
latex_str = "\\begin{tabular}{@{}lllcccccc@{}}\n"  # @{} removes extra space on either side of columns
latex_str += "\\toprule\n"
latex_str += "ER Type & Metric & Method & \\multicolumn{3}{c}{T=250} & \\multicolumn{3}{c}{T=1000} \\\\\n"
latex_str += "\\cmidrule(lr){4-6} \\cmidrule(lr){7-9}\n"
latex_str += "& & & Node20 & Node30 & Node50 & Node20 & Node30 & Node50 \\\\\n"
latex_str += "\\midrule\n"

first_er_block = True
for er_name in ER_TYPES:
    if not first_er_block:
        latex_str += "\\midrule\n"

    er_row_span = len(METRICS) * len(METHOD_ORDER)
    latex_str += f"\\multirow{{{er_row_span}}}{{*}}{{{er_name}}} "

    first_metric_in_er = True
    for metric_idx, metric_name in enumerate(METRICS): # Use enumerate to get index
        if not first_metric_in_er:
            latex_str += " " # Placeholder for ER column
        
        metric_row_span = len(METHOD_ORDER)
        latex_str += f"& \\multirow{{{metric_row_span}}}{{*}}{{{metric_name}}} "
        
        for method_idx, method_key in enumerate(METHOD_ORDER): # Use enumerate to get index
            if method_idx > 0: # For subsequent method rows under the same metric
                latex_str += "& & " # Placeholders for ER and Metric columns
            
            latex_method_name = LATEX_METHOD_NAMES[method_key]
            latex_str += f"& {latex_method_name} "

            for T_val_loop in T_VALUES:
                for node_val_loop in NODE_VALUES:
                    mean_val, std_val = results[er_name][metric_name][method_key][T_val_loop][node_val_loop]
                    is_best = best_flags[er_name][metric_name][T_val_loop][node_val_loop][method_key]
                    formatted_str = format_value(mean_val, std_val, is_best)
                    latex_str += f"& {formatted_str} "
            
            latex_str += "\\\\\n"
        
        # Draw a line between SHD and F1 sub-blocks within an ER block, but not a duplicate ER midrule after the last F1 sub-block
        if metric_idx < len(METRICS) - 1:
            latex_str += "\\cmidrule(lr){3-9}\n"
        first_metric_in_er = False
    
    first_er_block = False

latex_str += "\\bottomrule\n"
latex_str += "\\end{tabular}\n"

print("--- LaTeX Code ---")
print(latex_str)
print("\n--- Required LaTeX Packages ---")
print("\\usepackage{multirow}")
print("\\usepackage{booktabs}")
print("\\usepackage{bm}")