import os
import pickle
from functools import partial
from multiprocessing import Pool, current_process
from typing import Iterable
import numpy as np
from adaptive_svgd.kernels import general_exponential
from adaptive_svgd.svgd import SVGD, SVGDAdaptive
from adaptive_svgd.examples import ode, gaussian_mixture_1D


def experiment_ode(
    run: int,
    exp_path: str,
    Nx: int = 16,
    Ny: int = 256,
    No: int = 256,
    Ms: Iterable[int] = [50, 100, 200],
    N_posterior: int = 10000,
    noise_cov_scale: float = 0.001,
    prior_cov_scale: float = 50.0,
    p: float = 1.0,
    KL_length: int = 512,
    n_steps: int = 400000,
    step_size: float = 0.0001,
    h_step_size: float = 0.00001,
    h_update_freq: int = 100,
):
    current = current_process()
    rng = np.random.default_rng(run)
    noise_cov = np.eye(No) * noise_cov_scale
    (
        prior_cov,
        score_fn,
        _,
        _,
        true_process,
        _,
        _,
        posterior_cov,
        posterior_mean,
    ) = ode(rng, Nx, Ny, No, noise_cov, N_posterior, prior_cov_scale, KL_length)
    with open(f"{exp_path}/setup_run{run}.pkl", "wb") as f:
        pickle.dump((true_process, posterior_cov, posterior_mean), f)
    for M in Ms:
        x0 = rng.multivariate_normal(mean=np.zeros(Nx), cov=prior_cov, size=M)
        svgd_med = SVGD(
            kernel=partial(general_exponential, h=-1.0, p=p), score_fn=score_fn, x0=x0
        )
        x_hist, h_hist = svgd_med.update(
            n_steps=n_steps,
            step_size=step_size,
            tol=0.0,
            hist_freq=100,
            return_history=True,
            use_adadelta=True,
            progress_desc=f"{current.name}: Med-SVGD, M={M:<3}, run={run:<2}",
            progress_bar_position=int(current.name[-1]) - 1,
        )
        with open(
            f"{exp_path}/results_MedSVGD_M{M}_run{run}.pkl",
            "wb",
        ) as f:
            pickle.dump((x_hist, h_hist), f)
        svgd_adaptive = SVGDAdaptive(
            h=np.ones(Nx),
            score_fn=score_fn,
            x0=x0,
            p=p,
        )
        x_hist, h_hist = svgd_adaptive.update(
            n_steps=n_steps,
            step_size=step_size,
            tol=0.0,
            hist_freq=100,
            return_history=True,
            use_adadelta=True,
            h_step_size=h_step_size,
            h_steps_per_x_step=h_update_freq,
            progress_desc=f"{current.name}: Ad-SVGD,  M={M:<3}, run={run:<2}",
            progress_bar_position=int(current.name[-1]) - 1,
        )
        with open(
            f"{exp_path}/results_AdSVGD_M{M}_run{run}.pkl",
            "wb",
        ) as f:
            pickle.dump((x_hist, h_hist), f)


def run_experiment_ode():
    exp_path = "experiments/ode"
    if not os.path.isdir(exp_path):
        os.makedirs(exp_path, exist_ok=True)
    runs = range(1, 101)
    with Pool() as pool:
        pool.map(partial(experiment_ode, exp_path=exp_path), runs)


def experiment_1d_gaussian_mixture(
    run: int,
    exp_path: str,
    Ms: Iterable[int] = [10, 20, 50, 100, 200, 500],
    h_exponents: Iterable[int] = [-3, -2, -1, 0, 1, 2, 3],
    n_steps: int = 10000,
    step_size: float = 1.0,
    p: float = 1.0,
):
    current = current_process()
    rng = np.random.default_rng(run)
    for M in Ms:
        x0, score_fn, _ = gaussian_mixture_1D(M, rng)
        for h_exp in h_exponents:
            h = 10**h_exp
            svgd_algo = SVGD(
                kernel=partial(general_exponential, h=h, p=p),
                score_fn=score_fn,
                x0=x0,
            )
            x_hist, _ = svgd_algo.update(
                n_steps=n_steps,
                step_size=step_size,
                tol=0.0,
                hist_freq=100,
                return_history=True,
                use_adadelta=False,
                progress_desc=f"{current.name}: h={h:<5}, M={M:<3}, run={run:<3}",
                progress_bar_position=int(current.name[-1]) - 1,
            )
            with open(
                f"{exp_path}/results_h{str(h).replace('.', '_')}_M{M}_run{run}.pkl",
                "wb",
            ) as f:
                pickle.dump(x_hist, f)


def run_experiment_1d_gaussian_mixture():
    exp_path = "experiments/1d_gaussian_mixture"
    if not os.path.isdir(exp_path):
        os.makedirs(exp_path, exist_ok=True)
    runs = range(1, 101)
    with Pool() as pool:
        pool.map(partial(experiment_1d_gaussian_mixture, exp_path=exp_path), runs)


if __name__ == "__main__":
    run_experiment_ode()
    run_experiment_1d_gaussian_mixture()
