# %%
from contextlib import redirect_stdout
import math
import sys
import os
from sorcerun.git_utils import (
    get_commit_hash,
    get_repo,
    get_tree_hash,
    freeze_notebook,
    branches_pointing_to,
)
import numpy as np
from sorcerun.incense_utils import (
    exps_to_xarray,
    load_filesystem_expts_by_config_keys,
    filter_by_config,
    get_latest_single_and_grid_exps,
)
from tqdm import tqdm
from gifify import gifify
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

# %%
# rcparams
plt.rcParams.update({"font.size": 22})
# %%
LINEWIDTH = 3
MARKERSIZE = 10

# %% time plots
STEPS_PER_KERNEL = 1000

repo = get_repo()
# get path to root of repo
REPO_PATH = repo.working_dir
sys.path.append(f"{REPO_PATH}")

# %%
COMMIT_HASH = get_commit_hash(repo)
print(branches_pointing_to(repo, COMMIT_HASH, include_remote=False))
SUBTREE_HASH = get_tree_hash(repo, "main")
print(f"COMMIT_HASH: {COMMIT_HASH}")
print(f"SUBTREE_HASH: {SUBTREE_HASH}")
FIG_PATH = f"{REPO_PATH}/figs/"
os.makedirs(FIG_PATH, exist_ok=True)

ALL_EXPS = load_filesystem_expts_by_config_keys(
    runs_dir=f"{REPO_PATH}/file_storage/runs/",
    dirty=False,
    main_tree_hash=SUBTREE_HASH,
)
# %%
single, grid = get_latest_single_and_grid_exps(ALL_EXPS)
print(f"Number of grid exps: {len(grid)}")
if not grid:
    print("No grid experiments found for the current config criteria; skipping plots.")
    sys.exit(0)

potential_types = sorted({exp.config["potential_type"] for exp in grid})
for potential in potential_types:
    subset = [exp for exp in grid if exp.config["potential_type"] == potential]
    backends = set(e.config["backend"] for e in subset)
    if "torch" in backends:
        continue

    fvm_exps = list(filter_by_config(subset, run_fvm=True))
    print(f"[{potential}] Number of FVM exps: {len(fvm_exps)}")

    if not fvm_exps:
        print(f"[{potential}] No FVM experiments available; skipping plot.")
        continue

    fvm_errs = np.array([e.metrics["fvm_ss_error"].min() for e in fvm_exps])
    fvm_err_mean = np.mean(fvm_errs).item()
    fvm_err_std = np.std(fvm_errs).item()

    print(
        f"[{potential}] FVM error with best settings: {fvm_err_mean:.3e} +/- {fvm_err_std:.3e}"
    )

    langevin_exps = list(filter_by_config(subset, run_fvm=False))
    print(f"[{potential}] Number of Langevin exps: {len(langevin_exps)}")
    if not langevin_exps:
        continue

    _, langevin_array = exps_to_xarray(langevin_exps)

    reduced = (
        langevin_array.ffill("step", limit=None)
        .mean("steps")
        .isel(step=-1)
        .sel(metric="langevin_ss_error")
    )

    means = reduced.mean("repeat")
    stds = reduced.std("repeat")
    x = reduced.coords["num_particles"].values
    pot_title = f"{potential.title()} potential"

    plt.figure(figsize=(12, 9))
    plt.yscale("log")
    plt.title(f"L2 Error vs Steady State Density, {pot_title}")
    plt.axhline(
        fvm_err_mean,
        color="red",
        label="FVM Scheme (best error)",
        linewidth=LINEWIDTH,
    )
    plt.fill_between(
        reduced.coords["num_particles"].values,
        fvm_err_mean - fvm_err_std,
        fvm_err_mean + fvm_err_std,
        color="red",
        alpha=0.2,
    )
    for dt in reduced.coords["dt"].values:
        reduced_dt = reduced.sel(dt=dt)
        plt.plot(
            x,
            means.sel(dt=dt).values,
            label=f"Algorithm 1, $\\Delta t$={dt:.1e}",
            linewidth=LINEWIDTH,
            marker="o",
            markersize=MARKERSIZE,
        )
        plt.fill_between(
            reduced_dt.coords["num_particles"].values,
            (means - 2 * stds).sel(dt=dt).values,
            (means + 2 * stds).sel(dt=dt).values,
            alpha=0.2,
        )
    plt.xlabel("Number of particles")
    plt.ylabel("L2 Error")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.tight_layout()
    filename_prefix = f"{potential}_potential_error"
    plt.savefig(f"{FIG_PATH}/{filename_prefix}.png")
    plt.savefig(f"{FIG_PATH}/{filename_prefix}.pdf")

# # %%
freeze_notebook(
    filename="plot.py",
    repo=repo,
    commit_hash=COMMIT_HASH,
    SUBTREE_HASH=SUBTREE_HASH,
)
