"""Bayesian Logistic regression.
Adapted from https://github.com/langosco/neural-variational-gradient-descent/
blob/master/nvgd/experiments/bayesian_logistic_regression.py
"""
from functools import partial
from typing import Optional, Literal
from pathlib import Path

import time
import chex
import jax
import jax.numpy as jnp
from evosax import FitnessShaper, OpenES
from jax.scipy import stats, special
from matplotlib import pyplot as plt
from pandas import read_csv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from ucimlrepo import fetch_ucirepo

from sves.benchmarks import Benchmark
from sves.kernels import RBF
from sves.strategies import ParallelCMAES, OG_SVGD, GF_SVGD, MC_SVGD, ParallelOpenES
from sves.strategies.svgd_cma_es import SV_CMA_BB


# covtype can be loaded from sklearn
CREDIT_PATH = Path(__file__).parent.parent.resolve() / Path('data/German_credit.csv')


def ravel(w, log_alpha):
    return jnp.hstack([w, jnp.expand_dims(log_alpha, -1)])


def unravel(params):
    if params.ndim == 1:
        return params[:-1], params[-1]
    elif params.ndim == 2:
        return params[:, :-1], jnp.squeeze(params[:, -1])


class SVGDLogisticRegression(Benchmark):
    def __init__(
        self,
        dataset: Literal["covtype", "credit", "spam"] = "covtype",
        batch_size: int = 128,
        a0: float = 1.,
        b0: float = 1e-2
    ) -> None:
        if dataset == "covtype":
            self.X_train, self.X_val, self.X_test, self.y_train, self.y_val, self.y_test = self.load_covtype_data()
        elif dataset == "spam":
            self.X_train, self.X_val, self.X_test, self.y_train, self.y_val, self.y_test = self.load_uci(94)
        elif dataset == "credit":
            self.X_train, self.X_val, self.X_test, self.y_train, self.y_val, self.y_test = self.load_credit()
        else:
            raise ValueError(f"Dataset name must be 'covtype', 'credit', or 'spam', but is {dataset}")

        # Standardize the features
        scaler = StandardScaler()
        self.X_train = scaler.fit_transform(self.X_train)
        self.X_val = scaler.transform(self.X_val)
        self.X_test = scaler.transform(self.X_test)

        self.n_train, self.n_features = self.X_train.shape
        self.batch_size = batch_size
        self.n_batches = self.n_train // batch_size
        self.a0, self.b0 = a0, b0
        self.max_w = 1e3
        super().__init__(lb=-100, ub=100., dim=self.n_features+1, fglob=1., name=dataset)

    @staticmethod
    def load_covtype_data():
        from sklearn.datasets import fetch_covtype
        data = fetch_covtype()
        X = data.data
        y = (data.target == 2).astype(int)  # Binary classification: class 2 vs others

        # Partition the data
        X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.125, random_state=42)

        return X_train, X_val, X_test, y_train, y_val, y_test

    @staticmethod
    def load_credit():
        data = read_csv(CREDIT_PATH)
        data = data.values
        X = data[:, :-1]
        y = data[:, -1].flatten()

        # Partition the data
        X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.125, random_state=42)

        return X_train, X_val, X_test, y_train, y_val, y_test

    @staticmethod
    def load_uci(uci_id: int):
        dataset = fetch_ucirepo(id=uci_id)
        X = dataset.data.features.to_numpy()
        y = dataset.data.targets.to_numpy().flatten()

        # Partition the data
        X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.125, random_state=42)

        return X_train, X_val, X_test, y_train, y_val, y_test

    def get_batches(self, rng: chex.PRNGKey, x: chex.Array, y: chex.Array, n_steps: Optional[int] = None, batch_size: Optional[int] = None):
        """Split x and y into batches.

        NOTE: this can absolutely not be called inside some jitted
         code because then the same batch will be returned over and over.
        """
        assert len(x) == len(y)
        assert x.ndim > y.ndim
        batch_size = self.batch_size if batch_size is None else batch_size
        n_steps = batch_size * 2 if n_steps is None else n_steps

        n = len(x)
        idxs = jax.random.choice(rng, n, shape=(n_steps, self.batch_size))
        for idx in idxs:
            yield jnp.take(x, idx, axis=0), jnp.take(y, idx, axis=0)

    def sample_from_prior(self, rng, num=100):
        keya, keyb = jax.random.split(rng)
        alpha = jax.random.gamma(keya, self.a0, shape=(num,)) / self.b0
        w = jax.random.normal(keyb, shape=(num, self.n_features))
        return w, jnp.log(alpha)

    def prior_logp(self, w, log_alpha):
        """
        Returns logp(w, log_alpha) = sum_i(logp(wi, alphai))

        w has shape (num_features,), or (n, num_features)
        similarly, log_alpha may have shape () or (n,)."""
        if log_alpha.ndim == 0:
            assert w.ndim == 1
        elif log_alpha.ndim == 1:
            assert log_alpha.shape[0] == w.shape[0]

        alpha = jnp.exp(log_alpha)
        alpha = jnp.clip(alpha, 1e-6, 1e6)
        logp_alpha = jnp.sum(stats.gamma.logpdf(alpha, self.a0, scale=1 / self.b0))
        if w.ndim == 2:
            logp_w = jnp.sum(jax.vmap(lambda wi, alphai: stats.norm.logpdf(wi, scale=1 / jnp.sqrt(alphai)))(w, alpha))
        elif w.ndim == 1:
            logp_w = jnp.sum(stats.norm.logpdf(w, scale=1 / jnp.maximum(jnp.sqrt(alpha), 1e-6)))
        else:
            raise ValueError

        return logp_alpha + logp_w
    
    def loglikelihood(self, y, x, w):
        """
        returns P(y_generated==y | x, w)

        y and x are data batches. w is a single parameter
        array of shape (num_features,)."""
        logits = x @ w
        y = ((y - 1 / 2) * 2).astype(jnp.int32)  # Scale to [-1, 1]
        prob_y = jnp.maximum(special.expit(logits * y), 1e-6)  # Clip from bottom for numeriacl stability
        return jnp.sum(jnp.log(prob_y))

    def compute_probs(self, y, x, w):
        """
        returns P(y_generated==y | x, w)

        y and x are data batches. w is a single parameter
        array of shape (num_features,)"""
        logits = x @ w
        y = ((y - 1 / 2) * 2).astype(jnp.int32)
        prob_y = jnp.maximum(special.expit(logits * y), 1e-6)
        return prob_y

    @partial(jax.jit, static_argnums=(0,))
    def compute_accuracy(self, params, x, y):
        w = unravel(params)[0]
        probs = self.compute_probs(y, x, w)
        return jnp.mean(probs > 0.5)

    @partial(jax.jit, static_argnums=(0,))
    def minibatch_logp(self, params, x, y):
        """
        Returns callable logp that computes the unnormalized target
        log pdf of raveled (flat) params with shape (num_features+1,)
        or shape (n, num_features+1).

        y, x are minibatches of data."""
        assert len(x) == len(y)
        assert x.ndim > y.ndim

        # def logp(params):
        # """params = ravel(w, log_alpha)"""
        w, log_alpha = unravel(params)
        log_prior = self.prior_logp(w, log_alpha)
        if w.ndim == 1:
            loglikelihood = self.loglikelihood(y, x, w)
        elif w.ndim == 2:
            loglikelihood = jnp.mean(jax.vmap(lambda wi: self.loglikelihood(y, x, wi))(w))
        else:
            raise ValueError

        return log_prior / self.n_train + loglikelihood  # = grad(log p)(theta) + N/n sum_i grad(log p)(theta | x)

    @partial(jax.jit, static_argnums=0)
    def compute_average_nll(self, params: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray) -> float:
        # Unravel flat parameters -> w, log_alpha
        w, log_alpha = unravel(params)
        sum_logprob = self.loglikelihood(y, X, w)
        return -sum_logprob / X.shape[0]

    def train(
        self,
        rng,
        strategy,
        device,
        mode: Literal["test", "val"] = "val",
        n_iter: Optional[int] = None,
        n_val: int = 40,
        progress_bar: bool = False
    ):
        n_iter = n_iter if n_iter else self.n_batches
        assert n_iter > n_val

        # Get data
        if mode == "test":
            self.eval_data = self.X_test
            self.eval_labels = self.y_test
        else:
            self.eval_data = self.X_val
            self.eval_labels = self.y_val
        rng, rng_train, rng_test = jax.random.split(rng, 3)
        train_batches = self.get_batches(rng_train, self.X_train, self.y_train, n_iter)


        # Get strategy
        rng, rng_init1, rng_init2 = jax.random.split(rng, 3)
        init_particles = ravel(*self.sample_from_prior(rng_init1, strategy.npop))
        es_params = strategy.default_params.replace(clip_min=-self.max_w, clip_max=self.max_w)
        es_state = strategy.initialize(rng_init2, es_params).replace(particles=init_particles)
        if hasattr(es_state, "mean"):
            es_state = es_state.replace(mean=init_particles)
        shaper = FitnessShaper(centered_rank=True)

        # # Warm-up phase
        # w, _ = strategy.ask(rng_init3, es_state, es_params)
        # f_warmup = jnp.ones_like(w) if isinstance(strategy, OG_SVGD) else jnp.ones(w.shape[0],)
        # strategy.tell(w, f_warmup, es_state, es_params)

        test_acc_hist, test_logp_hist, test_nll_hist = [], [], []
        start = time.time()
        times = [0.]
        samples = []
        for i, batch in tqdm(enumerate(train_batches), total=n_iter, disable=not progress_bar):
            rng, rng_a = jax.random.split(rng)
            w, es_state = strategy.ask(rng_a, es_state, es_params)

            # Evaluate grad if we have the gradient-based svgd
            if isinstance(strategy, OG_SVGD):
                fitness = jax.vmap(jax.grad(lambda x: self.minibatch_logp(x, *batch)))(w)
            else:
                fitness = jax.vmap(lambda x: -self.minibatch_logp(x, *batch))(w)

            fitness = jnp.nan_to_num(fitness, nan=-1e6)  # nan can come from numerical errors; should also work without this though
            shaped_fitness = shaper.apply(w, fitness) if isinstance(strategy, (MC_SVGD, OpenES, ParallelOpenES)) else fitness
            es_state = strategy.tell(w, shaped_fitness, es_state, es_params)

            if i % (n_iter // n_val) == 0:
                # test_batch = next(test_batches)
                test_logp = jax.vmap(self.minibatch_logp, (0, None, None))(es_state.particles, self.eval_data, self.eval_labels)
                test_nll = jax.vmap(self.compute_average_nll, (0, None, None))(es_state.particles, self.eval_data, self.eval_labels)
                test_acc = jax.vmap(self.compute_accuracy, (0, None, None))(es_state.particles, self.eval_data, self.eval_labels)
                print(i, test_acc.mean(), test_logp.mean(), (-fitness).mean())
                test_acc_hist.append(test_acc.mean())
                test_logp_hist.append(test_logp.mean())
                test_nll_hist.append(test_nll.mean())
                elapsed = time.time() - start
                times.append(elapsed)
                samples.append(es_state.particles)

        return {
            "accuracy": test_acc_hist,
            "test_logp": test_logp_hist,
            "test_nll": test_nll_hist,
            "samples": samples,
            "times": times
        }


def run_cfg_cma(
    key: chex.PRNGKey,
    npart: int,
    subpopsize: int,
    er: float,
    sig: float,
    kw: float,
    nrep: int,
    num_generations: int,
    bench: SVGDLogisticRegression,
    device: chex.Device
):
    """Wrapper tu run a brax experiment for the SV-CMA-ES algorithm.

    Args:
        key: Random number generator seed.
        npart: Number of particles / ES populations.
        subpopsize: Size of each subpopulation.
        er: Elite ratio for CMA-ES.
        sig: Sigma for CMA-ES.
        kw: Kernel bandwidth.
        nrep: Number or repetitions of the experiments. All experiments will be parallelized.
        num_generations: Number of generations for which the experiments are run for.
        bench: The benchmark class.
        device: Device for experiments. Deprecated.
    """
    strategy = SV_CMA_BB(npart, subpopsize, RBF(kw), num_dims=bench.dim, elite_ratio=er, sigma_init=sig, num_iters=num_generations)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy, device=device, mode="val", n_iter=num_generations)
    )(seeds)
    return results


def run_cfg_oes(key, npart, subpopsize, lr, sig, kw, nrep, num_generations, bench, device):
    strategy = MC_SVGD(npart, subpopsize, kernel=RBF(kw), num_iters=num_generations, num_dims=bench.dim, sigma_init=sig, lrate_init=lr)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy, device=device, mode="val", n_iter=num_generations)
    )(seeds)
    return results


def run_cfg_og(key, npart, subpopsize, lr, kw, nrep, num_generations, bench, device):
    strategy = OG_SVGD(npart * subpopsize, RBF(kw), num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=lr)
     
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy, device=device, mode="val", n_iter=num_generations)
    )(seeds)
    return results


def run_cfg_gf(key, npart, subpopsize, sig, lr, kw, nrep, num_generations, bench, device):
    strategy = GF_SVGD(npart * subpopsize, RBF(kw), sigma_init=sig, num_iters=num_generations, num_dims=bench.dim, opt_name="adam", lrate_init=lr)
    
    seeds = jax.random.split(key, nrep)
    results = jax.vmap(
        lambda seed: bench.train(seed, strategy, device=device, mode="val", n_iter=num_generations)
    )(seeds)
    return results
