

from collections import OrderedDict
import tensorflow as tf
import tree


@tf.function(experimental_relax_shapes=True)
def batch_quadratic_form(W, inputs):
    W = tf.cast(tf.convert_to_tensor(W), tf.float32)
    inputs = tf.cast(tf.convert_to_tensor(inputs), tf.float32)
    return tf.einsum(
        '...bi,ji,...bj->...b', inputs, W, inputs
    )[..., tf.newaxis]


class OnlineUncertaintyModelV2(tf.keras.Model):
    """Algorithm 2."""
    def __init__(self, sigma_0=1.0):
        super(OnlineUncertaintyModelV2, self).__init__()
        self.sigma_0 = sigma_0

    def build(self, input_shapes):
        D = sum(input_shape[-1] for input_shape in tree.flatten(input_shapes))
        self.C_inverse = self.add_weight(
            'C_inverse',
            shape=(D, D),
            initializer=tf.initializers.identity(gain=-1.0 * self.sigma_0))
        self.rho = self.add_weight(
            'rho', shape=(1, D), initializer=tf.initializers.zeros)
        self.Delta_N = self.add_weight(
            'Delta_N', shape=(D, D), initializer=tf.initializers.zeros)
        self.N = self.add_weight(
            'N', shape=(), dtype=tf.int32, initializer=tf.initializers.zeros)

    def reset(self):
        self.C_inverse.assign(-1.0 * self.sigma_0 * tf.eye(*self.C_inverse.shape))
        self.rho.assign(tf.zeros_like(self.rho))
        self.Delta_N.assign(tf.zeros_like(self.Delta_N))
        self.N.assign(tf.zeros_like(self.N))

    @tf.function(experimental_relax_shapes=True)
    def call(self, inputs):
        return 0.0

    @tf.function(experimental_relax_shapes=True)
    def online_update(self, b_N, b_hat, b_N_next, gamma, reward, iw=1.0):
        N = tf.shape(b_N)[0]
        tf.debugging.assert_equal(N, 1)

        Delta = iw * (gamma * b_N_next - b_N)

        def update_C_inverse():
            Delta_C_inverse = tf.matmul(Delta, self.C_inverse)
            C_inverse_delta = - 1.0 * tf.matmul(
                tf.matmul(self.C_inverse, b_N, transpose_b=True),
                Delta_C_inverse
            ) / (1.0 + tf.matmul(Delta_C_inverse, b_N, transpose_b=True))

            C_inverse = self.C_inverse + C_inverse_delta
            return C_inverse

        C_inverse = update_C_inverse()
        self.C_inverse.assign(C_inverse)

        rho_delta = iw * b_N * reward
        self.rho.assign_add(rho_delta)

        self.N.assign_add(N)

        return True

    def get_diagnostics(self):
        diagnostics = OrderedDict((
            ('N', self.N.numpy()),
            ('epistemic_uncertainty', self(True).numpy()),
        ))
        return diagnostics


class OnlineUncertaintyModelV3(tf.keras.Model):
    """Algorithm 3."""
    def __init__(self, sigma_0=1.0):
        super(OnlineUncertaintyModelV3, self).__init__()
        self.sigma_0 = sigma_0

    def build(self, input_shapes):
        D = sum(input_shape[-1] for input_shape in tree.flatten(input_shapes))
        self.C_inverse = self.add_weight(
            'C_inverse',
            shape=(D, D),
            initializer=tf.initializers.identity(gain=-1.0 * self.sigma_0))
        self.rho = self.add_weight(
            'rho', shape=(1, D), initializer=tf.initializers.zeros)
        self.N = self.add_weight(
            'N', shape=(), dtype=tf.int32, initializer=tf.initializers.zeros)

    def reset(self):
        self.C_inverse.assign(tf.eye(*self.C_inverse.shape) * -1.0 * self.sigma_0)
        self.rho.assign(tf.zeros_like(self.rho))
        self.N.assign(tf.zeros_like(self.N))

    @tf.function(experimental_relax_shapes=True)
    def call(self, inputs):
        return 0.0

    @tf.function(experimental_relax_shapes=True)
    def online_update(self, b_N, b_hat, b_N_next, gamma, reward):
        N = tf.shape(b_N)[0]
        tf.debugging.assert_equal(N, 1)

        Delta = gamma * b_N_next - b_N

        def update_C_inverse():
            Delta_C_inverse = tf.matmul(Delta, self.C_inverse)
            C_inverse_delta = - 1.0 * tf.matmul(
                tf.matmul(self.C_inverse, b_N, transpose_b=True),
                Delta_C_inverse
            ) / (1.0 + tf.matmul(Delta_C_inverse, b_N, transpose_b=True))

            C_inverse = self.C_inverse + C_inverse_delta
            return C_inverse

        C_inverse = update_C_inverse()
        C_inverse_assign = self.C_inverse.assign(C_inverse)

        def update_C_inverse_2():
            # tf.print("self.N: ", self.N)
            C_inverse_nth_col = self.C_inverse[:, self.N]
            C_inverse_nth_row = self.C_inverse[self.N, :]

            # C_inverse_nth_column_row_outer_product = tf.einsum(
            #     'i,j->ij', C_inverse_nth_row, C_inverse_nth_col)
            C_inverse_nth_column_row_outer_product = tf.einsum(
                'i,j->ij', C_inverse_nth_col, C_inverse_nth_row)
            C_inverse_nth_diagonal = self.C_inverse[self.N, self.N]
            C_inverse_delta = - 1.0 * (
                (self.sigma_0 * C_inverse_nth_column_row_outer_product)
                / (1.0 + self.sigma_0 * C_inverse_nth_diagonal))

            C_inverse = self.C_inverse + C_inverse_delta
            return C_inverse

        rho_delta = b_N * reward
        self.rho.assign_add(rho_delta)

        with tf.control_dependencies([C_inverse_assign]):
            C_inverse = tf.cond(
                tf.less(self.N, tf.shape(self.C_inverse)[0]),
                update_C_inverse_2,
                lambda: self.C_inverse)

            C_inverse_assign = self.C_inverse.assign(C_inverse)

            with tf.control_dependencies([C_inverse_assign]):
                self.N.assign_add(N)

        tf.debugging.check_numerics(self.C_inverse, "C_inverse")
        tf.debugging.check_numerics(self.rho, "rho")

        return True

    def get_diagnostics(self):
        diagnostics = OrderedDict((
            ('N', self.N.numpy()),
            ('epistemic_uncertainty', self(True).numpy()),
        ))
        return diagnostics
