import numpy as np
from sorcerun.git_utils import (
    is_dirty,
    get_repo,
    get_commit_hash,
    get_time_str,
    get_tree_hash,
)
from config import _compute_stability_limit, _compute_stability_limit_dx

repo = get_repo()
commit_hash = get_commit_hash(repo)
time_str = get_time_str()
dirty = is_dirty(repo)


# Configuration
T = 1.0
sigma = 1.0
num_edges = 5
edge_length = 10.0
drift_coeffs = np.array([-10, -20, -30, -40, -50], dtype=np.float32).tolist()
# drift_coeffs = np.array([-30] * num_edges)
make_gif = False
error_norm = 2
run_fvm = False
potential_type = "quadratic"
potential_type = "linear"
#
#
edge_lengths = np.array([edge_length] * num_edges).tolist()
jump_weights = np.ones(num_edges, dtype=np.float32)
jump_weights /= jump_weights.sum()
jump_weights = jump_weights.tolist()


# %%
def make_config(
    num_particles,
    num_bins,
    dt,
    r,
    run_f=False,
    backend="cuda",
    total_time=T,
):
    max_stable_dt = _compute_stability_limit(
        drift_coeffs,
        D=sigma**2 / 2,
        dx=edge_length / num_bins,
    )
    if dt > max_stable_dt and run_fvm:
        print(f"Skipping dt={dt} > {max_stable_dt}")
        return None
    steps = int(total_time / dt)
    c = {
        "num_particles": num_particles,
        "num_bins": num_bins,
        "steps": steps,
        "dt": dt,
        "sigma": sigma,
        "num_edges": num_edges,
        "edge_lengths": edge_lengths,
        "drift_coeffs": drift_coeffs,
        "jump_weights": jump_weights,
        "make_gif": make_gif,
        "error_norm": error_norm,
        "run_fvm": run_f,
        "potential_type": potential_type,
        "backend": backend,
        #
        "repeat": r,
        "commit_hash": commit_hash,
        "main_tree_hash": get_tree_hash(repo, "main"),
        "time_str": time_str,
        "dirty": dirty,
    }
    return c


num_repeats = 5
parts = np.logspace(4, 8, 5, base=10, dtype=int).tolist()
# parts = [int(1e8)]
# bins = np.logspace(2, 6, 5, base=10, dtype=int).tolist()
# bins = [100, 1000]
dts = (1 * np.logspace(-6, -4, 3, base=10)).tolist()
# dts = [1e-5]
# particles_per_bin = int(1e5)
num_bins = 1000

max_stable_dt = _compute_stability_limit(
    drift_coeffs,
    D=sigma**2 / 2,
    dx=edge_length / num_bins,
)

min_stable_dx = _compute_stability_limit_dx(
    drift_coeffs,
    D=sigma**2 / 2,
    dt=max(dts),
)
print(f"Min stable dx: {min_stable_dx}")
print(f"Max stable bins: {edge_length / min_stable_dx}")


configs = [
    make_config(
        num_particles=npart,
        num_bins=num_bins,
        dt=dt,
        r=r,
    )
    for npart in parts
    for dt in dts
    for r in range(num_repeats)
]

# add one config for timing each particle count
# timing dt is the max dt in dts that is smaller than the max stable dt
# timing_dt = max(dt for dt in dts if dt < max_stable_dt)
# timing_dt = min(dts)
# timing_configs = [
#     make_config(
#         num_particles=npart,
#         num_bins=num_bins,
#         dt=timing_dt,
#         r=0,
#         backend="torch",
#         run_f=True if i == 0 else False,
#         total_time=timing_dt * 10 * 1000,
#     )
#     for i, npart in enumerate(parts)
# ]

# configs = timing_configs + configs


# add one config with best possible settings for fvm
# configs = [
#     make_config(
#         num_particles=min(parts),
#         num_bins=num_bins,
#         dt=min(dts),
#         r=0,
#         run_f=True,
#     )
# ] + configs

configs = [c for c in configs if c is not None]
print(f"Generated {len(configs)} configurations")
print("Max steps", max([c["steps"] for c in configs]))
print("Min steps", min([c["steps"] for c in configs]))
