from functools import partial
import os
import pickle
from copy import deepcopy

import numpy as np
import jax
import jax.numpy as jnp
import optax
import flax
import flax.linen as nn
from flax.training.train_state import TrainState


class NeuralNetwork(nn.Module):
    input_dim: int
    output_dim: int
    hidden_dim: int = 256
    num_hidden_layers: int = 4

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        for _ in range(self.num_hidden_layers):
            x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.orthogonal(scale=jnp.sqrt(2)))(x)
            x = nn.relu(x)
        x = nn.Dense(self.output_dim, kernel_init=nn.initializers.orthogonal(scale=jnp.sqrt(2)))(x)
        return x

class Ensemble(nn.Module):
    input_dim: int
    output_dim: int
    hidden_dim: int = 256
    num_hidden_layers: int = 4
    num_models: int = 5

    @nn.compact
    def __call__(self, x: jnp.ndarray, return_maxmin_logvar: bool = False):
        vmap_ensemble = nn.vmap(
            NeuralNetwork,
            variable_axes={"params": 0},  # parameters not shared between the models
            split_rngs={"params": True},  # different initializations
            in_axes=0,
            axis_size=self.num_models,
        )
        ensemble_out = vmap_ensemble(
            input_dim=self.input_dim,
            output_dim=self.output_dim * 2,  # mean and variance
            hidden_dim=self.hidden_dim,
            num_hidden_layers=self.num_hidden_layers
            )(x)

        # TODO: check if this is being learned
        max_logvar = self.param("max_logvar", nn.initializers.constant(0.5), (1, self.output_dim))
        min_logvar = self.param("min_logvar", nn.initializers.constant(-10.0), (1, self.output_dim))

        mean, logvar = jnp.split(ensemble_out, 2, axis=-1)
        logvar = max_logvar - nn.softplus(max_logvar - logvar)
        logvar = min_logvar + nn.softplus(logvar - min_logvar)

        if return_maxmin_logvar:
            return mean, logvar, max_logvar, min_logvar

        return mean, logvar


class ProbabilisticEnsemble:
    def __init__(
        self,
        input_dim,
        output_dim,
        ensemble_size=5,
        arch=[200, 200, 200, 200],
        learning_rate=0.001,
        num_elites=2,
        normalize_inputs=True,
    ):
        self.ensemble_size = ensemble_size
        self.input_dim = input_dim
        self.output_dim = output_dim  # mean and std
        self.arch = arch
        self.num_elites = num_elites
        self.elites = np.array([i for i in range(self.ensemble_size)])
        self.normalize_inputs = normalize_inputs
        self.learning_rate = learning_rate

    def build(self, key):
        key, self.key, subkey = jax.random.split(key, 3)

        self.ensemble = Ensemble(self.input_dim, self.output_dim, self.arch[0], len(self.arch), self.ensemble_size)
        self.ensemble_state = TrainState.create(
                        apply_fn=self.ensemble.apply, 
                        params=self.ensemble.init(
                            {"params": subkey},
                            jnp.ones((self.ensemble_size, 1, self.input_dim)),
                        ),
                        tx=optax.adamw(self.learning_rate, weight_decay=0.1),
        )
        self.ensemble.apply = jax.jit(self.ensemble.apply, static_argnames=["return_maxmin_logvar"])
        self.opt_state = deepcopy(self.ensemble_state.opt_state)

        self.inputs_mu = jnp.zeros((1, self.input_dim))
        self.inputs_sigma = jnp.ones((1, self.input_dim))

        return key
    
    def get_params(self):
        return {"ensemble_state": self.ensemble_state, 
                "inputs_mu": self.inputs_mu, 
                "inputs_sigma": self.inputs_sigma,
                "elites": self.elites
        }

    @staticmethod
    @partial(jax.jit, static_argnames=["ensemble", "normalize_inputs", "deterministic", "return_dist"])
    def forward(ensemble, ensemble_state, input, inputs_mu, inputs_sigma, normalize_inputs, deterministic, return_dist, key):
        dim = len(input.shape)
        if dim < 3:
            input = input.reshape(1, *input.shape)
            if dim == 1:
                input = input.reshape(1, *input.shape)
            input = input.repeat(ensemble.num_models, axis=0)

        # input normalization
        if normalize_inputs:
            h = (input - inputs_mu) / inputs_sigma
        else:
            h = input

        mean, logvar = ensemble.apply(ensemble_state.params, h)

        if deterministic:
            if return_dist:
                return mean, logvar
            else:
                return mean
        else:
            key, subkey = jax.random.split(key)
            std = jnp.exp(0.5*logvar)      # exp(0.5*logvar) = sqrt(exp(logvar))
            samples = mean + std * jax.random.normal(subkey, std.shape)
            if return_dist:
                return key, samples, mean, logvar
            else:
                return key, samples

    def sample(self, input, deterministic=False):
        if not deterministic:
            self.key, samples, means, logvar = ProbabilisticEnsemble.forward(
                                                    self.ensemble,
                                                    self.ensemble_state, 
                                                    input, 
                                                    self.inputs_mu, 
                                                    self.inputs_sigma, 
                                                    normalize_inputs=self.normalize_inputs,
                                                    deterministic=False, 
                                                    return_dist=True,
                                                    key=self.key)
            samples = jax.device_get(samples)
        else:
            means, logvar = ProbabilisticEnsemble.forward(
                                                    self.ensemble,
                                                    self.ensemble_state, 
                                                    input, 
                                                    self.inputs_mu, 
                                                    self.inputs_sigma, 
                                                    normalize_inputs=self.normalize_inputs,
                                                    deterministic=True, 
                                                    return_dist=True,
                                                    key=self.key)
        means = jax.device_get(means)
        logvar = jax.device_get(logvar)

        vars = np.exp(logvar)
        num_models, batch_size, _ = means.shape
        batch_inds = np.arange(0, batch_size)
        model_inds = np.random.choice(self.elites, size=batch_size)

        #Ensemble Standard Deviation/Variance (Lakshminarayanan et al., 2017)
        mean_ensemble = means.mean(axis=0)
        var_ensemble = (means**2 + vars).mean(axis=0) - mean_ensemble**2
        std_ensemble = np.sqrt(var_ensemble + 1e-12)
        uncertainties = std_ensemble.sum(-1)

        if deterministic:
            return means[model_inds, batch_inds], vars[model_inds, batch_inds], uncertainties
        else:
            return samples[model_inds, batch_inds], vars[model_inds, batch_inds], uncertainties

    @staticmethod
    @partial(jax.jit, static_argnames=["ensemble", "normalize_inputs"])
    def compute_mse_losses(ensemble, ensemble_state, x, y, inputs_mu, inputs_sigma, normalize_inputs):
        mean = ProbabilisticEnsemble.forward(
                                    ensemble,
                                    ensemble_state, 
                                    x, 
                                    inputs_mu, 
                                    inputs_sigma, 
                                    normalize_inputs=normalize_inputs,
                                    deterministic=True, 
                                    return_dist=False,
                                    key=None
        )
        mse_losses = (mean - y) ** 2  # (num_models, batch_size, output_dim)
        return mse_losses.mean(-1).mean(-1)  # (num_models)

    def fit_input_stats(self, data):
        mu = np.mean(data, axis=0, keepdims=True)
        sigma = np.std(data, axis=0, keepdims=True)
        sigma[sigma < 1e-12] = 1.0
        self.inputs_mu = mu
        self.inputs_sigma = sigma

    @staticmethod
    @partial(jax.jit, static_argnames=["ensemble", "normalize_inputs"])
    def update(ensemble, ensemble_state, x, y, inputs_mu, inputs_sigma, normalize_inputs):
        if normalize_inputs:
            h = (x - inputs_mu) / inputs_sigma
        else:
            h = x

        def loss_fn(params):
            mean, logvar, max_logvar, min_logvar = ensemble.apply(params, h, return_maxmin_logvar=True)

            inv_var = jnp.exp(-logvar)
            loss = 0.5 * (logvar + (mean - y)**2 * inv_var)
            total_losses = loss.mean()
            total_losses += 0.01 * (max_logvar.sum() - min_logvar.sum())
            return total_losses

        loss, grad = jax.value_and_grad(loss_fn)(ensemble_state.params)
        ensemble_state = ensemble_state.apply_gradients(grads=grad)
        return ensemble_state, loss

    @staticmethod
    @jax.jit
    def reset_opt_state(ensemble_state, opt_state):
        ensemble_state = ensemble_state.replace(opt_state=opt_state)
        return ensemble_state

    def fit(
        self,
        X,
        Y,
        batch_size=256,
        holdout_ratio=0.1,
        max_holdout_size=5000,
        max_epochs_no_improvement=5,
        max_epochs=200,
    ):
        if self.normalize_inputs:
            self.fit_input_stats(X)

        num_holdout = min(int(X.shape[0] * holdout_ratio), max_holdout_size)
        permutation = np.random.permutation(X.shape[0])
        inputs, holdout_inputs = (
            X[permutation[num_holdout:]],
            X[permutation[:num_holdout]],
        )
        targets, holdout_targets = (
            Y[permutation[num_holdout:]],
            Y[permutation[:num_holdout]],
        )

        idxs = np.random.randint(inputs.shape[0], size=[self.ensemble_size, inputs.shape[0]])
        num_batches = int(np.ceil(idxs.shape[-1] / batch_size))

        def shuffle_rows(arr):
            idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1)
            return arr[np.arange(arr.shape[0])[:, None], idxs]

        num_epochs_no_improvement = 0
        epoch = 0
        best_holdout_losses = [float("inf") for _ in range(self.ensemble_size)]
        mean_losses = []
        while num_epochs_no_improvement < max_epochs_no_improvement and epoch < max_epochs:
            mean_loss = 0.0
            for batch_num in range(num_batches):
                batch_idxs = idxs[:, batch_num * batch_size : (batch_num + 1) * batch_size]
                batch_x, batch_y = inputs[batch_idxs], targets[batch_idxs]

                self.ensemble_state, loss = ProbabilisticEnsemble.update(self.ensemble, self.ensemble_state, batch_x, batch_y, self.inputs_mu, self.inputs_sigma, self.normalize_inputs)
                mean_loss += loss
            
            mean_loss /= num_batches
            mean_losses.append(mean_loss)

            idxs = shuffle_rows(idxs)

            holdout_losses = ProbabilisticEnsemble.compute_mse_losses(self.ensemble, self.ensemble_state, holdout_inputs, holdout_targets, self.inputs_mu, self.inputs_sigma, self.normalize_inputs)
            holdout_losses = [jax.device_get(l) for l in holdout_losses]

            self.elites = np.argsort(holdout_losses)[: self.num_elites]

            improved = False
            for i in range(self.ensemble_size):
                if epoch == 0 or (best_holdout_losses[i] - holdout_losses[i]) / (best_holdout_losses[i]) > 0.01:
                    best_holdout_losses[i] = holdout_losses[i]
                    num_epochs_no_improvement = 0
                    improved = True
            if not improved:
                num_epochs_no_improvement += 1

            epoch += 1

        print("Epoch:", epoch, "Mean loss", mean_loss, "Holdout losses:", ", ".join(["%.4f" % hl for hl in holdout_losses]))
        return np.mean(mean_losses), np.mean(holdout_losses)
