import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow_probability import distributions as tfd
from utils.utils_tools import static_scan
import numpy as np
import pathlib
import pickle
from utils.utils_tools import parse_layers
import math

epsilon = 1e-6
LOG_MAX = 2
LOG_MIN = -20


class BasicModel(tf.Module):
    def save(self, filename):
        values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
        with pathlib.Path(filename).open('wb') as f:
            pickle.dump(values, f)

    def load(self, filename):
        with pathlib.Path(filename).open('rb') as f:
            values = pickle.load(f)
        tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)


class Adam(tf.Module):
    def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'):
        self._name = name
        self._modules = modules
        self._clip = clip
        self._wd = wd
        self._wdpattern = wdpattern
        self._opt = tf.optimizers.Adam(lr)

    @property
    def variables(self):
        return self._opt.variables()

    def __call__(self, tape, loss):
        variables = [module.variables for module in self._modules]
        self._variables = tf.nest.flatten(variables)
        assert len(loss.shape) == 0, loss.shape
        grads = tape.gradient(loss, self._variables)
        norm = tf.linalg.global_norm(grads)
        if self._clip:
            grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
        self._opt.apply_gradients(zip(grads, self._variables))
        return norm


class ConvEncoder(BasicModel):
    def __init__(self, depth=32, act=tf.nn.relu):
        self._act = act
        self._depth = depth
        self._layers = []
        self._layers.append(tfkl.Conv2D(1 * self._depth, 4, strides=2, activation=self._act))
        self._layers.append(tfkl.Conv2D(2 * self._depth, 4, strides=2, activation=self._act))
        self._layers.append(tfkl.Conv2D(4 * self._depth, 4, strides=2, activation=self._act))
        self._layers.append(tfkl.Conv2D(8 * self._depth, 4, strides=2, activation=self._act))

    def __call__(self, obs):
        x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
        for index in range(len(self._layers)):
            x = self._layers[index](x)
        shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0)
        return tf.reshape(x, shape)


class ConvDecoder(BasicModel):
    def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)):
        self._act = act
        self._depth = depth
        self._shape = shape

        # layers
        self._h1 = tfkl.Dense(32 * self._depth, None)
        self._h2 = tfkl.Conv2DTranspose(4 * self._depth, 5, strides=2, activation=self._act)
        self._h3 = tfkl.Conv2DTranspose(2 * self._depth, 5, strides=2, activation=self._act)
        self._h4 = tfkl.Conv2DTranspose(self._depth, 6, strides=2, activation=self._act)
        self._h5 = tfkl.Conv2DTranspose(self._shape[-1], 6, strides=2, activation=None)

    def __call__(self, features):
        x = self._h1(features)
        x = tf.reshape(x, [-1, 1, 1, 32*self._depth])
        x = self._h2(x)
        x = self._h3(x)
        x = self._h4(x)
        x = self._h5(x)

        mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], axis=0))
        return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))


class DenseDecoder(BasicModel):
    def __init__(self, shape, num_layers, units, dist='normal', act=tf.nn.elu):
        self._shape = shape
        self._num_layers = num_layers
        self._units = units
        self._dist = dist
        self._act = act
        self._layers = []
        for index in range(num_layers-1):
            self._layers.append(tfkl.Dense(self._units, self._act))
        self._layers.append(tfkl.Dense(np.prod(self._shape)))

    def __call__(self, features):
        x = features
        for index in range(self._num_layers):
            x = self._layers[index](x)
        x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
        if self._dist == 'normal':
            return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
        if self._dist == 'binary':
            return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
        raise NotImplementedError(self._dist)


class BasicStochasticActor(BasicModel):
    def __init__(self, shape,  num_layers, units, max_action=1, act=tf.nn.relu):
        self._shape = shape     # output_dim
        self._num_layers = num_layers
        self._units = units
        self._act = act
        self._max_action = max_action
        self._layers = []
        for index in range(self._num_layers-1):
            self._layers.append(tfkl.Dense(self._units, activation=act))
        self._layers.append(tfkl.Dense(2*self._shape, activation=None))

    def __call__(self, x):
        for index in range(self._num_layers):
            x = self._layers[index](x)
        mean, log_std = tf.split(x, num_or_size_splits=2, axis=-1)
        log_std = tf.clip_by_value(log_std, clip_value_min=LOG_MIN, clip_value_max=LOG_MAX)
        std = tf.math.exp(log_std)
        action = mean + std * tf.random.normal(shape=std.shape)
        action = self._max_action * tf.tanh(action)
        return action

    def action_log_prob(self, x):
        for index in range(self._num_layers):
            x = self._layers[index](x)
        mean, log_std = tf.split(x, num_or_size_splits=2, axis=-1)
        log_std = tf.clip_by_value(log_std, clip_value_min=LOG_MIN, clip_value_max=LOG_MAX)
        std = tf.math.exp(log_std)
        x_t = mean + std * tf.random.normal(shape=std.shape)
        y_t = tf.tanh(x_t)
        action = self._max_action * y_t

        normal = tfd.Normal(mean, std)
        log_prob = normal.log_prob(x_t)
        log_prob -= tf.math.log(self._max_action * (1 - tf.math.pow(y_t, 2)) + epsilon)
        log_prob = tf.reduce_sum(log_prob, axis=1, keepdims=True)

        return action, log_prob


class BasicActor(BasicModel):
    def __init__(self, shape,  num_layers, units, max_action=1, act=tf.nn.relu):
        self._shape = shape     # output_dim
        self._num_layers = num_layers
        self._units = units
        self._act = act
        self._max_action = max_action
        self._layers = []
        for index in range(self._num_layers-1):
            self._layers.append(tfkl.Dense(self._units, activation=act))
        self._layers.append(tfkl.Dense(self._shape, activation=None))

    def __call__(self, x):
        for index in range(self._num_layers):
            x = self._layers[index](x)
        x = self._max_action * tf.tanh(x)
        return x


class DenseNetwork(BasicModel):
    def __init__(self, shape, num_layers, units, act=tf.nn.relu):
        self._shape = shape     # output dim
        self._num_layers = num_layers
        self._units = units
        self._act = act         # activation function
        self._layers = []
        for index in range(self._num_layers-1):
            self._layers.append(tfkl.Dense(self._units, activation=act))
        self._layers.append(tfkl.Dense(self._shape))

    def __call__(self, x):
        for index in range(self._num_layers):
            x = self._layers[index](x)
        return x


class RSSME(BasicModel):
    def __init__(self, stoch=30, deter=200, hidden=200, num_models=7, act=tf.nn.elu):
        super().__init__()
        self._activation = act
        self._stoch_size = stoch
        self._deter_size = deter
        self._hidden_size = hidden
        self._cell = tfkl.GRUCell(self._deter_size)
        self._k = num_models

        # observation layers
        self._obs1 = tfkl.Dense(self._hidden_size, self._activation)
        self._obs2 = tfkl.Dense(2 * self._stoch_size, None)

        # imagination layers
        self._img1 = tfkl.Dense(self._hidden_size, self._activation)
        self._img2 = []
        self._img3 = []
        for index in range(self._k):
            self._img2.append(tfkl.Dense(self._hidden_size, self._activation))
            self._img3.append(tfkl.Dense(2 * self._stoch_size, None))

    def initial(self, batch_size):
        dtype = prec.global_policy().compute_dtype
        return dict(
            mean=tf.zeros([batch_size, self._stoch_size], dtype),
            std=tf.zeros([batch_size, self._stoch_size], dtype),
            stoch=tf.zeros([batch_size, self._stoch_size], dtype),
            deter=self._cell.get_initial_state(None, batch_size, dtype))

    def observe(self, embed, action, state=None):
        if state is None:
            state = self.initial(tf.shape(action)[0])
        embed = tf.transpose(embed, [1, 0, 2])
        action = tf.transpose(action, [1, 0, 2])
        post, prior = static_scan(lambda prev, inputs: self.obs_step(prev[0], *inputs), (action, embed), (state, state))
        post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
        prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
        return post, prior

    def obs_step(self, prev_state, prev_action, embed):
        prior = self.img_step(prev_state, prev_action)
        x = tf.concat([prior['deter'], embed], -1)
        x = self._obs1(x)
        x = self._obs2(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_dist({'mean': mean, 'std': std}).sample()
        post = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']}
        return post, prior

    def img_step(self, prev_state, prev_action, k=None):
        if k is None:
            k = np.random.choice(self._k)
        x = tf.concat([prev_state['stoch'], prev_action], -1)
        x = self._img1(x)
        x, deter = self._cell(x, [prev_state['deter']])
        deter = deter[0]  # Keras wraps the state in a list.
        x = self._img2[k](x)
        x = self._img3[k](x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_dist({'mean': mean, 'std': std}).sample()
        prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter}
        return prior

    def get_dist(self, state):
        return tfd.MultivariateNormalDiag(state['mean'], state['std'])

    def imagine(self, action, state=None):
        if state is None:
            state = self.initial(tf.shape(action)[0])
        assert isinstance(state, dict), state
        action = tf.transpose(action, [1, 0, 2])
        prior = static_scan(self.img_step, action, state)
        prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
        return prior

    def get_feat(self, state):
        return tf.concat([state['stoch'], state['deter']], axis=-1)

    def get_feat_size(self):
        return self._stoch_size + self._deter_size


class DeterministicNN_IQN(BasicModel):
    """Deterministic NN Implementation for Implicit Quantile Network
    with continuous actions.
    Returns Q function given the triplet (state,action, tau) where tau
    is the confidence level.

    Parameters
    ----------
    dim_state: int
        dimension of state input to neural network.
    dim_action: int
        dimension of action input to neural network.
    layers_*: list of int, optional
        list of width of neural network layers, each separated with a
        'non_linearity' type non-linearity.
        *==state: layers mapping state input
        *==action: layers mapping action input
        *==f: layers mapping all 3 inputs together
    embedding_dim: dimension to map cat(state,action) to, and tau to.
    tau_embed_dim: int, optional, default 1
        if >1 map tau to a learned linear function of
        tau_embed_dim cosine basis functions of the form cos(pi*i*tau); where
        i = 1... tau_embed_dim. As in paper.

    biased_head: bool, optional, default = True
        flag that indicates if head of NN has a bias term or not.
    non_linearity: str, optional, default = 'ReLU'
        type of nonlinearity between layers

    tau: float [0,1], optional, default 1.0
        Regulates soft update of target parameters.
        % of new parameters used to update target parameters

     References
    ----------
    Will Dabney and Georg Ostrovski and David Silver and Rémi Munos
    Implicit Quantile Networks for Distributional Reinforcement Learning
    2018
    """

    def __init__(self, dim_state, dim_action,
                 layers_state: list = None,
                 layers_action: list = None,
                 layers_f: list = None,
                 embedding_dim=None,
                 tau_embed_dim=1,
                 biased_head=True,
                 non_linearity='ReLU',
                 tau=1.0):

        super().__init__()
        self.dim_state = dim_state
        self.dim_action = dim_action
        self.layers_state = layers_state or list()
        self.layers_action = layers_action or list()
        self.layers_f = layers_f or list()
        self.embedding_dim = embedding_dim
        self.tau_embed_dim = tau_embed_dim
        self.tau = tau

        # Map state:
        self.fc_state, state_out_dim = parse_layers(
            layers_state, self.dim_state, non_linearity, normalized=True)
        # Map action:
        self.fc_action, action_out_dim = parse_layers(
            layers_action, self.dim_action, non_linearity, normalized=True)

        self.fc_state_action, _ = parse_layers(
            self.embedding_dim,
            state_out_dim + action_out_dim,
            non_linearity,
            normalized=True)

        # Prepare to map with cosine basis functions
        if self.tau_embed_dim > 1:
            self.i_ = tf.constant(np.arange(tau_embed_dim), dtype=tf.float32)

        # Map tau to embedding_dim
        self.head_tau, _ = parse_layers(self.embedding_dim,
                                        tau_embed_dim, non_linearity,
                                        normalized=True)

        self.hidden_layers_f, in_dim = parse_layers(
            layers_f, self.embedding_dim, non_linearity, normalized=True)
        # Layer mapping to 1-dim value function. No non-linearity added.
        self.head = tf.keras.layers.Dense(1, use_bias=biased_head)


    def __call__(self, state, tau_quantile, action=None):
        """Execute forward computation of the Neural Network.

        Parameters
        ----------
        state: torch.Tensor
            Tensor of size [batch_size x dim_state]
        action: torch.Tensor
            Tensor of size [batch_size x dim_action]
        tau_quantile: torch.Tensor
        Tensor of size [batch_size x 1]

        Returns
        -------
        output: torch.Tensor
            [batch_size x 1] (Q_function for triplet (state,action, tau)
        """

        state_output = self.fc_state(state)  # [batch_size x state_layer]
        action_output = self.fc_action(action)  # [batch_size x action_layer]
        state_action_output = self.fc_state_action(
            tf.concat((state_output, action_output), axis=-1))
        # [batch_size x  embedding_dim]

        # Cosine basis functions of the form cos(pi*i*tau)
        if self.tau_embed_dim > 1:
            a = tf.math.cos(tf.constant([math.pi])*self.i_*tau_quantile)
        else:
            a = tau_quantile
        tau_output = self.head_tau(a)  # [batch_size x embedding_dim]
        output = self.hidden_layers_f(tf.math.multiply(state_action_output, tau_output))
        output = tf.reshape(self.head(output), shape=(-1, 1))

        return output

    def get_sampled_Z(self, state, confidences, action):
        """Runs IQN for K different confidence levels
        Parameters
        ----------
        state: torch.Tensor [batch_size x dim_state]
        confidences: torch.Tensor. [1 x K]
        Returns
        -------
        Z_tau_K: torch.Tensor [batch_size x K]

        """
        K = confidences.shape[0]  # number of confidence levels to evaluate
        batch_size = state.shape[0] if len(state.shape) > 1 else 1
        # Reorganize so that the NN runs per one quantile at a time. Repeat
        # all batch_size block "num_quantiles" times:
        # [batch_size * K, dim_state]
        x = tf.reshape(tf.stack([state for _ in range(K)]), shape=(K*batch_size, self.dim_state))
        # [batch_size * K, dim_state]
        a = tf.reshape(tf.stack([action for _ in range(K)]), shape=(K*batch_size, self.dim_action))
        y = tf.reshape(tf.stack([confidences for _ in range(batch_size)]), shape=(K*batch_size, 1))
        Z_tau_K = tf.reshape(self(state=x, tau_quantile=y, action=a), shape=(batch_size, K))
        return Z_tau_K

    @property
    def params(self):
        """Get iterator of NN parameters."""
        variables = [self.fc_state.trainable_variables, self.fc_action.trainable_variables,
                     self.fc_state_action.trainable_variables, self.head_tau.trainable_variables,
                     self.hidden_layers_f.trainable_variables, self.head.trainable_variables]
        return tf.nest.flatten(variables)

    def update_param(self, target_net):
        for source_weight, target_weight in zip(self.params, target_net.params):
            new_param_ = self.tau*tf.stop_gradient(target_weight) + \
                         (1 - self.tau)*tf.stop_gradient(source_weight)
            source_weight.assign(new_param_)


class QuantileMlp(BasicModel):
    def __init__(
            self,
            hidden_sizes,
            tau=0.005,
            embedding_size=64,
            num_quantiles=32,
            layer_norm=True,
            **kwargs,
    ):
        super().__init__()
        self.tau = tau
        self.layer_norm = layer_norm

        self.base_fc = tf.keras.Sequential()
        for next_size in hidden_sizes[:-1]:
            self.base_fc.add(tfkl.Dense(next_size))
            if layer_norm:
                self.base_fc.add(tfkl.LayerNormalization())
            self.base_fc.add(tfkl.ReLU())

        self.num_quantiles = num_quantiles
        self.embedding_size = embedding_size

        self.tau_fc = tf.keras.Sequential()
        self.tau_fc.add(tfkl.Dense(hidden_sizes[-1]))
        if self.layer_norm:
            self.tau_fc.add(tfkl.LayerNormalization())
        self.tau_fc.add(tfkl.Activation('sigmoid'))

        self.merge_fc = tf.keras.Sequential()
        self.merge_fc.add(tfkl.Dense(hidden_sizes[-1]))
        if self.layer_norm:
            self.merge_fc.add(tfkl.LayerNormalization())
        self.merge_fc.add(tfkl.ReLU())

        self.last_fc = tfkl.Dense(1)
        self.const_vec = tf.convert_to_tensor(np.arange(1, 1 + self.embedding_size), dtype=float)

    def __call__(self, state, action, tau):
        """
        Calculate Quantile Value in Batch
        tau: quantile fractions, (N, T)
        """
        h = self.base_fc(tf.concat([state, action], axis=-1))  # (N, C)

        x = tf.math.cos(tf.expand_dims(tau, axis=-1) * self.const_vec * np.pi)
        x = self.tau_fc(x)  # (N, T, C)

        h = tf.math.multiply(x, tf.expand_dims(h, axis=-2))
        h = self.merge_fc(h)  # (N, T, C)
        output = tf.squeeze(self.last_fc(h))  # (N, T)
        return output

    def penultimate_layer(self, state, action):
        """
        Calculate Quantile Value in Batch
        tau: quantile fractions, (N, T)
        """
        presum_tau = tf.zeros(len(action), self.num_quantiles) + 1. / self.num_quantiles
        tau = tf.math.cumsum(presum_tau, axis=-1)

        tau_hat = tf.zeros_like(tau)
        tau_hat[:, 0:1] = tau[:, 0:1] /2
        tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2.
        tau_hat = tf.stop_gradient(tau_hat)
        h = self.base_fc(tf.concat([state, action], axis=-1))
        x = tf.math.cos(tau_hat.unsqueeze(-1) * self.const_vec * np.pi)
        x = self.tau_fc(x)  # (N, T, C)

        h = tf.math.multiply(x, h.unsqueeze(-2))
        h = self.merge_fc(h)  # (N, T, C)
        return h

    @property
    def params(self):  #  do not call it 'parameters' (already a default func)
        """Get iterator of NN parameters."""
        variables = self.trainable_variables
        return tf.nest.flatten(variables)

    def update_param(self, target_net):
        for source_weight, target_weight in zip(self.params, target_net.params):
            new_param_ = self.tau*tf.stop_gradient(target_weight) + \
                         (1 - self.tau)*tf.stop_gradient(source_weight)
            source_weight.assign(new_param_)


