# main_exp.py

import matplotlib.pyplot as plt
import numpy as np
import yaml
from funcs import run_experiment, compute_mean_std_by_col

if __name__ == "__main__":
    # Load YAML config
    with open("config.yaml", "r") as f:
        config = yaml.safe_load(f)

    # Iterate over all experiments in config
    for exp_idx, exp in enumerate(config["experiments"]):
        print(f"\n=== Running experiment {exp_idx+1}/{len(config['experiments'])} ===")
        print(f"{exp['source_dist']} vs {exp['target_dist']} | dims={exp['dims']} | sizes={exp['sizes']}")

        results = run_experiment(
            exp["source_dist"],
            exp["target_dist"],
            dims=exp["dims"],
            sizes=exp["sizes"],
            t_list=exp["t_list"],
            reg=exp["reg"],
            stopThr=exp["stopThr"],
            max_iter=exp["max_iter"],
            repeated_times=exp["repeated_times"]
        )

        # --- Plot for the first (dim, size) case ---
        key = list(results.keys())[0]
        res = results[key]
        t_list = res["t_list"]

        # Running time
        run_mean_sk, run_std_sk = compute_mean_std_by_col(res["time"]["sk"])
        run_mean_rot, run_std_rot = compute_mean_std_by_col(res["time"]["rot"])

        plt.figure(figsize=(7, 6))
        plt.errorbar(t_list, run_mean_sk, yerr=run_std_sk, fmt='-^', label="Sinkhorn")
        plt.errorbar(t_list, run_mean_rot, yerr=run_std_rot, fmt='-D', label="RW2 Sinkhorn")
        plt.xlabel("Translation t (last dim)")
        plt.ylabel("Running time (s)")
        plt.legend()
        plt.title(f"{exp['source_dist']} vs {exp['target_dist']} (dim={key[2]}, size={key[3]})")
        plt.show()

        # Errors vs EMD
        err_sk = np.abs(res["w2"]["sk"] - res["w2"]["emd"])
        err_rot = np.abs(res["w2"]["rot"] - res["w2"]["emd"])
        err_mean_sk, err_std_sk = compute_mean_std_by_col(err_sk)
        err_mean_rot, err_std_rot = compute_mean_std_by_col(err_rot)

        plt.figure(figsize=(7, 6))
        plt.errorbar(t_list, err_mean_sk, yerr=err_std_sk, fmt='-^', label="Sinkhorn")
        plt.errorbar(t_list, err_mean_rot, yerr=err_std_rot, fmt='-D', label="RW2 Sinkhorn")
        plt.xlabel("Translation t (last dim)")
        plt.ylabel("Error vs EMD")
        plt.legend()
        plt.title(f"{exp['source_dist']} vs {exp['target_dist']} (dim={key[2]}, size={key[3]})")
        plt.show()
