import os
import subprocess
import itertools
import numpy as np
import glob

here = os.path.dirname(os.path.abspath(__file__))
main_file = os.path.join(here, "main_sde_matching.py")  # make sure this matches!

sigmas_a = [-3.0,-2.5, -2.0]
sigmas_o = [0.01, 0.03, 0.1]

results = []  # store (sigma_a, sigma_o, best_mmd, best_epoch, mmd_file)

for sa, so in itertools.product(sigmas_a, sigmas_o):
    run_path = os.path.join(here, f"results/sa{sa}_so{so}")
    os.makedirs(run_path, exist_ok=True)

    print(f"\n>>> Running with sigma_a={sa}, sigma_o={so}")
    cmd = [
        "python", main_file,
        "--no_epochs", "15",
        "--sigma_a", str(sa),
        "--sigma_o", str(so),
        "--path", run_path
    ]
    subprocess.run(cmd)

    # After training, find mmd_list.txt (possibly in subfolder)
    matches = glob.glob(os.path.join(run_path, "**", "mmd_list.txt"), recursive=True)
    if matches:
        mmd_file = matches[0]  # take the first match
        print(f"   -> Found {mmd_file}")
        mmds = np.loadtxt(mmd_file)
        if mmds.ndim == 0:  # handle single value
            mmds = np.array([mmds])
        best_epoch = int(np.argmin(mmds))
        best_mmd = float(np.min(mmds))
        results.append((sa, so, best_mmd, best_epoch, mmd_file))
        print(f"   -> Best MMD {best_mmd:.4f} at epoch {best_epoch}")
    else:
        print(f"   -> Warning: no mmd_list.txt found under {run_path}")

# -------------------------------
# Print and save summary
# -------------------------------
summary_lines = []
summary_lines.append("=== Grid Search Results ===")
for sa, so, mmd, epoch, mmd_file in results:
    line = (
        f"sigma_a={sa:4}, sigma_o={so:5} | "
        f"Best MMD={mmd:.4f} at epoch {epoch} | File: {mmd_file}"
    )
    print(line)
    summary_lines.append(line)

if results:
    best_run = min(results, key=lambda x: x[2])
    sa, so, mmd, epoch, mmd_file = best_run
    summary_lines.append("\n>>> BEST CONFIG <<<")
    summary_lines.append(
        f"sigma_a={sa}, sigma_o={so} | Best MMD={mmd:.4f} at epoch {epoch} | File: {mmd_file}"
    )

    print("\n>>> BEST CONFIG <<<")
    print(
        f"sigma_a={sa}, sigma_o={so} | Best MMD={mmd:.4f} at epoch {epoch} | File: {mmd_file}"
    )

# Save to file
summary_file = os.path.join(here, "results", "grid_search_summary.txt")
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
with open(summary_file, "w") as f:
    f.write("\n".join(summary_lines))

print(f"\nSummary saved to {summary_file}")
