import numpy as np
from sorcerun.git_utils import (
    is_dirty,
    get_repo,
    get_commit_hash,
    get_time_str,
    get_tree_hash,
)
from config import _compute_stability_limit, _compute_stability_limit_dx

repo = get_repo()
commit_hash = get_commit_hash(repo)
time_str = get_time_str()
dirty = is_dirty(repo)


# Configuration
T = 0.1
sigma = 1.0
num_edges = 5
edge_length = 10.0
drift_coeffs = np.array([-10, -20, -30, -40, -50], dtype=np.float32).tolist()
# drift_coeffs = np.array([-30] * num_edges)
make_gif = False
error_norm = 2
run_fvm = False
POTENTIAL_SETTINGS = {
    "linear": {
        "dts": [5e-4, 1e-3],
        "total_time": 0.1,
        "fvm_dt": 5e-4,
        "fvm_total_time": 0.05,
    },
    "quadratic": {
        "dts": [5e-5, 1e-4],
        "total_time": 0.05,
        "fvm_dt": 5e-5,
        "fvm_total_time": 0.02,
    },
}
#
#
edge_lengths = np.array([edge_length] * num_edges).tolist()
jump_weights = np.ones(num_edges, dtype=np.float32)
jump_weights /= jump_weights.sum()
jump_weights = jump_weights.tolist()


# %%
def make_config(
    num_particles,
    num_bins,
    dt,
    r,
    run_f=False,
    backend="cuda",
    total_time=T,
    pot_type="linear",
):
    max_stable_dt = _compute_stability_limit(
        drift_coeffs,
        D=sigma**2 / 2,
        dx=edge_length / num_bins,
    )
    if dt > max_stable_dt and run_fvm:
        print(f"Skipping dt={dt} > {max_stable_dt}")
        return None
    steps = int(total_time / dt)
    c = {
        "num_particles": num_particles,
        "num_bins": num_bins,
        "steps": steps,
        "dt": dt,
        "sigma": sigma,
        "num_edges": num_edges,
        "edge_lengths": edge_lengths,
        "drift_coeffs": drift_coeffs,
        "jump_weights": jump_weights,
        "make_gif": make_gif,
        "error_norm": error_norm,
        "run_fvm": run_f,
        "potential_type": pot_type,
        "backend": backend,
        #
        "repeat": r,
        "commit_hash": commit_hash,
        "main_tree_hash": get_tree_hash(repo, "main"),
        "time_str": time_str,
        "dirty": dirty,
    }
    return c


num_repeats = 2
parts = np.logspace(4, 5, 2, base=10, dtype=int).tolist()
num_bins = 100
FVM_NUM_BINS = 200
POTENTIAL_TYPES = list(POTENTIAL_SETTINGS.keys())
all_dts = sorted(
    {dt for settings in POTENTIAL_SETTINGS.values() for dt in settings["dts"]}
)

max_stable_dt = _compute_stability_limit(
    drift_coeffs,
    D=sigma**2 / 2,
    dx=edge_length / num_bins,
)

min_stable_dx = _compute_stability_limit_dx(
    drift_coeffs,
    D=sigma**2 / 2,
    dt=max(all_dts),
)
print(f"Min stable dx: {min_stable_dx}")
print(f"Max stable bins: {edge_length / min_stable_dx}")


langevin_configs = []
for pot_type in POTENTIAL_TYPES:
    pot_settings = POTENTIAL_SETTINGS[pot_type]
    for npart in parts:
        for dt in pot_settings["dts"]:
            for r in range(num_repeats):
                langevin_configs.append(
                    make_config(
                        num_particles=npart,
                        num_bins=num_bins,
                        dt=dt,
                        r=r,
                        pot_type=pot_type,
                        total_time=pot_settings["total_time"],
                    )
                )

fvm_configs = []
for pot_type in POTENTIAL_TYPES:
    pot_settings = POTENTIAL_SETTINGS[pot_type]
    fvm_configs.append(
        make_config(
            num_particles=min(parts),
            num_bins=FVM_NUM_BINS,
            dt=pot_settings["fvm_dt"],
            r=0,
            run_f=True,
            total_time=pot_settings["fvm_total_time"],
            pot_type=pot_type,
        )
    )

# add one config for timing each particle count
# timing dt is the max dt in dts that is smaller than the max stable dt
# timing_dt = max(dt for dt in dts if dt < max_stable_dt)
# timing_dt = min(dts)
# timing_configs = [
#     make_config(
#         num_particles=npart,
#         num_bins=num_bins,
#         dt=timing_dt,
#         r=0,
#         backend="torch",
#         run_f=True if i == 0 else False,
#         total_time=timing_dt * 10 * 1000,
#     )
#     for i, npart in enumerate(parts)
# ]

# configs = timing_configs + configs


# add one config with best possible settings for fvm
# configs = [
#     make_config(
#         num_particles=min(parts),
#         num_bins=num_bins,
#         dt=min(dts),
#         r=0,
#         run_f=True,
#     )
# ] + configs

configs = [c for c in fvm_configs + langevin_configs if c is not None]
print(f"Generated {len(configs)} configurations")
print("Max steps", max([c["steps"] for c in configs]))
print("Min steps", min([c["steps"] for c in configs]))
