import os
import glob
import numpy as np
import pandas as pd
from typing import List


def find_files_recursively(root: str, filename: str) -> List[str]:
    pattern = os.path.join(os.path.abspath(root), "**", filename)
    return sorted(glob.glob(pattern, recursive=True))

def load_results_from_csv(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)

    # Normalize dtypes
    df = df.copy()
    df["theta_value"] = df["theta_value"].astype(float)
    df["n"] = df["n"].astype(int)
    df["mmd_p_value"] = pd.to_numeric(df["mmd_p_value"], errors="coerce")
    df["sink_p_value"] = pd.to_numeric(df["sink_p_value"], errors="coerce")

    return df

def compute_rejection_probability_table(df: pd.DataFrame, alpha: float) -> pd.DataFrame:
    """
    Returns a tidy table with columns:
      theta_value, n, method, power, n_sims
    """
    # Reject indicators
    mmd_reject = (df["mmd_p_value"] < alpha).astype(np.float32)
    sink_reject = (df["sink_p_value"] < alpha).astype(np.float32)

    base = df[["theta_value", "n"]].copy()
    base["mmd_reject"] = mmd_reject
    base["sink_reject"] = sink_reject

    agg = base.groupby(["theta_value", "n"], as_index=False).agg(
        power_mmd=("mmd_reject", "mean"),
        power_sinkhorn=("sink_reject", "mean"),
        n_sims=("mmd_reject", "size"),
    )

    # Tidy format
    tidy_mmd = agg[["theta_value", "n", "power_mmd", "n_sims"]].rename(columns={"power_mmd": "power"})
    tidy_mmd["method"] = "MMD"

    tidy_sink = agg[["theta_value", "n", "power_sinkhorn", "n_sims"]].rename(columns={"power_sinkhorn": "power"})
    tidy_sink["method"] = "Sinkhorn"

    tidy = pd.concat([tidy_mmd, tidy_sink], ignore_index=True)
    tidy = tidy.sort_values(["method", "theta_value", "n"]).reset_index(drop=True)
    return tidy

def compute_coverage_table(df: pd.DataFrame) -> pd.DataFrame:
    d = df.copy()
    d["theta_value"] = d["theta_value"].astype(float)
    d["n"] = d["n"].astype(int)
    d["sim_idx"] = d["sim_idx"].astype(int)

    # Convert numeric columns robustly
    num_cols = [
        "mmd_true", "mmd_ci_low", "mmd_ci_high",
        "sink_true", "sink_ci_low", "sink_ci_high",
    ]
    for c in num_cols:
        d[c] = pd.to_numeric(d[c], errors="coerce")

    d = d.dropna(subset=num_cols + ["theta_value", "n", "sim_idx"])
    if d.empty:
        raise RuntimeError("No valid rows after dropping NaNs in coverage columns.")

    # Coverage indicators (one-step is inside its own Wald CI)
    d["mmd_cover"] = ((d["mmd_true"] >= d["mmd_ci_low"]) & (d["mmd_true"] <= d["mmd_ci_high"])).astype(float)
    d["sink_cover"] = ((d["sink_true"] >= d["sink_ci_low"]) & (d["sink_true"] <= d["sink_ci_high"])).astype(float)

    # Average over sims to get coverage
    cov = (
        d.groupby(["theta_value", "n"], as_index=False)
               .agg(
                   coverage_mmd=("mmd_cover", "mean"),
                   coverage_sinkhorn=("sink_cover", "mean"),
                   n_sims=("sim_idx", "nunique"),
               )
    )

    # Tidy
    tidy_mmd = cov[["theta_value", "n", "coverage_mmd", "n_sims"]].rename(columns={"coverage_mmd": "coverage"})
    tidy_mmd["method"] = "MMD"

    tidy_sink = cov[["theta_value", "n", "coverage_sinkhorn", "n_sims"]].rename(columns={"coverage_sinkhorn": "coverage"})
    tidy_sink["method"] = "Sinkhorn"

    tidy = pd.concat([tidy_mmd, tidy_sink], ignore_index=True)
    tidy = tidy.sort_values(["method", "theta_value", "n"]).reset_index(drop=True)
    return tidy

def compute_mse_table(df: pd.DataFrame) -> pd.DataFrame:

    d = df.copy()
    d["theta_value"] = d["theta_value"].astype(float)
    d["n"] = d["n"].astype(int)
    d["sim_idx"] = d["sim_idx"].astype(int)

    # Convert numeric columns robustly
    num_cols = [
        "mmd_true", "mmd_plugin", "mmd_one_step",
        "sink_true", "sink_plugin", "sink_one_step",
    ]
    for c in num_cols:
        d[c] = pd.to_numeric(d[c], errors="coerce")

    d = d.dropna(subset=num_cols + ["theta_value", "n", "sim_idx"])
    if d.empty:
        raise RuntimeError("No valid rows after dropping NaNs in coverage columns.")

    # Coverage indicators (one-step is inside its own Wald CI)
    d["mmd_plugin_mse"] = ((d['mmd_true'] - d['mmd_plugin'])**2).astype(float)
    d["sink_plugin_mse"] = ((d['sink_true'] - d['sink_plugin'])**2).astype(float)
    d["mmd_one_step_mse"] = ((d['mmd_true'] - d['mmd_one_step'])**2).astype(float)
    d["sink_one_step_mse"] = ((d['sink_true'] - d['sink_one_step'])**2).astype(float)

    # Average over sims to get coverage
    mse = (
        d.groupby(["theta_value", "n"], as_index=False)
               .agg(
                   mmd_plugin_mse=("mmd_plugin_mse", "mean"),
                   sink_plugin_mse=("sink_plugin_mse", "mean"),
                   mmd_one_step_mse=("mmd_one_step_mse", "mean"),
                   sink_one_step_mse=("sink_one_step_mse", "mean"),
                   n_sims=("sim_idx", "nunique"),
               )
    )

    # Tidy
    tidy_mmd = mse[["theta_value", "n", "mmd_plugin_mse", "mmd_one_step_mse", "n_sims"]].rename(columns={"mmd_plugin_mse": "plugin_mse", "mmd_one_step_mse": "one_step_mse"})
    tidy_mmd["method"] = "MMD"

    tidy_sink = mse[["theta_value", "n", "sink_plugin_mse", "sink_one_step_mse", "n_sims"]].rename(columns={"sink_plugin_mse": "plugin_mse", "sink_one_step_mse": "one_step_mse"})
    tidy_sink["method"] = "Sinkhorn"

    tidy = pd.concat([tidy_mmd, tidy_sink], ignore_index=True)
    tidy = tidy.sort_values(["method", "theta_value", "n"]).reset_index(drop=True)
    return tidy