import pandas as pd
import re
import os

# --- CONFIGURATION ---
INPUT_FILE = "/home/wmar/wmar_audio/outputs/fad.txt"
NISQA_FILENAME = "NISQA_results.csv"
DNSMOS_FILENAME = "DNSMOSPro_results.csv"

# Map: Substring -> Full Readable Name
# NOTE: Ensure every value contains 'h=X' so the parser can find the group.
METHOD_MAP = {
    "unwatermarked_1/audio_standard": "Clean",

    "clusters1234_h0_delta2/audio_standard": "Base $h=0$",
    "meta_aug_clusters1234_h0": "WMAR (aug) $h=0$",
    "meta_noaug_clusters1234_h0": "WMAR $h=0$",
    "clusters1234_h0_delta2/audio_selected": "Ours $h=0$",
    
    "clusters1234_h1_delta2/audio_standard": "Base $h=1$",
    "meta_aug_clusters1234_h1": "WMAR (aug) $h=1$",
    "meta_noaug_clusters1234_h1": "WMAR $h=1$",
    "clusters1234_h1_delta2/audio_selected": "Ours $h=1$",

    "clusters1234_h2_delta2/audio_standard": "Base $h=2$",
    "meta_aug_clusters1234_h2": "WMAR (aug) $h=2$",
    "meta_noaug_clusters1234_h2": "WMAR $h=2$",
    "clusters1234_h2_delta2/audio_selected": "Ours $h=2$",
}

# --- HELPER FUNCTIONS ---

def get_mos_stats(dir_path, filename):
    """Reads 2nd column of CSV, returns Mean ± Std."""
    full_path = os.path.join(dir_path, filename)
    if not os.path.exists(full_path): return None
    try:
        df = pd.read_csv(full_path)
        if df.shape[1] < 2: return None
        # Blindly take 2nd column
        scores = pd.to_numeric(df.iloc[:, 1], errors='coerce').dropna()
        if scores.empty: return None
        return f"{scores.mean():.2f} $\pm$ {scores.std():.2f}"
    except: return None

def parse_group_method(full_name):
    """
    Splits 'WMAR (aug) $h=0$' into ('$h=0$', 'WMAR (aug)')
    """
    # Regex to capture h=number
    match = re.search(r'(\$h=\d+\$)', full_name)
    if match:
        group = match.group(1) # e.g. "$h=0$"
        name = full_name.replace(group, "").strip() # Remove "$h=0$" from string
        return group, name
    return "", full_name

def parse_fad_file(filepath):
    data = []
    current_model = None
    pattern = re.compile(r'(\S+)\s+FAD:\s+([0-9.]+)')

    if not os.path.exists(filepath):
        print(f"File not found: {filepath}")
        return pd.DataFrame()

    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            if line.lower() in ["vggish", "clap"]:
                current_model = line.lower()
                continue
            if "throw an exception" in line: continue

            match = pattern.search(line)
            if match and current_model:
                dir_path = match.group(1)
                fad_score = float(match.group(2))
                
                if "audio_prompts" in dir_path: dataset = "Audio Prompts"
                elif "librispeech" in dir_path: dataset = "Librispeech"
                else: dataset = "Unknown"

                # Find map match
                full_method = "Other"
                for key, name in METHOD_MAP.items():
                    if key in dir_path:
                        full_method = name
                        break
                
                if full_method == "Other":
                    continue
                
                data.append({
                    "dataset": dataset,
                    "full_method": full_method,
                    "dir_path": dir_path,
                    "model": current_model,
                    "fad_score": fad_score
                })
    return pd.DataFrame(data)

# --- EXECUTION ---

df = parse_fad_file(INPUT_FILE)

# 1. Fetch MOS (Merge by directory)
unique_dirs = df[['dir_path']].drop_duplicates()
mos_data = []

for _, row in unique_dirs.iterrows():
    p = row['dir_path']
    mos_data.append({
        "dir_path": p,
        "NISQA": get_mos_stats(p, NISQA_FILENAME),
        "DNSMOS": get_mos_stats(p, DNSMOS_FILENAME)
    })
df = pd.merge(df, pd.DataFrame(mos_data), on="dir_path", how="left")

# 2. Process per Dataset
for ds in df['dataset'].unique():
    df_ds = df[df['dataset'] == ds].copy()

    # 3. Split Name into ($h$-gram, Method)
    # We apply the parser to the 'full_method' column
    parsed = df_ds['full_method'].apply(parse_group_method)
    df_ds['$h$-gram'] = [p[0] for p in parsed]
    df_ds['Method'] = [p[1] for p in parsed]

    # 4. Pivot FAD scores -> Index becomes [$h$-gram, Method]
    table_fad = df_ds.pivot_table(
        index=["$h$-gram", "Method"], 
        columns="model", 
        values="fad_score"
    )

    # 5. Get MOS scores (drop duplicates since they are same for vggish/clap)
    table_mos = df_ds[['$h$-gram', 'Method', 'NISQA', 'DNSMOS']].drop_duplicates().set_index(['$h$-gram', 'Method'])

    # 6. Join & Format
    final = table_fad.join(table_mos)
    final = final.rename(columns={"vggish": "FAD (VGGish)", "clap": "FAD (CLAP)"})
    
    # Select and Order Columns
    cols = ["FAD (VGGish)", "FAD (CLAP)", "NISQA", "DNSMOS"]
    final = final[[c for c in cols if c in final.columns]]

    # Create a MultiIndex to force LaTeX multicolumns
    # Structure: (Top Header, Bottom Header)
    new_headers = []
    for col in final.columns:
        if "FAD" in col:
            # Turns "FAD (VGGish)" -> ("FAD", "VGGish")
            # Turns "FAD (CLAP)"   -> ("FAD", "CLAP")
            sub_header = "VGGish" if "VGGish" in col else "CLAP"
            new_headers.append(("FAD", sub_header))
        else:
            # Turns "NISQA"  -> ("MOS", "NISQA")
            # Turns "DNSMOS" -> ("MOS", "DNSMOS")
            new_headers.append(("MOS", col))
    
    final.columns = pd.MultiIndex.from_tuples(new_headers)
    
    # Sort Index to ensure $h=0$, $h=1$, $h=2$ come in order
    final = final.sort_index()

    # --- LATEX GENERATION ---
    final.index.names = [None, None]




    
    # Define formats: Left Left (for index) then Center Center Center Center
    col_fmt = "ll" + "c" * len(final.columns)
    
    latex = final.to_latex(
        multirow=True,     # $h$-grams the '$h$-gram' index
        column_format=col_fmt,
        float_format="{:.3f}".format,
        caption=f"Audio quality scores for {ds}.",
        label=f"tab:fad_{ds.lower().replace(' ', '_')}",
        multicolumn_format="c"
    ).replace("_", "-")

    latex = latex.replace(" &  & VGGish", "$h$-gram & Method & VGGish")

    latex = latex.replace("multirow[t]", "multirow")
    latex = latex.replace("\\begin{table}", "\\begin{table}[ht]")
    latex = latex.replace("\\begin{tabular}", "\\centering\n\\begin{tabular}")
    latex = latex.replace("\\end{tabular}", "\\end{tabular}\n\\vskip -0.1in")
    latex = latex.replace("{table}", "{table*}") # Use table* for wide tables
    # Remove the unnecessary cline appearing right before bottomrule
    latex = re.sub(r'\\cline\{[0-9-]+\}\n\\bottomrule', r'\\bottomrule', latex)
    
    print("\n" + latex)
