# main_exp.py

import matplotlib.pyplot as plt
import numpy as np
from funcs import run_experiment, compute_mean_std_by_col

if __name__ == "__main__":
    # Run experiment
    results = run_experiment(
        "Gaussian", "Uniform",
        dims=[10], sizes=[1000],
        t_list=[0, 1, 2],
        reg=1e-5, stopThr=1e-3,
        max_iter=2000, repeated_times=5
    )

    # Get one case
    key = list(results.keys())[0]
    res = results[key]
    t_list = res["t_list"]

    # Plot running time
    run_mean_sk, run_std_sk = compute_mean_std_by_col(res["time"]["sk"])
    run_mean_rot, run_std_rot = compute_mean_std_by_col(res["time"]["rot"])

    plt.figure(figsize=(7, 6))
    plt.errorbar(t_list, run_mean_sk, yerr=run_std_sk, fmt='-^', label="Sinkhorn")
    plt.errorbar(t_list, run_mean_rot, yerr=run_std_rot, fmt='-D', label="RW2 Sinkhorn")
    plt.xlabel("Translation t (last dim)")
    plt.ylabel("Running time (s)")
    plt.legend()
    plt.show()

    # Plot errors vs EMD
    err_sk = np.abs(res["w2"]["sk"] - res["w2"]["emd"])
    err_rot = np.abs(res["w2"]["rot"] - res["w2"]["emd"])
    err_mean_sk, err_std_sk = compute_mean_std_by_col(err_sk)
    err_mean_rot, err_std_rot = compute_mean_std_by_col(err_rot)

    plt.figure(figsize=(7, 6))
    plt.errorbar(t_list, err_mean_sk, yerr=err_std_sk, fmt='-^', label="Sinkhorn")
    plt.errorbar(t_list, err_mean_rot, yerr=err_std_rot, fmt='-D', label="RW2 Sinkhorn")
    plt.xlabel("Translation t (last dim)")
    plt.ylabel("Error vs EMD")
    plt.legend()
    plt.show()
