import subprocess
import os
from collections import defaultdict

# Update to 10 seeds
seeds = list(range(5))
subsample_times =  [50,25,10,5]
results = []
mmd_by_subsample = defaultdict(list)

for subsample in subsample_times:
    for seed in seeds:
        print(f"\n--- Running with subsample_time={subsample}, seed={seed} ---")
        
        run = subprocess.run(
            ["python", "main_sde.py", 
             "--manual_seed", str(seed), 
             "--subsample_time", str(subsample)],
            capture_output=True,
            text=True
        )
        output = run.stdout
        mmd_val = None
        for i, line in enumerate(output.splitlines()):
            if "FINAL MMD VAL" in line and i + 1 < len(output.splitlines()):
                try:
                    mmd_val = float(output.splitlines()[i + 1])
                    break
                except ValueError:
                    continue

        if mmd_val is not None:
            results.append((subsample, seed, mmd_val))
            mmd_by_subsample[subsample].append(mmd_val)
            print(f"✔️ Subsample {subsample}, Seed {seed} → MMD: {mmd_val}")
        else:
            print(f"❌ Failed to extract MMD for subsample={subsample}, seed={seed}")
            with open("failed_runs_log.txt", "a") as f:
                f.write(f"subsample_time={subsample}, seed={seed}\n")

# Save all MMDs and averages
os.makedirs("resultsSDE", exist_ok=True)
with open("resultsSDE/mmd_results.txt", "w") as f:
    for subsample in subsample_times:
        vals = mmd_by_subsample[subsample]
        avg = sum(vals) / len(vals) if vals else float('nan')
        f.write(f"Subsample {subsample}: Average MMD = {avg:.6f}\n")
        for seed, mmd_val in zip(seeds, vals):
            f.write(f"  Seed {seed}: MMD = {mmd_val}\n")

print("\n✅ Summary:")
for subsample in subsample_times:
    vals = mmd_by_subsample[subsample]
    avg = sum(vals) / len(vals) if vals else float('nan')
    print(f"Subsample {subsample}: Average MMD = {avg:.6f}")
