import functools
from typing import Any, Callable, Dict, Sequence, Tuple

import d4rl
import distrax
import gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import linen as nn
from flax.core import FrozenDict
from flax.training import checkpoints, train_state
from tqdm import trange
from utils import Batch, get_flow_action_dataset, target_update

###################
# Utils Functions #
###################
LOG_STD_MIN = -5.
LOG_STD_MAX = 2.


def init_fn(initializer: str, gain: float = jnp.sqrt(2)):
    if initializer == "orthogonal":
        return nn.initializers.orthogonal(gain)
    elif initializer == "glorot_uniform":
        return nn.initializers.glorot_uniform()
    elif initializer == "glorot_normal":
        return nn.initializers.glorot_normal()
    return nn.initializers.lecun_normal()


class MLP(nn.Module):
    hidden_dims: Sequence[int] = (256, 256)
    init_fn: Callable = nn.initializers.glorot_uniform()
    activate_final: bool = True

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=self.init_fn)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = nn.relu(x)
        return x


##############
# Flow Model #
##############
def get_flow_layer(flow_layers, input_dim, hidden_dim, mask, initializer):
    flows = []
    for i in range(flow_layers):
        flow = CouplingLayer(input_dim, hidden_dim,
                             mask if i % 2 == 0 else 1 - mask, initializer)
        flows.append(flow)
    return flows


class CouplingLayer(nn.Module):
    input_dim: int
    hidden_dim: int
    mask: jnp.ndarray
    initializer: str = "orthogonal"

    def setup(self):
        self.mlp = MLP(hidden_dims=(self.hidden_dim, ) * 2,
                       init_fn=init_fn(self.initializer),
                       activate_final=False)
        self.shared_net = MLP(hidden_dims=(self.hidden_dim, ) * 2,
                              init_fn=init_fn(self.initializer),
                              activate_final=True)
        self.out_net = nn.Dense(self.input_dim * 2)
        self.scaling_factor = self.param("scaling_factor",
                                         nn.initializers.zeros,
                                         (self.input_dim, ))

    def __call__(self, inputs, cond_inputs, reverse: bool = False):
        """
        forward pass: z = x * s + t
        reverse pass: x = (z - t)/s
        """
        # print(f"masks.shape = {self.mask.shape}")
        # print(f"inputs.shape = {inputs.shape}")
        # print(f"cond_inputs.shape = {cond_inputs.shape}")
        masked_inputs = inputs * self.mask
        cond_inputs = self.mlp(cond_inputs)
        masked_inputs = jnp.concatenate([masked_inputs, cond_inputs], axis=-1)
        masked_inputs = self.shared_net(masked_inputs)
        mu, log_s = self.out_net(masked_inputs).split(2, axis=-1)
        s_fac = jnp.exp(self.scaling_factor)
        log_s = nn.tanh(log_s / s_fac) * s_fac

        mu = mu * (1 - self.mask)
        log_s = log_s * (1 - self.mask)
        z = jnp.where(reverse, (inputs - mu) * jnp.exp(-log_s),
                      inputs * jnp.exp(log_s) + mu)
        log_det = jnp.where(reverse, -log_s.sum(-1), log_s.sum(-1))

        return z, log_det


class Flow(nn.Module):
    """A conditional generative model."""
    flows: Sequence[nn.Module]
    input_dim: int
    hidden_dim: int
    mask: jnp.ndarray
    initializer: str = "orthogonal"

    def __call__(self, inputs, cond_inputs):
        """compute the log likelihood of x"""
        z, y = inputs, cond_inputs
        log_dets = jnp.zeros(inputs.shape[0])

        # forward pass (x ==> z)
        for flow in self.flows:
            z, log_det = flow(z, y, reverse=False)
            log_dets += log_det
        logp_z = jax.scipy.stats.norm.logpdf(z).sum(axis=-1)
        logp_x = logp_z + log_dets
        return logp_x

    def sample(self, rng, cond_inputs):
        """convert z to x"""

        # sample noise z from the prior
        rng, sample_rng = jax.random.split(rng, 2)
        z = jax.random.normal(rng, shape=cond_inputs.shape)
        y = cond_inputs

        # reverse pass (z ==> x)
        for flow in reversed(self.flows):
            z, _ = flow(z, y, reverse=True)
        return z


class Prior:
    def __init__(self,
                 env_name: str,
                 flow_layers: int = 6,
                 hidden_dim: int = 128,
                 seed: int = 42,
                 lr: float = 3e-4,
                 batch_size: int = 512,
                 epochs: int = 100,
                 initializer: str = "orthogonal"):

        # params
        self.batch_size = batch_size
        self.epochs = epochs

        # d4rl env
        self.env_name = env_name
        self.env = gym.make(env_name)
        input_dim = self.env.action_space.shape[0]
        cond_dim = input_dim  # condition on one-step action

        # random seed
        self.rng = jax.random.PRNGKey(seed)
        self.rng, model_key = jax.random.split(self.rng, 2)

        # create the Flow model
        mask = jnp.arange(0, input_dim, dtype=float) % 2
        flows = get_flow_layer(flow_layers, input_dim, hidden_dim, mask,
                               initializer)
        self.model = Flow(flows, input_dim, hidden_dim, mask, initializer)

        # initialize model parameters
        dummy_inputs = jnp.ones([1, input_dim])
        dummy_cond_inputs = jnp.ones([1, cond_dim])
        model_params = self.model.init(model_key, dummy_inputs,
                                       dummy_cond_inputs)["params"]
        self.model_state = train_state.TrainState.create(
            apply_fn=self.model.apply,
            params=model_params,
            tx=optax.adam(learning_rate=lr))
        self.optimal_state = self.model_state

    @functools.partial(jax.jit, static_argnames=("self"))
    def _sample(self, params, rng, cond_inputs):
        sampled_action = self.model.apply({"params": params},
                                          rng,
                                          cond_inputs,
                                          method=Flow.sample)
        return sampled_action

    def sample(self, rng, cond_inputs):
        sampled_action = self._sample(self.model_state.params, rng,
                                      cond_inputs)
        return sampled_action

    def get_action_data(self):
        env = gym.make(self.env_name)
        dataset = d4rl.qlearning_dataset(env)
        actions, next_actions = get_flow_action_dataset(dataset)
        train_num = int(0.8 * len(actions))
        idx = np.random.permutation(np.arange(len(actions)))
        train_idx = idx[:train_num]
        valid_idx = idx[train_num:]
        train_actions = actions[train_idx]
        valid_actions = actions[valid_idx]
        train_next_actions = next_actions[train_idx]
        valid_next_actions = next_actions[valid_idx]
        return train_actions, valid_actions, train_next_actions, valid_next_actions

    def train(self):
        @jax.jit
        def train_step(model_state, batch_actions, batch_next_actions):
            def loss_fn(params):
                logp = self.model.apply({"params": params}, batch_next_actions,
                                        batch_actions)
                loss = -logp.mean()
                return loss, {"train_loss": loss}

            (_, log_info), grads = jax.value_and_grad(loss_fn, has_aux=True)(
                model_state.params)
            new_model_state = model_state.apply_gradients(grads=grads)
            return new_model_state, log_info

        @jax.jit
        def eval_step(model_state, batch_actions, batch_next_actions):
            def loss_fn(params):
                logp = self.model.apply({"params": params}, batch_next_actions,
                                        batch_actions)
                loss = -logp.mean()
                return loss

            valid_loss = loss_fn(model_state.params)
            return valid_loss

        train_actions, valid_actions, train_next_actions, valid_next_actions = self.get_action_data(
        )
        patience, max_patience, min_valid_loss = 0, 5, np.inf
        batch_num = int(np.ceil(len(train_actions) / self.batch_size))
        valid_batch_num = int(np.ceil(len(valid_actions) / 1000))

        for epoch in trange(self.epochs):
            train_losses = []
            shuffled_idxs = np.random.permutation(np.arange(
                len(train_actions)))
            for i in range(batch_num):
                batch_idxs = shuffled_idxs[i * self.batch_size:(i + 1) *
                                           self.batch_size]
                batch_train_actions = train_actions[batch_idxs]
                batch_train_next_actions = train_next_actions[batch_idxs]
                self.model_state, train_log = train_step(
                    self.model_state, batch_train_actions,
                    batch_train_next_actions)
                train_losses.append(train_log["train_loss"].item())

            valid_losses = []
            for i in range(valid_batch_num):
                batch_valid_actions = valid_actions[i * 1000:(i + 1) * 1000]
                batch_valid_next_actions = valid_next_actions[i *
                                                              1000:(i + 1) *
                                                              1000]
                valid_loss = eval_step(self.model_state, batch_valid_actions,
                                       batch_valid_next_actions)
                valid_losses.append(valid_loss.item())

            mean_train_loss = np.mean(train_losses)
            mean_valid_loss = np.mean(valid_losses)
            print(
                f"Epoch #{epoch+1}: train_loss={mean_train_loss:.3f}, valid_loss={mean_valid_loss:.3f}"
            )
            if mean_valid_loss < min_valid_loss:
                self.optimal_state = self.model_state
                min_valid_loss = mean_valid_loss
                patience = 0
            else:
                patience += 1

            if patience > max_patience:
                break

    def save(self, fname: str, cnt: int):
        checkpoints.save_checkpoint(fname,
                                    self.optimal_state,
                                    1,
                                    prefix="prior_",
                                    keep=20,
                                    overwrite=True)

    def load(self, ckpt_dir, step):
        self.model_state = checkpoints.restore_checkpoint(
            ckpt_dir=ckpt_dir,
            target=self.model_state,
            step=1,
            prefix="prior_")
