import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as tfkl
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec

import tools
from tfdiffeq import odeint

class OdeFunc(tfkl.Layer):
    def __init__(self, hidden_dim = 300, output_dim = 200):
        super(OdeFunc, self).__init__()
        self._hidden_dim = hidden_dim
        self._output_dim = output_dim
    
        self._layer1 = tfkl.Dense(self._hidden_dim, 'tanh')
        self._layer2 = tfkl.Dense(self._hidden_dim, 'tanh')
        self._layer3 = tfkl.Dense(self._output_dim)

        self._layer_compat = tfkl.Dense(self._output_dim, 'tanh')


        self._model = tf.keras.Sequential([
            tfkl.Dense(self._hidden_dim, 'tanh'),
            tfkl.Dense(self._hidden_dim, 'tanh'),
            tfkl.Dense(self._output_dim)
        ])

    #@tf.function
    def call(self, t, y):
        y = self._layer1(y)
        y = self._layer2(y)
        y = self._layer3(y)
        return y

class RSSM(tools.Module):

    def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu, actdim=None):
        super().__init__()
        self._activation = act
        self._stoch_size = stoch
        self._deter_size = deter
        self._hidden_size = hidden
        self._actdim = actdim
        self._cell = tfkl.GRUCell(self._deter_size)
        self._odehelper = OdeFunc(300, self._deter_size)

    def initial(self, batch_size, calprob=False):
        dtype = prec.global_policy().compute_dtype
        if calprob:
            return dict(
                mean=tf.zeros([batch_size, self._stoch_size], dtype),
                std=tf.zeros([batch_size, self._stoch_size], dtype),
                stoch=tf.zeros([batch_size, self._stoch_size], dtype),
                deter=tf.zeros([batch_size, self._deter_size], dtype),
                log_prob=tf.zeros([batch_size], dtype))
        else:
            return dict(
                mean=tf.zeros([batch_size, self._stoch_size], dtype),
                std=tf.zeros([batch_size, self._stoch_size], dtype),
                stoch=tf.zeros([batch_size, self._stoch_size], dtype),
                deter=tf.zeros([batch_size, self._deter_size], dtype))

    #@tf.function
    def observe(self, embed, action, state=None):
        if state is None:
            state = self.initial(tf.shape(action)[0])
        embed = tf.transpose(embed, [1, 0, 2])
        action = tf.transpose(action, [1, 0, 2])
        post, prior = tools.static_scan(
            lambda prev, inputs: self.obs_step(prev[0], *inputs),
            (action, embed), (state, state))
        post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
        prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
        return post, prior

    #@tf.function
    def imagine(self, action, state=None):
        if state is None:
            state = self.initial(tf.shape(action)[0])
        assert isinstance(state, dict), state
        action = tf.transpose(action, [1, 0, 2])
        prior = tools.static_scan(self.img_step, action, state)
        prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
        return prior

    def get_feat(self, state):
        return tf.concat([state['stoch'], state['deter']], -1)

    def get_dist(self, state):
        return tfd.MultivariateNormalDiag(state['mean'], state['std'])

    #@tf.function
    def obs_step(self, prev_state, prev_action, embed):
        prior = self.img_step(prev_state, prev_action)
        x = tf.concat([prior['deter'], embed], -1)
        x = self.get('obs1', tfkl.Dense, self._hidden_size, self._activation)(x)
        x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_dist({'mean': mean, 'std': std}).sample()
        post = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']}
        return post, prior

    #@tf.function
    def img_step(self, prev_state, prev_action, calprob=False):
        x = tf.concat([prev_state['stoch'], prev_action], -1)
        x = self.get('img1', tfkl.Dense, self._hidden_size, self._activation)(x)
        x = tf.concat([x, prev_state['deter']], -1)
        x = self.get('img1_2', tfkl.Dense, self._hidden_size, self._activation)(x)
        integration_time = tf.constant([0., 1.], dtype = tf.float32)
        deter = odeint(self._odehelper, x, integration_time, rtol = 1e-5, atol = 1e-5)[1]
        x = deter
        x = self.get('img2', tfkl.Dense, self._hidden_size, self._activation)(x)
        x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_dist({'mean': mean, 'std': std}).sample()
        prob = self.get_dist({'mean': mean, 'std': std}).log_prob(stoch)
        if calprob:
            prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter, 'prob': prob}
        else:
            prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter}
        return prior

class ConvEncoder(tools.Module):

    def __init__(self, depth=32, act=tf.nn.relu):
        self._act = act
        self._depth = depth

    def __call__(self, obs):
        kwargs = dict(strides=2, activation=self._act)
        x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:]))
        x = self.get('h1', tfkl.Conv2D, 1 * self._depth, 4, **kwargs)(x)
        x = self.get('h2', tfkl.Conv2D, 2 * self._depth, 4, **kwargs)(x)
        x = self.get('h3', tfkl.Conv2D, 4 * self._depth, 4, **kwargs)(x)
        x = self.get('h4', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x)
        shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0)
        return tf.reshape(x, shape)


class ConvDecoder(tools.Module):

    def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)):
        self._act = act
        self._depth = depth
        self._shape = shape

    def __call__(self, features):
        kwargs = dict(strides=2, activation=self._act)
        x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features)
        x = tf.reshape(x, [-1, 1, 1, 32 * self._depth])
        x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, 5, **kwargs)(x)
        x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, 5, **kwargs)(x)
        x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, 6, **kwargs)(x)
        x = self.get('h5', tfkl.Conv2DTranspose, self._shape[-1], 6, strides=2)(x)
        mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
        return tfd.Independent(tfd.Normal(mean, 1), len(self._shape))


class DenseDecoder(tools.Module):

    def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu):
        self._shape = shape
        self._layers = layers
        self._units = units
        self._dist = dist
        self._act = act

    def __call__(self, features):
        x = features
        for index in range(self._layers):
            x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
        x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
        x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
        if self._dist == 'normal':
            return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
        if self._dist == 'binary':
            return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
        raise NotImplementedError(self._dist)


class DenseDecoderDouble(tools.Module):

    def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu):
        self._shape = shape
        self._layers = layers
        self._units = units
        self._dist = dist
        self._act = act

    def __call__(self, features, entropy):
        x = features
        for index in range(self._layers):
            x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
        x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x)
        x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0))
        x += entropy
        if self._dist == 'normal':
            return tfd.Independent(tfd.Normal(x, 1), len(self._shape))
        if self._dist == 'binary':
            return tfd.Independent(tfd.Bernoulli(x), len(self._shape))
        raise NotImplementedError(self._dist)


class ActionDecoder(tools.Module):

    def __init__(
            self, size, layers, units, dist='tanh_normal', act=tf.nn.elu,
            min_std=1e-4, init_std=5, mean_scale=5):
        self._size = size
        self._layers = layers
        self._units = units
        self._dist = dist
        self._act = act
        self._min_std = min_std
        self._init_std = init_std
        self._mean_scale = mean_scale

    def __call__(self, features):
        raw_init_std = np.log(np.exp(self._init_std) - 1)
        x = features
        for index in range(self._layers):
            x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x)
        if self._dist == 'tanh_normal':
            # https://www.desmos.com/calculator/rcmcf5jwe7
            x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x)
            mean, std = tf.split(x, 2, -1)
            mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
            std = tf.nn.softplus(std + raw_init_std) + self._min_std
            dist = tfd.Normal(mean, std)
            dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
            dist = tfd.Independent(dist, 1)
            dist = tools.SampleDist(dist)
        elif self._dist == 'onehot':
            x = self.get(f'hout', tfkl.Dense, self._size)(x)
            dist = tools.OneHotDist(x)
        else:
            raise NotImplementedError(self._dist)
        return dist


class DictEncoder(tools.Module):

    def __init__(self, hidden_size=300, key_size=128, act=tf.nn.relu):
        self._act = act
        self._key_size = key_size
        self._hidden_size = hidden_size

    def __call__(self, features):
        x = features
        x = self.get('h1', tfkl.Dense, self._key_size, self._act)(x)
        x = self.get('h2', tfkl.Dense, self._key_size, self._act)(x)

        return x


class Cache(tools.Module):

    def __init__(self, total_batch_size, number_nn, key_size, replica_num):
        self.total_batch_size = total_batch_size
        self.number_nn = number_nn
        self.key_size = key_size
        self.replica_num = replica_num
        self.embed = []
        self.value = []

        # self.embed=tf.Variable(tf.zeros((total_batch_size, key_size),dtype=tf.dtypes.float16),
        #        name='dndemb',trainable=True,dtype=tf.dtypes.float16)
        # self.value=tf.Variable(tf.zeros((total_batch_size, number_nn),dtype=tf.dtypes.float16),
        #        name='dndv', trainable=True,dtype=tf.dtypes.float16)

    def __call__(self, embed, value, replica_ind):
        if self.number_nn == 1 and len(embed.shape) == 2:
            embed = np.expand_dims(embed, 1)
            value = np.expand_dims(value, 1)
        for i in range(self.replica_num):
            self.embed.append(self.get('dndembed' + str(i), tf.Variable,
                                       tf.zeros((self.total_batch_size, self.number_nn, self.key_size),
                                                dtype=tf.dtypes.float32), name='dndembed_' + str(i),
                                       aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA, trainable=True,
                                       dtype=tf.dtypes.float32))
            self.value.append(
                self.get('dndvalue' + str(i), tf.Variable, tf.zeros((self.total_batch_size, self.number_nn),
                                                                    dtype=tf.dtypes.float32), name='dndvalue_' + str(i),
                         aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA, trainable=True,
                         dtype=tf.dtypes.float32))
            # print('))))))))))))))))))))))))))')
            # print(self.embed[i].name, self.value[i].name)
        # self.embed = self.get('dndembed', tf.Variable, tf.zeros((self.total_batch_size, self.number_nn,self.key_size),dtype=tf.dtypes.float32),
        #        aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA,trainable=True,dtype=tf.dtypes.float32)
        # self.value = self.get('dndvalue', tf.Variable, tf.zeros((self.total_batch_size, self.number_nn),dtype=tf.dtypes.float32),
        #        aggregation=tf.compat.v1.VariableAggregation.ONLY_FIRST_REPLICA,trainable=True,dtype=tf.dtypes.float32)
        self.embed[replica_ind].assign(embed)
        self.value[replica_ind].assign(value)

        return self.embed[replica_ind], self.value[replica_ind]
