import re
import pandas as pd


SCENARIOS = {
    "clusters1234_h0_delta2_audio_prompts": {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h0_delta2_aug/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_noaug_clusters1234_h0_delta2_aug/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_aug_clusters1234_h0_delta2_aug/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h0_delta2_aug/results_selected.csv",
    },
    "clusters1234_h0_delta2_librispeech": {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h0_delta2_aug/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_noaug_clusters1234_h0_delta2_aug/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_aug_clusters1234_h0_delta2_aug/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h0_delta2_aug/results_selected.csv",
    },
    
    "clusters1234_h0_delta2_music": {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/final_clusters0123_h0_delta2_aug/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/meta_noaug_clusters1234_h0_delta2_aug/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/final_clusters0123_h0_delta2_aug/results_selected.csv",
    },
}


AUG_MAPPING = {
    "BandpassFilter": "Bandpass", "DacCompression": "DAC", "EncodecCompression": "EnCodec",
    "HighpassFilter": "Highpass", "LowpassFilter": "Lowpass", "MP3Compression": "MP3",
    "NoiseInjection": "Noise", "TemporalCrop": "Crop", "TimeShift": "Shift",
    "Speechtokenizer": "SpeechTok", "Speed": "Speedup"
}

AUG_GROUPS = {
    "Baseline": ["Identity"],
    "Signal Proc.": ["Smooth", "Lowpass", "Highpass", "Noise"],
    "Compression": ["MP3", "DAC", "EnCodec", "SpeechTok", "FaCodec"],
    "Temporal": ["Crop", "Shift", "Speedup"],
}

# Create helper maps for grouping and sorting
# Flatten the groups into a single list for sort order
SORT_ORDER = [aug for key in AUG_GROUPS for aug in AUG_GROUPS[key]]
# Map AugName -> GroupName
AUG_TO_GROUP = {aug: key for key, val in AUG_GROUPS.items() for aug in val}


for scenario, paths in SCENARIOS.items():
    rows = []
    for method_name, path in paths.items():
        try:
            df = pd.read_csv(path)
            df["aug_name"] = df["aug_name"].map(AUG_MAPPING).fillna(df["aug_name"])
            
            per_aug = (
                df.groupby("aug_name")
                .agg(mean_logpval=("logpval", "mean"))
                .reset_index()
            )
            for _, r in per_aug.iterrows():
                rows.append({
                    "method": method_name,
                    "aug_name": r["aug_name"],
                    "mean_logpval": r["mean_logpval"],
                })
        except Exception as e:
            print(f"Error reading {method_name}: {e}")

    if not rows: continue

    # --- CREATE BASE TABLE ---
    df_long = pd.DataFrame(rows)
    # Pivot: Index=Augmentation, Cols=Method
    df_wide = df_long.pivot(index="aug_name", columns="method", values="mean_logpval")

    # --- CALCULATE LOSS & RESTRUCTURE ---
    # Get Identity values for reference (Loss = Identity - Value)
    ref_vals = df_wide.loc["Identity"]

    dfs_combined = []
    # Loop through methods in the user's preferred order
    desired_order = [k for k in paths.keys() if k in df_wide.columns]
    
    for method in desired_order:
        val_col = df_wide[method]
        # Calculate Loss (Clip to 0 to avoid negative loss on random noise variance)
        loss_col = (ref_vals[method] - val_col)#.clip(lower=0.0)
        
        # Combine into a mini-dataframe with MultiIndex Columns
        # Top Level: Method Name, Sub Level: Value Type
        sub_df = pd.concat([val_col, loss_col], axis=1)
        sub_df.columns = pd.MultiIndex.from_product([[method], ["$\\log(p)$", "Loss"]])
        dfs_combined.append(sub_df)

    final_df = pd.concat(dfs_combined, axis=1)

    # --- GROUPING & SORTING ---
    # Add 'Group' level to index
    final_df.index.name = "Transformation"
    final_df["Group"] = final_df.index.map(AUG_TO_GROUP).fillna("Other")
    final_df = final_df.reset_index().set_index(["Group", "Transformation"])

    # Reorder rows based on the strict SORT_ORDER list
    # Filter list to only include rows that actually exist in data
    valid_order = [x for x in SORT_ORDER if x in final_df.index.get_level_values(1)]
    # Create the correct MultiIndex for reindexing
    new_index = pd.MultiIndex.from_tuples([(AUG_TO_GROUP.get(x, "Other"), x) for x in valid_order], names=["Group", "Transformation"])
    final_df = final_df.reindex(new_index)

    # --- FORMATTING ---
    # Calculate Max Loss for the Color Scale (Use only 'Loss' columns)
    loss_cols = final_df.xs("Loss", level=1, axis=1)
    max_val = loss_cols.max().max()
    if pd.isna(max_val) or max_val == 0: max_val = 1.0

    final_df.index.names = [None, None]

    # Define Columns: 'll' for index, then 'cE' pair for each method
    # c = Centered (Value), E = Evaluation/Colored (Loss)
    col_fmt = "ll" + "cE" * len(desired_order)
    
    # Formatters need to target the tuples (Method, Type)
    formatters = {}
    for col in final_df.columns:
        formatters[col] = lambda x: f"{x:.2f}"

    latex = final_df.to_latex(
        index=True,          # Keep Group/Transformation index
        multirow=True,       # Group the 'Group' column visually
        column_format=col_fmt,
        caption=f"Mean log p-values and Loss (relative to Identity) for {scenario}.",
        label=f"tab:rob-{scenario}",
        formatters=formatters,
        multicolumn_format="c" # Center the Method Name header
    ).replace("_", "-")

    latex = latex.replace(" &  & $\log(p)$", "Group & Transformation & $\log(p)$")

    # Clean up standard replacements
    latex = latex.replace("multirow[t]", "multirow")
    latex = latex.replace("\\begin{table}", "\\begin{table}[ht]")
    latex = latex.replace("\\begin{tabular}", "\\centering\n\\begin{tabular}")
    latex = latex.replace("\\end{tabular}", "\\end{tabular}\n\\vskip -0.1in")
    latex = latex.replace("{table}", "{table*}") # Use table* for wide tables
    # Remove the unnecessary cline appearing right before bottomrule
    latex = re.sub(r'\\cline\{[0-9-]+\}\n\\bottomrule', r'\\bottomrule', latex)
    
    # Prepend the Max Value definition
    latex_prefix = f"\\renewcommand{{\\TableMax}}{{{max_val:.2f}}}\n"
    print("\n" + latex_prefix + latex)
