import tensorflow as tf
import tree

from .base_algorithm import BaseAlgorithm

DEFAULT_OPTIMIZER = {
    'class_name': 'Adam',
    'config': {}
}


class GTDBase(BaseAlgorithm):
    def __init__(self,
                 V_theta,
                 alpha=0.5,
                 beta=0.05,
                 gamma=0.9,
                 num_weight_steps=1,
                 optimizer_params=None):
        self._alpha = alpha
        self._beta = beta
        self._gamma = gamma
        self._num_weight_steps = num_weight_steps
        optimizer_params = optimizer_params or DEFAULT_OPTIMIZER
        assert 'learning_rate' not in optimizer_params

        self.V_theta = V_theta
        self._V_theta_optimizer = tf.optimizers.get({
            'class_name': optimizer_params['class_name'],
            'config': {
                **optimizer_params['config'],
                'learning_rate': alpha,
            },
        })

        self.weights = tree.map_structure(
            lambda x: tf.Variable(tf.random.uniform(tf.shape(x))),
            self.V.trainable_variables)

        self._weight_optimizer = tf.optimizers.get({
            'class_name': optimizer_params['class_name'],
            'config': {
                **optimizer_params['config'],
                'learning_rate': beta,
            },
        })

    @property
    def V(self):
        return self.V_theta

    @tf.function(experimental_relax_shapes=True)
    def phi(self, inputs):
        """Compute linearized features."""
        with tf.GradientTape() as tape:
            outputs = self.V.values(inputs)

        jacobians = tape.jacobian(
            outputs,
            self.V.trainable_variables,
            unconnected_gradients=tf.UnconnectedGradients.ZERO)

        phi = tree.map_structure(
            lambda jacobian, variable: tf.reduce_sum(
                jacobian,
                axis=tf.range(1, tf.rank(jacobian) - tf.rank(variable))),
            jacobians,
            self.V.trainable_variables)

        return phi


class GTD2(GTDBase):
    @tf.function(experimental_relax_shapes=True)
    def update_V(self, state_0s, actions, state_1s, rewards, terminals, rhos):
        rewards = tf.cast(rewards, self.V.model.dtype)

        target = rewards + self._gamma * self.V.values(state_1s)
        delta = target - self.V.values(state_0s)

        # Fast update:
        phi_s0 = self.phi(tf.convert_to_tensor(state_0s))
        phi_s0 = tf.concat(tree.map_structure(
            lambda x: tf.reshape(x, (tf.shape(x)[0], -1)),
            phi_s0), axis=-1)

        def train_weights(i):
            weights = tf.concat(
                tree.map_structure(
                    lambda x: tf.reshape(x, [-1]), self.weights),
                axis=-1)
            phi_s0_w = tf.einsum('bi,i->b', phi_s0, weights)[..., None]

            weight_updates = -1.0 * tf.reduce_sum((
                (delta - phi_s0_w) * phi_s0
            ), axis=0)
            flat_weight_updates = tf.split(
                weight_updates, tree.map_structure(tf.size, self.weights))
            reshaped_weight_updates = tree.map_structure(
                lambda flat_update, shape: tf.reshape(flat_update, shape),
                flat_weight_updates,
                tree.map_structure(tf.shape, self.weights))

            self._weight_optimizer.apply_gradients(zip(
                reshaped_weight_updates, self.weights))

            return tf.reduce_mean(weight_updates)

        weight_losses = tf.map_fn(
            train_weights,
            tf.range(self._num_weight_steps),
            dtype=tf.float32,
            parallel_iterations=1)

        # Slow update:
        with tf.autodiff.ForwardAccumulator(
                self.V.trainable_variables, self.weights) as acc:
            phi_s0 = self.phi(tf.convert_to_tensor(state_0s))

        jacobian_2_V_theta_s0_w = acc.jvp(
            phi_s0, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        del acc
        jacobian_2_V_theta_s0_w_flat = tf.concat(tree.map_structure(
            lambda jacobian: tf.reshape(jacobian, (tf.shape(jacobian)[0], -1)),
            jacobian_2_V_theta_s0_w), axis=1)

        phi_s0 = tf.concat(tree.map_structure(
            lambda x: tf.reshape(x, (tf.shape(x)[0], -1)),
            phi_s0), axis=-1)
        weights = tf.concat(
            tree.map_structure(
                lambda x: tf.reshape(x, [-1]), self.weights),
            axis=-1)
        phi_s0_w = tf.einsum('bi,i->b', phi_s0, weights)[..., None]

        phi_s1 = self.phi(state_1s)
        phi_s1 = tf.concat(tree.map_structure(
            lambda x: tf.reshape(x, (tf.shape(x)[0], -1)),
            phi_s1), axis=-1)

        h_0 = (delta - phi_s0_w) * jacobian_2_V_theta_s0_w_flat

        V_updates = -1.0 * tf.reduce_sum((
            (phi_s0 - self._gamma * phi_s1) * phi_s0_w - h_0
        ), axis=0)
        flat_V_updates = tf.split(
            V_updates, tree.map_structure(tf.size, self.V.trainable_variables))
        reshaped_V_updates = tree.map_structure(
            lambda flat_update, shape: tf.reshape(flat_update, shape),
            flat_V_updates,
            tree.map_structure(tf.shape, self.V.trainable_variables))

        self._V_theta_optimizer.apply_gradients(zip(
            reshaped_V_updates, self.V.trainable_variables))

        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'weights'),
            self.weights)
        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'V'),
            self.V.trainable_variables)

        return {
            'weight_loss': tf.reduce_mean(weight_losses),
            'V_loss': tf.reduce_mean(V_updates),
        }


class TDC(GTDBase):
    @tf.function(experimental_relax_shapes=True)
    def update_V(self, state_0s, actions, state_1s, rewards, terminals, rhos):
        rewards = tf.cast(rewards, self.V.model.dtype)

        target = rewards + self._gamma * self.V.values(state_1s)
        delta = target - self.V.values(state_0s)

        # Fast update:
        phi_s0 = self.phi(tf.convert_to_tensor(state_0s))
        phi_s0 = tf.concat(tree.map_structure(
            lambda x: tf.reshape(x, (tf.shape(x)[0], -1)),
            phi_s0), axis=-1)

        def train_weights(i):
            weights = tf.concat(
                tree.map_structure(
                    lambda x: tf.reshape(x, [-1]), self.weights),
                axis=-1)
            phi_s0_w = tf.einsum('bi,i->b', phi_s0, weights)[..., None]

            weight_updates = -1.0 * tf.reduce_sum((
                (delta - phi_s0_w) * phi_s0
            ), axis=0)
            flat_weight_updates = tf.split(
                weight_updates, tree.map_structure(tf.size, self.weights))
            reshaped_weight_updates = tree.map_structure(
                lambda flat_update, shape: tf.reshape(flat_update, shape),
                flat_weight_updates,
                tree.map_structure(tf.shape, self.weights))

            self._weight_optimizer.apply_gradients(zip(
                reshaped_weight_updates, self.weights))

            return tf.reduce_mean(weight_updates)

        weight_losses = tf.map_fn(
            train_weights,
            tf.range(self._num_weight_steps),
            dtype=tf.float32,
            parallel_iterations=1)

        # Slow update:
        with tf.autodiff.ForwardAccumulator(
                self.V.trainable_variables, self.weights) as acc:
            phi_s0 = self.phi(tf.convert_to_tensor(state_0s))

        jacobian_2_V_theta_s0_w = acc.jvp(
            phi_s0, unconnected_gradients=tf.UnconnectedGradients.ZERO)
        del acc
        jacobian_2_V_theta_s0_w_flat = tf.concat(tree.map_structure(
            lambda jacobian: tf.reshape(jacobian, (tf.shape(jacobian)[0], -1)),
            jacobian_2_V_theta_s0_w), axis=1)

        phi_s0 = tf.concat(tree.map_structure(
            lambda x: tf.reshape(x, (tf.shape(x)[0], -1)),
            phi_s0), axis=-1)
        weights = tf.concat(
            tree.map_structure(
                lambda x: tf.reshape(x, [-1]), self.weights),
            axis=-1)
        phi_s0_w = tf.einsum('bi,i->b', phi_s0, weights)[..., None]

        phi_s1 = self.phi(state_1s)
        phi_s1 = tf.concat(tree.map_structure(
            lambda x: tf.reshape(x, (tf.shape(x)[0], -1)),
            phi_s1), axis=-1)

        h_0 = (delta - phi_s0_w) * jacobian_2_V_theta_s0_w_flat

        V_updates = -1.0 * tf.reduce_sum((
            delta * phi_s0 - self._gamma * phi_s1 * phi_s0_w - h_0
        ), axis=0)
        flat_V_updates = tf.split(
            V_updates, tree.map_structure(tf.size, self.V.trainable_variables))
        reshaped_V_updates = tree.map_structure(
            lambda flat_update, shape: tf.reshape(flat_update, shape),
            flat_V_updates,
            tree.map_structure(tf.shape, self.V.trainable_variables))

        self._V_theta_optimizer.apply_gradients(zip(
            reshaped_V_updates, self.V.trainable_variables))

        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'weights'),
            self.weights)
        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'V'),
            self.V.trainable_variables)

        return {
            'weight_loss': tf.reduce_mean(weight_losses),
            'V_loss': tf.reduce_mean(V_updates),
        }
