import tensorflow as tf
import numpy as np
from tensorflow_probability import distributions as tfd
from tensorflow.keras.mixed_precision import experimental as prec

from collections import deque 

import tools

class reverseAction(tools.Module):
    def __init__(self, hidden_depth, hidden_dim, joint_dim, action_size, dist=None, act=tf.nn.relu, min_std=1e-4, init_std=5, mean_scale=5):
        self._hidden_depth = hidden_depth
        self._hidden_dim = hidden_dim
        self._joint_dim = joint_dim
        self._action_size = action_size
        self._dist = dist
        self._act = act
        self._min_std = min_std
        self._init_std = init_std
        self._mean_scale = mean_scale

    def __call__(self, feature_1, feature_2):
        x1 = feature_1
        x2 = feature_2
        x = self.get('concat', tf.keras.layers.Concatenate)([x1, x2])
        # for i in range(self._hidden_depth):
        #     x1 = self.get(f'h1-{i}', tf.keras.layers.Dense, self._hidden_dim, self._act)(x1)
        #     x2 = self.get(f'h2-{i}', tf.keras.layers.Dense, self._hidden_dim, self._act)(x2)
        # x = self.get('concat', tf.keras.layers.Concatenate)([x1, x2])
        x = self.get(f'h1', tf.keras.layers.Dense, self._joint_dim, self._act)(x)
        x = self.get(f'dropout1', tf.keras.layers.Dropout, rate=0.5)(x)
        x = self.get(f'h2', tf.keras.layers.Dense, self._joint_dim, self._act)(x)
        x = self.get(f'dropout2', tf.keras.layers.Dropout, rate=0.5)(x)
        # try using dropout in the reverse actor
        if self._dist is None:
            x = self.get(f'hout', tf.keras.layers.Dense, self._action_size)(x)
            return x
        elif self._dist == 'tanh':
            raw_init_std = np.log(np.exp(self._init_std) - 1)
            x = self.get(f'hout', tf.keras.layers.Dense, 2 * self._action_size)(x)
            mean, std = tf.split(x, 2, -1)
            mean = self._mean_scale * tf.tanh(mean / self._mean_scale)
            std = tf.nn.softplus(std + raw_init_std) + self._min_std
            dist = tfd.Normal(mean, std)
            dist = tfd.TransformedDistribution(dist, tools.TanhBijector())
            dist = tfd.Independent(dist, 1)
            dist = tools.SampleDist(dist)
            return dist
        else:
            raise NotImplementedError(self._dist)
        return

class latentRSSM(tools.Module):
    def __init__(self, stoch_size, deter_size=200, hidden_size=200, activation=tf.nn.relu, num_sample_points=1):
        super().__init__()
        self._activation = activation
        self._stoch_size = stoch_size
        # self._horizon = horizon
        self._deter_size = deter_size
        self._hidden_size = hidden_size
        self._gru = tf.keras.layers.GRUCell(self._deter_size)
        self._num_sample_points = num_sample_points
        # self._reverse_actor = reverse_actor
    
    def initial(self, batch_size):
        dtype = prec.global_policy().compute_dtype
        return dict(
            mean=tf.zeros([batch_size, self._stoch_size], dtype), 
            std=tf.zeros([batch_size, self._stoch_size], dtype), 
            stoch=tf.zeros([batch_size, self._stoch_size], dtype), 
            internal=self._gru.get_initial_state(None, batch_size, dtype))
    
    @tf.function
    def observe(self, embed, action, k=1, state=None, num_samples=1):
        if state is None:
            state = self.initial(tf.shape(action)[0])
        # batch_length = tf.shape(embed)[1].numpy()
        embed = tf.transpose(embed, [1, 0, 2])
        # action = tools.k_steps_action(action, k)
        action = tf.transpose(action, [1, 0, 2]) # (L, N, A)
        assert isinstance(k, int) 
        if k == 1:
            # this is the original dreamer version
            post, prior = tools.static_scan(
                lambda prev, inputs: self.obs_step(prev[0], *inputs), 
                (action, embed), 
                (state, state)
            )
            post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
            prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
            last_prior = None
        else:
            # execute this branch only if k > 1 （i.e., similar to SPR)
            '''
            input actions shape: (N, L, A) transformed to -? (L, N, A)
            we wish to turn this tensor into a (L-K, K, N, A) tensor, with each l-th (K, N, A) tensor represents the list of actions for K-step imagination
            for the l-th states (i.e., states[:, l, :])
            '''
            embed = embed[k:]
            action = tf.numpy_function(tools.k_steps_action, [action, k], prec.global_policy().compute_dtype) # (L-K, N, K, A)
            action = tf.stack(action, axis=0)
            # action = tools.k_steps_action(action, k)
            num_valid_states = len(action)
            # prior_list, post_list = deque(maxlen=num_valid_states), deque(maxlen=num_valid_states)
            prior_list, post_list, last_prior_list = [], [], []
            for i in range(num_valid_states):
                post, priors, last_prior = self.obs_at_k_step_multi_sample(state, action[i], embed, self._num_sample_points)

                if i >= k:
                    state = post_list[-k] # use available-computed posteriors for k-step onwards rollouts
                # post, priors = tools.static_scan(
                #     lambda prev, inputs: self.obs_at_k_step_multi_sample(prev[0], *inputs), 
                #     (action[i], embed), 
                #     # ({k: v[None, i] for k, v in state.items()}, {k: v[None, i] for k, v in state.items()})
                #     (state, state)
                # )
                post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()}
                priors = [{k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} for prior in priors]
                last_prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in last_prior.items()}
                # prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
                # priors contains the k prior distributions for i-th state k-step imagination
                prior_list.append(priors)
                post_list.append(post)
                last_prior_list.append(last_prior)
                
            post = {k: tf.stack([p[k] for p in post_list], axis=1) for k in post_list[0].keys()}
            prior = prior_list
            last_prior = {k: tf.stack([p[k] for p in post_list], axis=1) for k in post_list[0].keys()}
        return post, prior, last_prior
        # TODO: create tools.k_steps_action(action)
        # action = [tf.transpose(a, [1, 0, 2]) for a in action]

    def retrace_k_step_multi_sample(self, reverse_actor, priors, posts):
        num_valid_states = len(priors)
        K = len(priors[0]) # model-learning horizon + 1
        # prev_states = {k: tf.transpose(v, [1, 0, 2]) for k, v in posts.items()}
        retraced_states = deque(maxlen=(num_valid_states))
        for i in range(num_valid_states):
            prior = priors[i]
            prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
            prev_state = {k: v[:, [i], :] for k, v in posts}
            prev_state = {k: tf.transpose(v, [1, 0, 2]) for k, v in posts.items()}
            for j in reversed(range(1, K)):
                next_prior, prev_prior = {k: v[[j]] for k, v in prior.items()}, {k: v[[j-1]] for k, v in prior.items()}
                next_feature, prev_feature = self.get_feature(next_prior), self._get_feature(prev_prior)
                reverse_action = reverse_actor(next_feature, prev_feature)
                prev_state = self.imagine_step_multi_sample(prev_state, reverse_action, self._num_sample_points)
            
            retraced_states.append({k: tf.transpose(v, [1, 0, 2]) for k, v in prev_state.items()})
        return {k: tf.stack([p[k] for p in retraced_states], axis=1) for k in retraced_states[0].keys()}

    # @tf.function
    def retrace(self, prev_state, next_state, reverse_actor):
        # here we only consider one-step inference and retrace
        # more steps requires further look into
        next_state = {k: v[:, 1:, :] for k, v in next_state.items()}
        prev_state = {k: v[:, :-1, :] for k, v in prev_state.items()}

        next_feature, prev_feature = self.get_feature(next_state), self.get_feature(prev_state)
        # this is one option, i.e., using the feature
        # or we could use plain stoch (i.e., sampled latent feature vector)
        reverse_action = reverse_actor(next_feature, prev_feature)
        # reverse_action = tf.transpose(reverse_action)
        pred_prev_state = self.imagine_step(next_state, reverse_action)
        return pred_prev_state

    @tf.function
    def imagine(self, action, state=None):
        if state is None:
            state = self.initial(tf.shape(action)[0])
        assert isinstance(state, dict), state
        # action = tf.expand_dims(action, 1)
        action = tf.transpose(action, [1, 0, 2])
        prior = tools.static_scan(self.imagine_step, action, state)
        prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
        return prior
    
    @tf.function
    def obs_at_k_step(self, prev_state, actions, embed):
        priors = self.imagine_k_step_forward(prev_state, actions)
        # last_prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in priors[-1].items()}
        last_prior = priors[-1]
        x = tf.concat([last_prior['internal'], embed], -1)
        x = self.get('obs1', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
        x = self.get('obs2', tf.keras.layers.Dense, 2*self._stoch_size, None)(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_distribution({'mean': mean, 'std': std}).sample()
        post = {'mean': mean, 'std': std, 'stoch': stoch, 'internal': last_prior['internal']}
        return post, priors

    @tf.function
    def obs_at_k_step_multi_sample(self, prev_state, actions, embed, num_samples=20):
        # note that we do not need multiple samples for computing the posterior since we do have external ground-truth inputs
        # but in principle, using samples will also improve the accuracy of posterior inference
        priors = self.imagine_k_step_forward_multi_sample(prev_state, actions, num_samples)
        # last_prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in priors[-1].items()}
        last_prior = priors[-1]
        x = tf.concat([last_prior['internal'], embed], -1)
        x = self.get('obs1', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
        x = self.get('obs2', tf.keras.layers.Dense, 2*self._stoch_size, None)(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_distribution({'mean': mean, 'std': std}).sample()
        post = {'mean': mean, 'std': std, 'stoch': stoch, 'internal': last_prior['internal']}
        return post, priors, last_prior
    
    @tf.function
    def imagine_k_step_forward_multi_sample(self, prev_state, actions, num_samples=20):
        K = tf.shape(actions)[0] # K is the model-learning imagination horizon
        priors = []
        prior = prev_state
        priors.append(prior) # add the most original (even completely empty) prior to the list of priors
        for i in range(K):
            prior = self.imagine_step_multi_sample(prior, actions[i], num_samples)
            priors.append(prior)
        return priors
    
    @tf.function
    def imagine_step_multi_sample(self, prev_state, prev_action, num_samples):
        dist = self.get_distribution(prev_state)
        prev_internal = [prev_state['internal']]
        mean, std = 0, 0
        internal = 0
        for n in range(num_samples):
            x = tf.concat([dist.sample(), prev_action], -1)
            x = tf.reshape(x, (1, -1))
            x = self.get('imagine1', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
            x, internal_temp = self._gru(x, prev_internal)
            internal_temp = internal_temp[0]
            internal += internal_temp
            x = self.get('imagine2', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
            x = self.get('imagine3', tf.keras.layers.Dense, self._stoch_size*2, None)(x)
            mean_temp, std_temp = tf.split(x, 2, -1)
            mean += mean_temp
            std += tf.nn.softplus(std_temp)+0.1
        mean = mean / num_samples
        std = std / num_samples
        internal = internal / num_samples
        stoch = self.get_distribution({'mean': mean, 'std': std}).sample()
        prior = {'mean': mean, 'std': std, 'stoch': stoch, 'internal': internal}
        return prior

    @tf.function
    def imagine_k_step_forward(self, prev_state, actions):
        # if len(tf.shape(actions)) == 3:
        #     actions = tf.expand_dims(actions, 0)
        # k = tf.shape(actions)[0]
        # assert len(tf.shape(actions)) == 4
        K = tf.shape(actions)[0]
        print(K)
        print(tf.shape(actions))
        priors = []
        prior = prev_state
        # prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prev_state.items()}
        for i in range(K):
            print(tf.shape(actions[i]))
            prior = self.imagine_step(prior, actions[i])
            # prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
            # x = tf.concat([prior['stoch'], actions[i]], -1)
            # x = self.get('imagine1', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
            # x, internal = self._gru(x, [prior['internal']])
            # x = self.get('imagine2', tf.keras.layers.Dense, self._hidden_size, self.activation)(x)
            # x = self.get('imagine3', tf.keras.layers.Dense, self._stoch_size*2, None)(x)
            # mean, std = tf.split(x, 2, -1)
            # std = tf.nn.softplus(std) + 0.1
            # stoch = self.get_distribution({'mean': mean, 'std': std})
            # prior = {'mean': mean, 'std': std, 'stoch': stoch, 'internal': internal}
            # prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()}
            # store everything in the original shape
            priors.append(prior)
        return priors
    
    @tf.function
    def obs_step(self, prev_state, prev_action, embed):
        num_samples=self._num_sample_points
        prior = self.imagine_step_multi_sample(prev_state, prev_action, num_samples)
        x = tf.concat([prior['internal'], embed], -1)
        x = self.get('obs1', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
        x = self.get('obs2', tf.keras.layers.Dense, self._stoch_size*2, None)(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_distribution({'mean': mean, 'std': std}).sample()
        post = {'mean': mean, 'std': std, 'stoch': stoch, 'internal': prior['internal']}
        return post, prior

    @tf.function
    def imagine_step(self, prev_state, prev_action):
        x = tf.concat([prev_state['stoch'], prev_action], -1)
        # print(x)
        # print(tf.shape(x))
        # x = tf.reshape(x, (1, 2))
        x = self.get('imagine1', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
        x, internal = self._gru(x, [prev_state['internal']])
        internal = internal[0]
        x = self.get('imagine2', tf.keras.layers.Dense, self._hidden_size, self._activation)(x)
        x = self.get('imagine3', tf.keras.layers.Dense, self._stoch_size*2, None)(x)
        mean, std = tf.split(x, 2, -1)
        std = tf.nn.softplus(std) + 0.1
        stoch = self.get_distribution({'mean': mean, 'std': std}).sample()
        prior = {'mean': mean, 'std': std, 'stoch': stoch, 'internal': internal}
        return prior

    def get_distribution(self, state):
        return tfd.MultivariateNormalDiag(state['mean'], state['std'])

    def get_feature(self, state):
        return tf.concat([state['stoch'], state['internal']], -1)

def bisimulation_metric(state, pred_state, c=1):
    # note that here we assume both the reward and transition models are diagonal Gaussian
    z1, z2 = state['stoch'], pred_state['stoch']
    rdist_1, rdist_2 = state['reward_dist'], pred_state['reward_dist']
    mean_1, mean_2 = state['dynamics_mean'], pred_state['dynamics_mean']
    std_1, std_2 = state['dynamics_std'], pred_state['dynamics_std']
    W2_dist = tf.reduce_mean(tf.reduce_sum(tf.math.square(mean_1-mean_2), axis=-1)) + tf.reduce_mean(tf.reduce_sum(tf.math.square(std_1-std_2), axis=-1))
    bisim_metric = tf.square(tf.reduce_mean(tf.reduce_sum(tf.abs(z1-z2), axis=-1)) - tf.reduce_mean(tfd.kl_divergence(rdist_1, rdist_2)) - 0.99*W2_dist)
    return bisim_metric

def gaussian_logprob(noise, log_std):
    # compute Gaussian log probability
    # again, here we assume that the gaussian is diagonal
    residual = tf.reduce_sum((-0.5*tf.math.square(noise)-log_std), axis=-1, keepdims=True)
    return residual - 0.5*np.log(2*np.pi) * tf.shape(residual)[-1]

def squash(mu, pi, log_pi):
    # applying squashing function
    # this is due to the fact that a gaussian has infinite support, but we want finite action selection, hence applying a tanh squashing function to  the 
    # gaussian random samples, and need to change the probabilities accordingly
    # hence pi(a|s) = mu(u|s) * |det(da/du)|^{-1}
    # a = tanh(u), hence da/du = diag(1-tanh^2(u))
    # following this, we have log(pi(a|s)) = log(mu(u|s)) - \sum_{i=1}^{D}log(1-tanh^2(u))
    mu = tf.math.tanh(mu)
    if pi is not None:
        pi = tf.math.tanh(pi)
    if log_pi is not None:
        log_pi -= tf.reduce_sum(tf.math.log(tf.nn.relu(1-tf.math.square(pi)) + 1e-6), axis=-1, keepdims=True)
    return mu, pi, log_pi

class SAC_actor(tools.Module):
    def __init__(self, hidden_depth, hidden_dim, action_size, log_std_min, log_std_max, activation=tf.nn.elu):
        # super().__init__()
        self._hidden_depth = hidden_depth
        self._hidden_dim = hidden_dim
        self._action_size = action_size
        self._activation = activation
        self._log_std_min = log_std_min
        self._log_std_max = log_std_max
        self._outputs = dict()

    def __call__(self, features, compute_pi=True, compute_log_pi=True):
        x = features
        for i in range(self._hidden_depth):
            x = self.get(f'h{i}', tf.keras.layers.Dense, self._hidden_dim, self._activation, kernel_initializer=tf.keras.initializers.Orthogonal())(x)
        x = self.get(f'hout', tf.keras.layers.Dense, self._action_size*2, None)(x)
        mu, log_std = tf.split(x, 2, -1)
        log_std = tf.math.tanh(log_std) # apply the squashing function
        log_std = self.log_std_min + 0.5*(self.log_std_max-self.log_std_min)*(log_std+1) # normalise the log-std within (log_std_min, log_std_max)
        self._outputs['mu'] = mu
        self._outputs['std'] = tf.math.exp(log_std)
        noise = tf.random.normal(tf.shape(log_std))
        if compute_pi:
            pi = noise * tf.math.exp(log_std) + mu
        else:
            pi = None
        if compute_log_pi:
            log_pi = gaussian_logprob(noise, log_std)
        mu, pi, log_pi = squash(mu, pi, log_pi)
        return mu, pi, log_pi, log_std

class Qnetwork(tools.Module):
    def __init__(self, hidden_depth, hidden_dim, activation=tf.nn.relu):
        self._hidden_depth = hidden_depth
        self._hidden_dim = hidden_dim
        self._activation = activation
    
    def __call__(self, features):
        x = features
        for i in range(self._hidden_depth):
            x = self.get(f'h{i}', tf.keras.layers.Dense, self._hidden_dim, self._activation, kernel_initializer=tf.keras.initializers.Orthogonal())(x)
            # use orthogonal initilisation in critic layer
        x = self.get(f'hout', tf.keras.layers.Dense, 1, None)(x)
        return x

class SAC_critic(tools.Module):
    def __init__(self, hidden_depth, hidden_dim, activation=tf.nn.relu):
        self._hidden_depth = hidden_depth
        self._hidden_dim = hidden_dim
        self._activation = activation
        self._Q1 = Qnetwork(hidden_depth, hidden_dim, activation)
        self._Q2 = Qnetwork(hidden_depth, hidden_dim, activation)
        self._outputs = dict()

    def __call__(self, features, actions):
        x = self.get('concat', tf.keras.layers.Concatenate)([features, actions])
        q1 = self._Q1(x)
        q2 = self._Q2(x)
        self._outputs['q1'] = q1
        self._outputs['q2'] = q2

def soft_update_params(net, target_net, tau):
    # (soft) updating the parameters, degree of updates towards targets depend on tau
    for var, target_var in zip(net.trainable_variables, target_net.trainable_variables):
        target_var.assign(tau * var + (1 - tau) * target_var)
