import os
import re
import ast
import pandas as pd
import numpy as np

# --- Configuration ---
BASE_RESULTS_DIR = 'results'
LATEX_OUTPUT_FILE = 'experiment_summary_table.tex'

# Algorithms involved
ALGORITHMS_IN_DATA = ['nts_notears', 'nts_notears_multiply'] # 'nts_notears_and' is not in this table

# Metrics for this specific table
METRICS_FOR_TABLE_INTERNAL = ['shd', 'f1_score', 'edge_recovery']
METRICS_FOR_TABLE_DISPLAY = {
    'shd': 'SHD',
    'f1_score': 'F1-score',
    'edge_recovery': 'Edge Recovery'
}

SEM_TYPES_FOR_TABLE = {
    'ANM': 'AdditiveNoiseModel',
    'AIM': 'AdditiveIndexModel'
}
SEM_TYPE_DISPLAY_IN_TABLE = {
    'AdditiveNoiseModel': 'ANM (Additive Noise Model)',
    'AdditiveIndexModel': 'AIM (Additive Index Model)'
}

SEQUENCE_LENGTHS_FOR_TABLE = [200, 1000]

# Defines rows in the table and how to get their data
# Each tuple: (Display Name, Algorithm Name in Data, Exist Edges Prob for this row)
ALGORITHM_ROWS_CONFIG = [
    ('nts-notears', 'nts_notears', 0.2), # Using exist_edges_prob=0.2 for nts_notears
    ('nts-notears*\\_20', 'nts_notears_multiply', 0.2),
    ('nts-notears*\\_40', 'nts_notears_multiply', 0.4),
    ('nts-notears*\\_60', 'nts_notears_multiply', 0.6),
    ('nts-notears*\\_80', 'nts_notears_multiply', 0.8),
    ('nts-notears*\\_100', 'nts_notears_multiply', 1.0),
]
# Note: Using '\\_' for LaTeX underscore in display names if not in math mode.

PARAM_DIR_REGEX = re.compile(
    r"sl(\d+)_d(\d+)_sem(AIM|ANM)_lags(\d+)_seed(\d+)_prob([\d.]+)"
)
SEM_SHORT_TO_LONG = {
    'ANM': 'AdditiveNoiseModel',
    'AIM': 'AdditiveIndexModel'
}
# --- Data Loading and Parsing Functions ---
def parse_param_dir_name(dir_name):
    match = PARAM_DIR_REGEX.match(dir_name)
    if not match: return None
    sl, d_val, sem_short, lags, seed_val, prob_val = match.groups()
    return {
        'sequence_length': int(sl), 'd': int(d_val),
        'sem_type_short': sem_short, 'sem_type': SEM_SHORT_TO_LONG[sem_short],
        'number_of_lags': int(lags), 'seed': int(seed_val),
        'exist_edges_prob': float(prob_val)
    }

def load_experiment_data(base_dir):
    all_data = []
    for alg_name in ALGORITHMS_IN_DATA: # Only load data for relevant algorithms
        alg_dir = os.path.join(base_dir, alg_name)
        if not os.path.isdir(alg_dir): continue
        for param_dir_name in os.listdir(alg_dir):
            param_dir_path = os.path.join(alg_dir, param_dir_name)
            if not os.path.isdir(param_dir_path): continue
            parsed_params = parse_param_dir_name(param_dir_name)
            if not parsed_params: continue
            if parsed_params['d'] != 20 or parsed_params['number_of_lags'] != 3: continue

            metrics_file_path = os.path.join(param_dir_path, 'SHD.txt')
            if not os.path.isfile(metrics_file_path): continue
            try:
                with open(metrics_file_path, 'r') as f:
                    content = f.read().strip()
                    if not content: continue # Skip empty files
                    metrics_dict = ast.literal_eval(content)
                
                # Ensure all required metrics for the table are present in the loaded dict
                # This helps catch issues if SHD.txt format changes or is missing keys
                valid_entry = True
                for metric_key in METRICS_FOR_TABLE_INTERNAL:
                    if metric_key not in metrics_dict:
                        # print(f"Warning: Metric '{metric_key}' not found in {metrics_file_path}. Skipping entry.")
                        valid_entry = False
                        break
                if not valid_entry:
                    continue

                row_data = {'algorithm': alg_name, **parsed_params, **metrics_dict}
                all_data.append(row_data)
            except Exception as e:
                print(f"Error reading/parsing {metrics_file_path}: {e}")
    return pd.DataFrame(all_data)

def aggregate_data(df):
    if df.empty: return pd.DataFrame()
    # Ensure all metrics needed for aggregation are present as columns
    for metric_key in METRICS_FOR_TABLE_INTERNAL:
        if metric_key not in df.columns:
            print(f"Error: Metric column '{metric_key}' missing from DataFrame before aggregation.")
            # Return empty or handle as error, depending on strictness
            return pd.DataFrame()
            
    agg_funcs = {metric: [np.mean, np.std] for metric in METRICS_FOR_TABLE_INTERNAL}
    group_by_cols = ['algorithm', 'sequence_length', 'sem_type', 'exist_edges_prob']
    
    missing_cols = [col for col in group_by_cols if col not in df.columns]
    if missing_cols:
        print(f"Error: Missing columns for aggregation: {missing_cols}")
        return pd.DataFrame()
        
    aggregated_df = df.groupby(group_by_cols).agg(agg_funcs)
    return aggregated_df # Keep MultiIndex for easier access later

# --- LaTeX Table Generation ---
def format_cell_value(mean_val, std_val, metric_name_internal): # metric_name_internal is not used now for formatting
    if pd.isna(mean_val) or pd.isna(std_val):
        return "N/A"
    
    # All metrics now use 2 decimal places for both mean and std
    std_str = f"{{\\scriptsize $\pm$ {std_val:.2f}}}" 
    mean_str = f"{mean_val:.2f}"
    
    return f"{mean_str}{std_str}"

def generate_latex_table(df_agg):
    if df_agg.empty:
        return "% No aggregated data to generate table.\n"

    tex_lines = []
    tex_lines.append("\\begin{table*}[htbp]")
    tex_lines.append("  \\centering")
    tex_lines.append("  \\small") 
    tex_lines.append("  \\caption{Experiment Results Summary (Mean $\pm$ Std. Dev.)}")
    tex_lines.append("  \\label{tab:experiment_summary}")
    
    num_metrics = len(METRICS_FOR_TABLE_INTERNAL)
    # New column spec: ll (for SEM, Method) + c*num_metrics for each sequence length
    col_spec = "ll" + ("c" * num_metrics) * len(SEQUENCE_LENGTHS_FOR_TABLE)
    tex_lines.append(f"  \\begin{{tabular}}{{{col_spec}}}")
    tex_lines.append("    \\toprule")

    # Header Row 1: SEM, Method, ER(1,1)_SL1, ER(1,1)_SL2
    header1_parts = ["SEM", "Method"]
    for sl in SEQUENCE_LENGTHS_FOR_TABLE:
        header1_parts.append(f"\\multicolumn{{{num_metrics}}}{{c}}{{ER(1,1)\_{sl}}}")
    tex_lines.append("    " + " & ".join(header1_parts) + " \\\\")

    # Header Row 2: Metric sub-headers (empty for SEM and Method columns)
    header2_parts = ["", ""] 
    for _ in SEQUENCE_LENGTHS_FOR_TABLE:
        for metric_internal_name in METRICS_FOR_TABLE_INTERNAL:
            header2_parts.append(METRICS_FOR_TABLE_DISPLAY[metric_internal_name])
    tex_lines.append("    " + " & ".join(header2_parts) + " \\\\")
    tex_lines.append("    \\midrule")

    # Data Rows
    # Iterate SEM types first, then algorithm configurations
    for sem_idx, (sem_short_key, sem_long_name_data) in enumerate(SEM_TYPES_FOR_TABLE.items()):
        for alg_display_name, alg_data_name, alg_prob in ALGORITHM_ROWS_CONFIG:
            # First two cells are SEM type and Algorithm display name
            row_cells = [sem_short_key, alg_display_name] 
            
            for sl_data in SEQUENCE_LENGTHS_FOR_TABLE:
                for metric_internal in METRICS_FOR_TABLE_INTERNAL:
                    try:
                        data_point = df_agg.loc[
                            (alg_data_name, sl_data, sem_long_name_data, alg_prob)
                        ]
                        mean_val = data_point[(metric_internal, 'mean')]
                        std_val = data_point[(metric_internal, 'std')]
                        # Pass metric_internal for potential future specific formatting, though not used now for decimal places
                        row_cells.append(format_cell_value(mean_val, std_val, metric_internal))
                    except KeyError:
                        row_cells.append("N/A")
            tex_lines.append("    " + " & ".join(row_cells) + " \\\\")
        
        # Add a midrule after all methods for one SEM type, if not the last SEM type
        if sem_idx < len(SEM_TYPES_FOR_TABLE) - 1:
             tex_lines.append("    \\midrule")

    tex_lines.append("    \\bottomrule")
    tex_lines.append("  \\end{tabular}")
    tex_lines.append("\\end{table*}")
    return "\\n".join(tex_lines) + "\n"

def main():
    print("Starting LaTeX summary table generation script...")
    
    raw_data_df = load_experiment_data(BASE_RESULTS_DIR)
    if raw_data_df.empty:
        print("No data loaded. Exiting.")
        return
    # print(f"Loaded {len(raw_data_df)} data points. Columns: {raw_data_df.columns.tolist()}")
    # print("Sample raw data:")
    # print(raw_data_df.head())


    aggregated_df = aggregate_data(raw_data_df)
    if aggregated_df.empty and not raw_data_df.empty : # Check if agg failed but raw data existed
        print("Data aggregation resulted in an empty dataframe, though raw data was present. Check for missing columns or keys required for grouping/aggregation.")
        return
    if aggregated_df.empty:
        print("Data aggregation failed or resulted in empty dataframe. Exiting.")
        return
        
    # print(f"Aggregated data. Index: {aggregated_df.index.names}")
    # print("Sample aggregated data:")
    # print(aggregated_df.head())

    latex_table_string = generate_latex_table(aggregated_df)

    # Minimal LaTeX document structure
    full_latex_doc = [
        "\\documentclass[10pt]{article}", # Using 10pt as base, table is \\small
        "\\usepackage{booktabs}",
        "\\usepackage{amsmath}",
        "\\usepackage{multirow}",
        "\\usepackage[margin=1in]{geometry}",
        "\\usepackage{lscape}", # For landscape if table is too wide
        "",
        "\\begin{document}",
        "",
        "% Uncomment for landscape if needed:",
        "% \\begin{landscape}",
        latex_table_string,
        "% \\end{landscape}",
        "",
        "\\end{document}"
    ]

    try:
        with open(LATEX_OUTPUT_FILE, 'w') as f:
            f.write("\n".join(full_latex_doc))
        print(f"\nSuccessfully wrote LaTeX summary table to: {os.path.abspath(LATEX_OUTPUT_FILE)}")
        print("You may need 'booktabs', 'amsmath', 'multirow', 'lscape' LaTeX packages.")
    except IOError as e:
        print(f"Error writing LaTeX file {LATEX_OUTPUT_FILE}: {e}")

    print("\nLaTeX table generation script finished.")

if __name__ == "__main__":
    main() 