import argparse
from functools import partial

import numpy as np
from sklearn.model_selection import train_test_split

from jax import numpy as jnp, random

from experiments.bnn import DATADIR, DataState
from experiments.result_io import write_result
import numpyro
from numpyro.contrib.einstein import RBFKernel, SteinVI
from numpyro.infer import Predictive, RenyiELBO, Trace_ELBO, log_likelihood
from numpyro.optim import Adam


def load_omniglot(name: str, rng_seed=None) -> DataState:
    x = np.loadtxt(DATADIR / f"{name}.txt")
    xtr, xte = train_test_split(x, train_size=0.90, random_state=rng_seed)

    return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, None, None)))


def load_bmnist():
    x = (
        np.loadtxt(DATADIR / f"binarized_mnist_{split}.txt", dtype=int, delimiter=" ")
        for split in ["train", "test"]
    )
    return DataState(*map(partial(jnp.array, dtype=int), x), None, None)


def run(dataset, config, verbose=False):
    match dataset:
        case "mnist":
            from experiments.mnist_vae import guide, model

            data = load_bmnist()
        case "omniglot":
            from experiments.omniglot_vae import guide, model

            data = load_omniglot("omniglot", rng_seed=1)
        case _:
            raise NotImplemented()

    img_dim = data.xtr.shape[1]

    inf_key, pred_key, data_key = random.split(random.PRNGKey(config.rng_seed + 1), 3)

    match config.div_order:
        case 1.0:
            loss = Trace_ELBO(num_particles=config.num_elbo_particles)
        case _:
            loss = RenyiELBO(
                alpha=config.div_order, num_particles=config.num_elbo_particles
            )

    rng_key, inf_key = random.split(inf_key)

    stein = SteinVI(
        model,
        guide,
        Adam(0.0005),
        loss,
        RBFKernel(),
        repulsion_temperature=config.repulsion,
        num_particles=config.num_particles,
    )
    # use keyword params for static (shape etc.)!
    result = stein.run(
        rng_key,
        data.xtr.shape[0] * config.epochs // config.subsample_size,
        data.xtr,
        subsample_size=config.subsample_size,
        img_dim=img_dim,
        num_data=data.xtr.shape[0],
    )
    state = result.state

    n_te = data.xte.shape[0]
    log_likes = []

    params = stein.get_params(state)
    pred = Predictive(
        model,
        guide=guide,
        params=params,
        return_sites=["h1", "h2", "img"]
        + [
            name
            for name, param in params.items()
            if param.shape[0] != config.num_particles
        ],
        num_samples=1,
        batch_ndims=1,
    )
    rng_key, pred_key = random.split(pred_key)
    posterior_preds = pred(
        pred_key, data.xte, subsample_size=n_te, num_data=n_te, img_dim=img_dim
    )

    log_likes.append(
        jnp.mean(
            log_likelihood(
                model,
                posterior_preds,
                data.xte,
                subsample_size=n_te,
                batch_ndims=2,
                num_data=n_te,
                img_dim=img_dim,
            )["img"]
        )
    )

    return {"time": [], "log_like": log_likes, "rmse": []}


def main(dataset, config, verbose=False, store_result=False):
    if verbose:
        print(dataset)

    results = run(dataset, config, verbose)
    if store_result:
        write_result("vae", dataset, div_order=str(config.div_order), **results)

    if verbose:
        log_likes = np.array(results["log_like"])

        print(dataset)
        print(rf"log likelihood {np.mean(log_likes):.3f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        choices=["mnist", "omniglot"],
        default="mnist",
    )
    parser.add_argument("--subsample-size", type=int, default=20)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--repulsion", type=float, default=1.0)
    parser.add_argument("--div-order", type=float, default=1.0)
    parser.add_argument("--verbose", type=bool, default=True)
    parser.add_argument("--store-result", type=bool, default=True)
    parser.add_argument("--num-particles", type=int, default=5)
    parser.add_argument("--num-elbo-particles", type=int, default=20)
    parser.add_argument("--progress-bar", type=bool, default=True)
    parser.add_argument("--rng-seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="gpu", choices=["gpu", "cpu"])

    args = parser.parse_args()

    numpyro.set_platform(args.device)
    dataset = args.__dict__.pop("dataset")
    verbose = args.__dict__.pop("verbose")
    store_result = args.__dict__.pop("store_result")

    main(dataset, args, verbose=verbose, store_result=store_result)
