import argparse
from collections import namedtuple
from functools import partial
from pathlib import Path
from time import time
from typing import Callable, Dict, Tuple

import numpy as np
from sklearn.model_selection import train_test_split

from jax import random
import jax.numpy as jnp

from experiments.result_io import write_result
import numpyro
from numpyro.contrib.einstein import RBFKernel, SteinVI
from numpyro.contrib.einstein.kernels import SteinKernel
from numpyro.distributions import Gamma, Normal, biject_to
from numpyro.infer import (
    Predictive,
    RenyiELBO,
    Trace_ELBO,
    init_to_uniform,
    log_likelihood,
)
from numpyro.infer.autoguide import AutoDelta, AutoNormal
from numpyro.optim import Adagrad, Adam

DATADIR = Path(__file__).parent / "data"
DataState = namedtuple("data", ["xtr", "xte", "ytr", "yte"])


class PPK(SteinKernel):
    def __init__(self, guide, rho=1.0):
        self._mode = "norm"
        self.guide = guide
        self.rho = rho
        assert isinstance(guide, AutoNormal), "PPK only implemented for AutoNormal"

    def compute(
        self,
        particles: jnp.ndarray,
        particle_info: Dict[str, Tuple[int, int]],
        loss_fn: Callable[[jnp.ndarray], float],
    ):
        loc_idx = jnp.concatenate(
            [
                jnp.arange(*idx)
                for name, idx in particle_info.items()
                if name.endswith(f"{self.guide.prefix}_loc")
            ]
        )
        scale_idx = jnp.concatenate(
            [
                jnp.arange(*idx)
                for name, idx in particle_info.items()
                if name.endswith(f"{self.guide.prefix}_scale")
            ]
        )

        def kernel(x, y):
            biject = biject_to(self.guide.scale_constraint)
            x_loc = x[loc_idx]
            x_scale = biject(x[scale_idx])
            x_quad = (x_loc / x_scale) ** 2

            y_loc = y[loc_idx]
            y_scale = biject(y[scale_idx])
            y_quad = (y_loc / y_scale) ** 2

            cross_loc = x_loc * x_scale**-2 + y_loc * y_scale**-2
            cross_var = 1 / (y_scale**-2 + x_scale**-2)
            cross_quad = cross_loc**2 * cross_var

            quad = jnp.exp(-self.rho / 2 * (x_quad + y_quad - cross_quad))

            norm = (
                (2 * jnp.pi) ** ((1 - 2 * self.rho) * 1 / 2)
                * self.rho ** (-1 / 2)
                * cross_var ** (1 / 2)
                * x_scale ** (-self.rho)
                * y_scale ** (-self.rho)
            )

            return jnp.linalg.norm(norm * quad)

        return kernel

    @property
    def mode(self):
        return self._mode


def load_data(name: str, rng_seed=None) -> DataState:
    data = np.loadtxt(DATADIR / f"{name}.txt")
    x, y = data[:, :-1], data[:, -1]
    xtr, xte, ytr, yte = train_test_split(x, y, train_size=0.90, random_state=rng_seed)

    return DataState(*map(partial(jnp.array, dtype=float), (xtr, xte, ytr, yte)))


def normalize(val, mean=None, std=None):
    if mean is None and std is None:
        std = jnp.std(val, 0, keepdims=True)
        std = jnp.where(std == 0, 1.0, std)
        mean = jnp.mean(val, 0, keepdims=True)
    return (val - mean) / std, mean, std


def model(x, y=None, hidden_dim=50, subsample_size=100):
    """BNN described in section 5 of [1].
    **References:**
        1. *Stein variational gradient descent: A general purpose bayesian inference algorithm*
        Qiang Liu and Dilin Wang (2016).
    """

    prec_nn = numpyro.sample(
        "prec_nn", Gamma(1.0, 0.1)
    )  # hyper prior for precision of nn weights and biases

    n, m = x.shape

    with numpyro.plate("l1_hidden", hidden_dim, dim=-1):
        # prior l1 bias term
        b1 = numpyro.sample(
            "nn_b1",
            Normal(
                0.0,
                1.0 / jnp.sqrt(prec_nn),
            ),
        )
        assert b1.shape == (hidden_dim,)

        with numpyro.plate("l1_feat", m, dim=-2):
            w1 = numpyro.sample(
                "nn_w1", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
            )  # prior on l1 weights
            assert w1.shape == (m, hidden_dim)

    with numpyro.plate("l2_hidden", hidden_dim, dim=-1):
        w2 = numpyro.sample(
            "nn_w2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
        )  # prior on output weights

    b2 = numpyro.sample(
        "nn_b2", Normal(0.0, 1.0 / jnp.sqrt(prec_nn))
    )  # prior on output bias term

    # precision prior on observations
    prec_obs = numpyro.sample("prec_obs", Gamma(1.0, 0.1))
    with numpyro.plate(
        "data",
        x.shape[0],
        subsample_size=subsample_size,
        dim=-1,
    ):
        batch_x = numpyro.subsample(x, event_dim=1)
        if y is not None:
            batch_y = numpyro.subsample(y, event_dim=0)
        else:
            batch_y = y

        numpyro.sample(
            "y",
            Normal(
                jnp.maximum(batch_x @ w1 + b1, 0) @ w2 + b2, 1.0 / jnp.sqrt(prec_obs)
            ),  # 1 hidden layer with ReLU activation
            obs=batch_y,
        )


def run(dataset, config: argparse.Namespace, verbose=False):
    data = [load_data(dataset, rng_seed=i + 1) for i in range(config.num_repeats)]

    inf_key, pred_key, data_key = random.split(random.PRNGKey(config.rng_seed + 1), 3)
    x, xtr_mean, xtr_std = zip(
        *[normalize(data[i].xtr) for i in range(config.num_repeats)]
    )
    y, ytr_mean, ytr_std = zip(
        *[normalize(data[i].ytr) for i in range(config.num_repeats)]
    )
    n_tr = x[0].shape[0]

    times = []
    states = []

    match config.div_order:
        case 1.0:
            loss = Trace_ELBO(num_particles=config.num_elbo_particles)
        case _:
            loss = RenyiELBO(
                alpha=config.div_order, num_particles=config.num_elbo_particles
            )

    match config.guide:
        case "normal":
            guide = AutoNormal(model)
            kernel = PPK(guide)
        case "delta":
            guide = AutoDelta(model, init_loc_fn=partial(init_to_uniform, radius=0.1))
            kernel = RBFKernel()

    for i in range(config.num_repeats):
        rng_key, inf_key = random.split(inf_key)

        stein = SteinVI(
            model,
            guide,
            Adagrad(config.step_size),
            loss,
            kernel,
            repulsion_temperature=config.repulsion,
            num_particles=config.num_particles,
        )
        start = time()
        # use keyword params for static (shape etc.)!
        result = stein.run(
            rng_key,
            n_tr * config.epochs // config.subsample_size,
            x[i],
            y[i],
            progress_bar=verbose,
            hidden_dim=50,
            subsample_size=config.subsample_size,
        )
        times.append(time() - start)
        states.append(result.state)

    rmses = []
    log_likes = []
    xte, _, _ = zip(
        *[
            normalize(data[i].xte, xtr_mean[i], xtr_std[i])
            for i in range(config.num_repeats)
        ]
    )
    yte, _, _ = zip(*[normalize(data[i].yte) for i in range(config.num_repeats)])
    for i, state in enumerate(states):
        n_te = xte[i].shape[0]
        pred = Predictive(
            model,
            guide=stein.guide,
            params=stein.get_params(state),
            return_sites=[
                name
                for name, site in stein.guide.prototype_trace.items()
                if site["type"] == "sample"
            ],
            num_samples=100,
            batch_ndims=1,
        )
        rng_key, pred_key = random.split(pred_key)
        posterior_preds = pred(rng_key, xte[i], subsample_size=n_te)

        y_pred = posterior_preds.pop("y")
        y_pred = (
            jnp.mean(y_pred.reshape(-1, xte[i].shape[0]), 0) * ytr_std[i] + ytr_mean[i]
        )

        rmses.append(jnp.sqrt(jnp.mean((y_pred - data[i].yte) ** 2)))

        log_likes.append(
            log_likelihood(
                model,
                posterior_preds,
                xte[i],
                yte[i],
                subsample_size=n_te,
                batch_ndims=2,
            )["y"].mean()
        )

    return {"time": times, "rmse": rmses, "log_like": log_likes, "config": vars(config)}


def main(dataset, config, verbose=False, store_result=False):
    if verbose:
        print(dataset)
    results = run(dataset, config, verbose=verbose)

    if store_result:
        write_result("bnn", dataset, div_order=str(config.div_order), **results)

    if verbose:
        # pretty print?
        times = np.array(results["time"])
        rmses = np.array(results["rmse"])
        log_likes = np.array(results["log_like"])

        print(dataset)
        print(rf"timing {np.mean(times): .3f}±{np.std(times):.3f}")
        print(rf"rmse {np.mean(rmses):.3f}±{np.std(rmses):.3f}")
        print(rf"log likelihood {np.mean(log_likes):.3f}±{np.std(log_likes):.3f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        choices=[
            "boston_housing",
            "concrete",
            "energy_heating_load",
            "kin8nm",
            "naval_compressor_decay",
            "power",
            "protein",
            "wine",
            "yacht",
            "year_prediction_msd",
        ],
        default="concrete",
    )

    parser.add_argument("--subsample-size", type=int, default=100)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--repulsion", type=float, default=1.0)
    parser.add_argument("--div-order", type=float, default=1.0)
    parser.add_argument("--num-repeats", type=int, default=5)
    parser.add_argument("--verbose", type=bool, default=True)
    parser.add_argument("--num-particles", type=int, default=5)
    parser.add_argument("--num-elbo-particles", type=int, default=100)
    parser.add_argument("--rng-seed", type=int, default=42)
    parser.add_argument("--device", type=str, default="cpu", choices=["gpu", "cpu"])
    parser.add_argument("--store-result", type=bool, default=False)
    parser.add_argument("--step-size", type=float, default=0.05)
    parser.add_argument(
        "--guide", type=str, default="normal", choices=["delta", "normal"]
    )

    args = parser.parse_args()

    numpyro.set_platform(args.device)
    dataset = args.__dict__.pop("dataset")
    verbose = args.__dict__.pop("verbose")
    store_result = args.__dict__.pop("store_result")

    main(dataset, args, verbose=verbose, store_result=store_result)
