# %%
from contextlib import redirect_stdout
import math
import sys
import os
from sorcerun.git_utils import (
    get_commit_hash,
    get_repo,
    get_tree_hash,
    freeze_notebook,
    branches_pointing_to,
)
from pathlib import Path
import numpy as np
import incense
from sorcerun.incense_utils import (
    exps_to_xarray,
    load_filesystem_expts_by_config_keys,
    filter_by_config,
    get_latest_single_and_grid_exps,
)
from tqdm import tqdm
from gifify import gifify
import matplotlib.pyplot as plt

# %%
# rcparams
plt.rcParams.update({"font.size": 22})
# %%
LINEWIDTH = 3
MARKERSIZE = 10

# %% time plots
STEPS_PER_KERNEL = 1000

# %%
FVM_LINEAR_ID = 192  # hardcoded to avoid re-running the slow fvm exps
FVM_QUADRATIC_ID = 102  # hardcoded to avoid re-running the slow fvm exps

# %%
repo = get_repo()
# get path to root of repo
REPO_PATH = repo.working_dir
sys.path.append(f"{REPO_PATH}")

# %%
COMMIT_HASH = get_commit_hash(repo)
print(branches_pointing_to(repo, COMMIT_HASH, include_remote=False))
SUBTREE_HASH = get_tree_hash(repo, "main")
print(f"COMMIT_HASH: {COMMIT_HASH}")
print(f"SUBTREE_HASH: {SUBTREE_HASH}")
FIG_PATH = f"{REPO_PATH}/figs/"
os.makedirs(FIG_PATH, exist_ok=True)

ALL_EXPS = load_filesystem_expts_by_config_keys(
    runs_dir=f"{REPO_PATH}/file_storage/runs/",
    dirty=False,
    main_tree_hash=SUBTREE_HASH,
)
# %%
single, grid = get_latest_single_and_grid_exps(ALL_EXPS)
print(f"Number of grid exps: {len(grid)}")
# %%
pot_type = (
    "Linear potential"
    if "linear" in grid[0].config["potential_type"]
    else "Quadratic potential"
)

backends = set(e.config["backend"] for e in grid)
if "torch" in backends:
    # time plots
    pass
else:
    # error plots
    # extract the FVM experiment first
    fvm_exps = list(filter_by_config(grid, run_fvm=True))
    print(f"Number of FVM exps: {len(fvm_exps)}")
    fvm_linear_exp = incense.experiment.FileSystemExperiment.from_run_dir(
        Path(f"{REPO_PATH}/file_storage/runs/{FVM_LINEAR_ID}")
    )
    fvm_quadratic_exp = incense.experiment.FileSystemExperiment.from_run_dir(
        Path(f"{REPO_PATH}/file_storage/runs/{FVM_QUADRATIC_ID}")
    )
    if pot_type == "Linear potential":
        fvm_exps.append(fvm_linear_exp)
    elif pot_type == "Quadratic potential":
        fvm_exps.append(fvm_quadratic_exp)

    # best fvm errors
    fvm_errs = np.array([e.metrics["fvm_ss_error"].min() for e in fvm_exps])
    fvm_err_mean = np.mean(fvm_errs).item()
    fvm_err_std = np.std(fvm_errs).item()

    print(f"FVM error with best settings: {fvm_err_mean:.3e} +/- {fvm_err_std:.3e}")

    # extract langevin errors by particle count

    langevin_exps = list(filter_by_config(grid, run_fvm=False))
    print(f"Number of Langevin exps: {len(langevin_exps)}")
    _, langevin_array = exps_to_xarray(langevin_exps)

    reduced = (
        langevin_array.ffill("step", limit=None)
        .mean("steps")
        .isel(step=-1)  # last step
        # .min("step")  # best step
        .sel(metric="langevin_ss_error")  # select the error metric
    )

    means = reduced.mean("repeat")
    stds = reduced.std("repeat")
    x = reduced.coords["num_particles"].values

    plt.figure(figsize=(12, 9))
    plt.yscale("log")
    plt.title(f"L2 Error vs Steady State Density, {pot_type}")
    plt.axhline(
        fvm_err_mean,
        color="red",
        label="FVM Scheme (best error)",
        linewidth=LINEWIDTH,
    )
    if len(fvm_exps) > 0:
        plt.fill_between(
            reduced.coords["num_particles"].values,
            fvm_err_mean - fvm_err_std,
            fvm_err_mean + fvm_err_std,
            color="red",
            alpha=0.2,
        )
    for dt in reduced.coords["dt"].values:
        reduced_dt = reduced.sel(dt=dt)
        plt.plot(
            x,
            means.sel(dt=dt).values,
            label=f"Algorithm 1, $\\Delta t$={dt:.1e}",
            linewidth=LINEWIDTH,
            marker="o",
            markersize=MARKERSIZE,
        )
        plt.fill_between(
            reduced_dt.coords["num_particles"].values,
            (means - 2 * stds).sel(dt=dt).values,
            (means + 2 * stds).sel(dt=dt).values,
            alpha=0.2,
        )
    plt.xlabel("Number of particles")
    plt.ylabel("L2 Error")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{FIG_PATH}/{pot_type}_error.png")
    plt.savefig(f"{FIG_PATH}/{pot_type}_error.pdf")
    plt.show()


# # %%
freeze_notebook(
    filename="plot.py",
    repo=repo,
    commit_hash=COMMIT_HASH,
    SUBTREE_HASH=SUBTREE_HASH,
)
