import random
import time
import tensorflow as tf
import larq as lq
import numpy as np
import itertools
import math

from keras import backend as K
# from keras.utils.generic_utils import get_custom_objects
import scipy.stats
from copy import deepcopy
from interval_models import IPOMDP
from models import MDPWrapper, POMDPWrapper
import utils as ut
from fsc import FiniteMemoryPolicy

from sklearn.cluster import KMeans

@tf.function
def flatter_tanh(x):
    y = 1.5 * K.tanh(x) + 0.5 * K.tanh(3 * x)
    return y

class Net(tf.keras.Model):
    def __init__(self, instance, cfg, state_machine = 'moore'):
        super().__init__(name='NET')
        assert state_machine.lower() in ['mealy', 'moore']
        self.mealy = not state_machine.lower() == 'moore'
        self.instance, self.cfg = instance, cfg
        self.input_dim, self.output_dim = instance.input_dim, instance.output_dim
        self.memory_dim = cfg['a_memory_dim']
        self.bottleneck_dim = cfg['bottleneck_dim']
        self.qbn_gru_rnn = GRUActor(instance, cfg)
        self.actor = tf.keras.Sequential([
            tf.keras.layers.Dense(32, 'tanh', name='policy1'),
            tf.keras.layers.Dense(32, 'tanh', name='policy2'),
            tf.keras.layers.Dense(self.output_dim, 'softmax', name='policy3')
        ], name='actor')
        self.build(input_shape = (None, None, instance.input_dim if cfg['one_hot_obs'] else 1))
        if cfg['a_loss'].lower() == 'cce':
            self.loss_f = tf.keras.losses.CategoricalCrossentropy(from_logits=False,label_smoothing=0)
        elif cfg['a_loss'].lower() == 'kld':
            self.loss_f = tf.keras.losses.KLDivergence()
        else:
            if cfg['a_loss'].lower() != 'mse':
                print("Unkown action loss specified, defaulting to MSE.")
            self.loss_f = 'mse'
        self.compile(loss = self.loss_f, optimizer = tf.keras.optimizers.Adam(learning_rate = self.cfg['a_lr'], clipnorm=self.cfg['clipnorm'], clipvalue=self.cfg['clipvalue'], decay=self.cfg['weight_decay']))

        try:
            self.initial_qbn_weights = self.qbn_gru_rnn.qbn_gru.hx_qbn.get_weights()
        except:
            pass
        self.initial_weights = self.get_weights()

        if cfg['quantization'].lower() == 'tern':
            quant_levels = [-1, 0, 1]
        elif cfg['quantization'].lower() == 'sign':
            quant_levels = [-1, 1]
        else:
            raise ValueError("Unknown quantization function!")

        self.hqs = list(itertools.product(*[quant_levels for _ in range(self.bottleneck_dim)]))
        # self.hqs_idxs = {np.array2string(np.squeeze(np.array(hq, dtype = 'int64'))) : idx for idx, hq in enumerate(self.hqs)}
        self.hqs_idxs = {hq : idx for idx, hq in enumerate(self.hqs)}

        print(self.summary())
        try:
            print(self.qbn_gru_rnn.summary())
        except:
            pass

    def reset_model(self):
        self.build(input_shape = (None, None, self.instance.input_dim if self.cfg['one_hot_obs'] else 1))
        self.set_weights(self.initial_weights)
        self.compile(loss = self.loss_f, optimizer = tf.keras.optimizers.Adam(learning_rate = self.cfg['a_lr'], clipnorm=self.cfg['clipnorm'], clipvalue=self.cfg['clipvalue'], decay=self.cfg['weight_decay']))

    def call(self, o, **kwargs):
        h = self.qbn_gru_rnn(o, **kwargs)
        print(kwargs, "NET h shape::", h.shape, "NET o shape:", o.shape)
        if self.mealy:
            ox = self.qbn_gru_rnn.embed(o, **kwargs)
            actor_in = tf.concat([ox, h], axis=-1)
        else:
            actor_in = h

        p = self.actor(actor_in)

        return p

    def simulate_with_random_uncertainty(self, ipomdp : IPOMDP, pomdp : POMDPWrapper, batch_dim = None, greedy = False, length = None, quantize = False, inspect = False, T_has_memory_dep = False, empirical_rewards_only = False, collect_hs = False):
        """ Simulates this policy network on the RPOMDP, randomly picking POMDP transitions. """

        batch_dim = batch_dim or self.cfg['batch_dim']
        length = length or self.cfg['length']

        # node = np.zeros((batch_dim), dtype=np.int32)
        
        rewards = np.zeros((batch_dim, length, pomdp.num_reward_models), dtype = 'float32')
        
        dones = np.ones((batch_dim, length), dtype=bool)
        batch_done = np.zeros((batch_dim), dtype=bool)
        
        if not empirical_rewards_only:
            beliefs = np.zeros((batch_dim, length, pomdp.nS))
            belief = np.zeros((batch_dim, pomdp.nS))
            belief[:, pomdp.initial_state] = 1
            states = np.zeros((batch_dim, length), dtype = 'int32')
            observations = np.zeros((batch_dim, length), dtype = 'int32') - 1
            policies = np.zeros((batch_dim, length, pomdp.nA), dtype = 'float32')
            actions = np.zeros((batch_dim, length), dtype = 'int32')
        
        if collect_hs or not empirical_rewards_only:
            hs = np.zeros((batch_dim, length, self.memory_dim), dtype = 'float32')
            # if quantize:
                # hxs = np.full((batch_dim, length, self.memory_dim), -2, dtype = 'float32')
                # hqs = np.full((batch_dim, length, self.bottleneck_dim), -2, dtype = 'int32')

        state = np.array([np.squeeze(pomdp.initial_state) for b in range(batch_dim)], dtype = 'int32')
        observation = np.array([np.squeeze(pomdp.initial_observation) for b in range(batch_dim)], dtype = 'int32')

        reset = self.qbn_gru_rnn.reset(batch_dim, quantize)
        if quantize:
            h, hq, hx = reset
        else:
            h = reset

        for l in range(length):
        

            dones[:, l] = batch_done

            if not empirical_rewards_only:
                beliefs[~batch_done, l] = belief[~batch_done]
                states[~batch_done, l] = state[~batch_done]
                observations[~batch_done, l] = observation[~batch_done]
            if collect_hs or not empirical_rewards_only:
                hs[:, l] = tf.squeeze(h)
                # if quantize:
                    # hxs[:, l] = tf.squeeze(hx)
                    # hqs[:, l] = tf.squeeze(hq)

            if self.cfg['one_hot_obs']:
                x = np.reshape(ut.one_hot_encode(observation, pomdp.nO, dtype = 'float32'), (batch_dim, 1, pomdp.nO))
            else:
                x = np.reshape(observation, (batch_dim, 1, 1))

            a, action, h = self._action(x, inspect = inspect, greedy = greedy, quantize=quantize, states = h, mask = pomdp.policy_mask[observation])

            if not empirical_rewards_only:
                policies[~batch_done, l] = a.numpy()[~batch_done]
                actions[~batch_done, l] = action[~batch_done]
            
            rewards[~batch_done, l, :] = ipomdp.R[state[~batch_done], action[~batch_done]][..., np.newaxis] if ipomdp.state_action_rewards else ipomdp.R[state[~batch_done]][..., np.newaxis]
            
            def get_random_distribution(trans_dict):
                possible_states, intervals = zip(*trans_dict.items())
                probs = [random.uniform(lb, up) for (lb, up) in intervals]
                if not math.isclose(sum(probs), 1):
                    probs[-1] = 1 - sum(probs[:-1])
                return possible_states, probs

            for b in (~batch_done).nonzero()[0]:
                possible_states, probs = get_random_distribution(ipomdp.T[state[b], action[b]])
                state[b] = random.choices(possible_states, weights=probs, k=1)[0]

            observation = pomdp.O[state]

            next_belief = np.zeros((batch_dim, ipomdp.nS))
            for b in (~batch_done).nonzero()[0]:
                if self.instance.label_to_reach in ipomdp.pPOMDP.observation_labels[observation[b]]:
                    batch_done[b] = True
                if not empirical_rewards_only:
                    for s in np.where(belief[b] > 0)[0]:
                        possible_states, probs = get_random_distribution(ipomdp.T[s, actions[b, l]])
                        for next_s, prob in zip(possible_states, probs):
                            if ipomdp.pPOMDP.O[next_s] == observation[b]:
                                next_belief[b, next_s] += prob * belief[b, s]
                    if not math.isclose(next_belief[b].sum(), 1):
                        next_belief[b] = ut.normalize(next_belief[b])

            belief = np.array(next_belief)

            if batch_done.all(): break

        if empirical_rewards_only:
            beliefs, states, observations, policies, actions = [None for _ in range(5)]
        if not (collect_hs or not empirical_rewards_only):
            hs, hqs, hxs = [None for _ in range(3)]

        return beliefs, states, hs, observations, policies, actions, rewards, dones

    def simulate_with_dynamic_uncertainty(self, ipomdp : IPOMDP, pomdp : POMDPWrapper, T : dict[dict[int, dict[tuple[int, int], dict[int, float]]]],  FSC : FiniteMemoryPolicy, batch_dim = None, greedy = False, length = None, quantize = False, inspect = False, T_has_memory_dep = False, empirical_rewards_only = False, collect_hs = False):
        """ Simulates an interaction of this HxQBN-GRU-RNN with a POMDP model application. """

        batch_dim = batch_dim or self.cfg['batch_dim']
        length = length or self.cfg['length']
        
        rewards = np.zeros((batch_dim, length, pomdp.num_reward_models), dtype = 'float32')
        
        dones = np.ones((batch_dim, length), dtype=bool)
        batch_done = np.zeros((batch_dim), dtype=bool)
        
        if not empirical_rewards_only:
            beliefs = np.zeros((batch_dim, length, pomdp.nS))
            belief = np.zeros((batch_dim, pomdp.nS))
            belief[:, pomdp.initial_state] = 1
            states = np.zeros((batch_dim, length), dtype = 'int32')
            observations = np.zeros((batch_dim, length), dtype = 'int32') - 1
            policies = np.zeros((batch_dim, length, pomdp.nA), dtype = 'float32')
            actions = np.zeros((batch_dim, length), dtype = 'int32')
        
        if collect_hs or not empirical_rewards_only:
            hs = np.zeros((batch_dim, length, self.memory_dim), dtype = 'float32')
            # if quantize:
                # hxs = np.full((batch_dim, length, self.memory_dim), -2, dtype = 'float32')
                # hqs = np.full((batch_dim, length, self.bottleneck_dim), -2, dtype = 'int32')

        state = np.array([np.squeeze(pomdp.initial_state) for b in range(batch_dim)], dtype = 'int32')
        observation = np.array([np.squeeze(pomdp.initial_observation) for b in range(batch_dim)], dtype = 'int32')

        reset = self.qbn_gru_rnn.reset(batch_dim, quantize)
        if quantize:
            h, hq, hx = reset
        else:
            h = reset

        for l in range(length):

            dones[:, l] = batch_done

            if not empirical_rewards_only:
                beliefs[~batch_done, l] = belief[~batch_done]
                states[~batch_done, l] = state[~batch_done]
                observations[~batch_done, l] = observation[~batch_done]
            if collect_hs or not empirical_rewards_only:
                hs[:, l] = tf.squeeze(h)
                # if quantize:
                    # hxs[:, l] = tf.squeeze(hx)
                    # hqs[:, l] = tf.squeeze(hq)

            if self.cfg['one_hot_obs']:
                x = np.reshape(ut.one_hot_encode(observation, pomdp.nO, dtype = 'float32'), (batch_dim, 1, pomdp.nO))
            else:
                x = np.reshape(observation, (batch_dim, 1, 1))

            a, action, h = self._action(x, inspect = inspect, greedy = greedy, quantize=quantize, states = h, mask = pomdp.policy_mask[observation])

            if not empirical_rewards_only:
                policies[~batch_done, l] = a.numpy()[~batch_done]
                actions[~batch_done, l] = action[~batch_done]
            
            rewards[~batch_done, l, :] = ipomdp.R[state[~batch_done], action[~batch_done]][..., np.newaxis] if ipomdp.state_action_rewards else ipomdp.R[state[~batch_done]][..., np.newaxis]

            for b in (~batch_done).nonzero()[0]:
                possible_states, probs = zip(*T[state[b], action[b]].items())
                state[b] = random.choices(possible_states, weights=probs, k=1)[0]

            observation = pomdp.O[state]

            next_belief = np.zeros((batch_dim, ipomdp.nS))
            for b in (~batch_done).nonzero()[0]:
                if self.instance.label_to_reach in ipomdp.pPOMDP.observation_labels[observation[b]]:
                    batch_done[b] = True
                if not empirical_rewards_only:
                    for s in np.where(belief[b] > 0)[0]:
                        for next_s, prob in T[s, actions[b, l]].items():
                            if ipomdp.pPOMDP.O[next_s] == observation[b]:
                                next_belief[b, next_s] += prob * belief[b, s]
                    if not math.isclose(next_belief[b].sum(), 1):
                        next_belief[b] = ut.normalize(next_belief[b])

            belief = np.array(next_belief)

            if batch_done.all(): break

        if empirical_rewards_only:
            beliefs, states, observations, policies, actions = [None for _ in range(5)]
        if not (collect_hs or not empirical_rewards_only):
            hs, hqs, hxs = [None for _ in range(3)]

        return beliefs, states, hs, observations, policies, actions, rewards, dones

    def set_train(self, value):
        self.training = value

    def _action(self, x, inspect = False, greedy = False, quantize = False, states = None, mask = None):
    
        if len(x.shape) > 2 and x.shape[1] > 1:
            raise ValueError('Actions can only be determined for a single time-step.')

        batch_dim = x.shape[0]
        
        if states is not None:
            states = tf.reshape(states, shape=(batch_dim, 1, self.memory_dim))
            assert x.shape[:2] == states.shape[:2], (x.shape, states.shape)

        if not self.cfg['one_hot_obs']: 
            x = self.qbn_gru_rnn.embed(x, training=False)
        
        if inspect:
            (hx, hq), hs = self.qbn_gru_rnn.qbn_gru(x, states, training=False)
        else:
            h, hs = self.qbn_gru_rnn.qbn_gru(x, states, training=False)
        
        if self.mealy:
            actor_in = tf.concat([x, states], axis=-1) # h
        else:
            actor_in = h

        a = tf.squeeze(self.actor(actor_in, training=False))

        if greedy:
            actions = ut.argmax_from_md(a.numpy(), batch_dim, mask = mask)
        else:
            actions = ut.choice_from_md(a.numpy(), batch_dim, mask = mask)

        if inspect:
            return a, actions, h, hq
        else:
            return a, actions, h

    def improve_r(self, hs, reset_weights=False):
        if reset_weights: self.qbn_gru_rnn.qbn_gru.hx_qbn.set_weights(self.initial_qbn_weights)
        train_result = self.qbn_gru_rnn.qbn_gru.hx_qbn.fit(x = hs, y = hs, batch_size = self.cfg['r_batch_size'], epochs = self.cfg['r_epochs'], shuffle=True, verbose = 0)
        r_loss = train_result.history['loss']
        return r_loss

    def improve_a(self, inputs, labels, quantize = False, mask = None, reset_weights = False):
        """
        Trains the GRU actor on inputs and labels of args, leaving HxQBN unchanged if quantize = False.

        param: inputs   :   the observation inputs.
        param: labels   :   the labels.
        param: quantize :   if True, perform finetuning given the current HxQBN.
        param: mask     :   timesteps to mask when training.

        """
        
        self.set_train(True)
    
        
        if reset_weights:
            self.reset_model()

        if self.cfg['one_hot_obs']:
            input_dim = self.input_dim
            batch_dim, time_dim, _ = inputs.shape
            dummy_observation = ut.dummy_observation(batch_dim, 1, input_dim, squeeze = True)
        else:
            input_dim = 1
            batch_dim, time_dim = inputs.shape
            inputs = inputs.reshape((batch_dim, time_dim, input_dim))
            dummy_observation = ut.dummy_observation(batch_dim, 1, input_dim, squeeze = True)
            assert inputs.max() < self.input_dim + 1
            
        
        _mask = None

        train_result = self.fit(x = inputs, y = labels, sample_weight=_mask, batch_size = self.cfg['a_batch_size'], epochs = self.cfg['a_epochs'], verbose = 0, shuffle=True)
        a_loss = train_result.history['loss']
        return a_loss
    
    def create_obs_input_for_extraction(self):
        obs = tf.reshape(tf.range(0, self.input_dim, dtype=tf.int32), (self.input_dim, 1, 1))
        obs = self.qbn_gru_rnn.embed(obs, training=False)
        return obs
    
    def extract_fsc_with_kmeans(self, kmeans : KMeans, k, make_greedy = True, reshape = True):

        next_memories = np.full((k, self.input_dim), -1)
        # next_memories = np.repeat(np.arange(len(hqs))[:, None], self.input_dim, axis=-1)
        action_distributions = np.zeros((k, self.input_dim, self.output_dim))


        for (hq_idx), hx in zip(range(k), kmeans.cluster_centers_):
            hx = tf.expand_dims(hx, axis = 0)            

            hx_ = tf.expand_dims(tf.repeat(hx, self.input_dim, 0), axis=1)

            if self.cfg['one_hot_obs']:
                obs = tf.one_hot(np.arange(self.input_dim), self.input_dim)
            else:
                obs = self.create_obs_input_for_extraction()

            # next_h_, _ = self.qbn_gru_rnn(obs, hx_, training=False) #.qbn_gru(obs, hx_)
            next_h_, _ = self.qbn_gru_rnn.qbn_gru(obs, hx_, training=False)

            next_hq_idx = kmeans.predict(tf.squeeze(next_h_).numpy())

            if self.mealy:
                actor_in = tf.concat([obs, hx_], axis=-1)
            else:
                actor_in = next_h_

            distributions = self.actor(actor_in).numpy() # self.actor(next_h_).numpy()

            action_distributions[hq_idx] = np.reshape(distributions, (self.input_dim, self.output_dim))

            next_memories[hq_idx] = next_hq_idx

        return FiniteMemoryPolicy(action_distributions, next_memories,
            make_greedy = make_greedy, reshape = reshape,
            initial_observation = self.instance.pomdp.initial_observation)


    def extract_fsc(self, make_greedy = True, reshape = True):

        action_distributions, next_memories = self._construct_transaction_table()

        
        fsc_policy = FiniteMemoryPolicy(
            action_distributions, next_memories,
            make_greedy = make_greedy, reshape = reshape,
            initial_observation = self.instance.pomdp.initial_observation)

        return fsc_policy

    def _construct_transaction_table(self):
        # First we swap indices of the first and initial hq.
        hqs = np.array(deepcopy(self.hqs))
        hqs_idxs = deepcopy(self.hqs_idxs)
        # print(hqs_idxs)
        first_hq = tuple(hqs[0])
        first_hq_idx = hqs_idxs[first_hq]
        init_h, init_hq, init_hx = self.qbn_gru_rnn.reset(batch_dim = 1, quantize = True)
        init_hq_idx = hqs_idxs[tuple(init_hq.numpy().ravel().tolist())]
        hqs_idxs[tuple(init_hq.numpy().ravel().tolist())] = 0
        hqs_idxs[first_hq] = init_hq_idx
        hqs[[init_hq_idx, first_hq_idx]] = hqs[[first_hq_idx, init_hq_idx]]
        
        hxs = self.qbn_gru_rnn.qbn_gru.hx_qbn.decode(np.array(hqs, dtype = 'float64')).numpy()

        next_memories = np.full((len(hqs), self.input_dim), -1)
        # next_memories = np.repeat(np.arange(len(hqs))[:, None], self.input_dim, axis=-1)
        action_distributions = np.zeros((len(hqs), self.input_dim, self.output_dim))
    

        for (hq_idx, hq), hx in zip(enumerate(hqs), hxs):
            hx = tf.expand_dims(hx, axis = 0)
            
            hx_ = tf.expand_dims(tf.repeat(hx, self.input_dim, 0), axis=1)
            
            if self.cfg['one_hot_obs']:
                obs = tf.one_hot(np.arange(self.input_dim), self.input_dim)
            else:
                obs = self.create_obs_input_for_extraction()
            
            
            next_h_, _ = self.qbn_gru_rnn.qbn_gru(obs, hx_, training=False)
            
            next_hq_ = self.qbn_gru_rnn.qbn_gru.hx_qbn.encode(next_h_, training=False)
            
            next_hq_ = np.array(next_hq_, dtype=int).reshape((self.input_dim, next_hq_.shape[-1]))

            if self.mealy:
                actor_in = tf.concat([obs, hx_], axis=-1)
            else:
                actor_in = next_h_
            

            distributions = self.actor(actor_in)
            
            action_distributions[hq_idx] = tf.reshape(distributions, (self.input_dim, self.output_dim))
            
            
            u,inv = np.unique(next_hq_, return_inverse = True, axis=0)
            next_hq_idx = np.array([hqs_idxs[tuple(x)] for x in u])[inv].reshape(self.input_dim)
            
            next_memories[hq_idx] = next_hq_idx

        return action_distributions, next_memories

class GRUActor(tf.keras.layers.RNN):
    """
    Represents an GRU-RNN with quantized latent state.

    """

    def __init__(self, instance, cfg):

        self.instance = instance
        self.cfg = cfg

        self.input_dim = instance.input_dim
        self.memory_dim = cfg['a_memory_dim']
        self.output_dim = instance.output_dim
        
        # self.inter_dim = self.input_dim // 4
        
        self.inter_dim = 32
        
        self.do_embed = not cfg['one_hot_obs']

        if self.do_embed:
            self.embedder = tf.keras.layers.Embedding(self.input_dim + 1, self.inter_dim, mask_zero=True)

        self.qbn_gru = QBNGRU(instance, cfg, input_dim=self.inter_dim if self.do_embed else self.input_dim)

        super().__init__(cell = self.qbn_gru, return_sequences = True, dtype = 'float32')
        # super().build((None, None, 1 if self.do_embed else self.input_dim))
        
        self.supports_masking = True

    def embed(self, x, training=None):
        if self.do_embed:
            x = self.embedder(tf.squeeze(x + 1, axis=-1), training=training)
        return x

    def call(self, x, **kwargs):
        
        x = self.embed(x, **kwargs)
        
        h = super().call(x, mask=x._keras_mask)

        return h
    
    def reset(self, batch_dim, quantize):
        initial_hidden_state = self.qbn_gru.reset(batch_dim)
        if quantize:
            encoded = self.qbn_gru.hx_qbn.encode(initial_hidden_state, training=False)
            decoded = self.qbn_gru.hx_qbn.decode(encoded, training=False)
            return initial_hidden_state, encoded, decoded
        return initial_hidden_state

class HxQBN(tf.keras.models.Model):
    def __init__(self, instance, cfg):

        super().__init__()

        self.instance = instance
        self.cfg = cfg

        self.input_dim = cfg['a_memory_dim']
        self.bottleneck_dim = cfg['bottleneck_dim']
        self.inter_dim = self.bottleneck_dim * cfg['blow_up']
        
        final_activation = flatter_tanh if cfg["quantization"] == 'tern' else 'tanh'

        self.encoder = tf.keras.models.Sequential([
            tf.keras.layers.Dense(self.inter_dim, activation = 'tanh', name = 'enc_1'),
            tf.keras.layers.Dense(self.inter_dim / 2, activation = 'tanh', name = 'enc_2'),
            tf.keras.layers.Dense(self.bottleneck_dim, activation = final_activation, name = 'enc_3')
        ], name = 'HxQBN-Encoder')
        
        self.quant = tf.keras.layers.Activation(f'ste_{cfg["quantization"]}', name = 'quantizer')

        self.decoder = tf.keras.models.Sequential([
            tf.keras.layers.Dense(self.inter_dim / 2, activation = 'tanh', name = 'dec_1'),
            tf.keras.layers.Dense(self.inter_dim, activation = 'tanh', name = 'dec_2'),
            tf.keras.layers.Dense(self.input_dim, activation = 'tanh', name = 'dec_3'),
        ], name = 'HxQBN-Decoder')

        self.inspect = False
        
        self.supports_masking = True

        super().build((None, self.input_dim))
        
        print(self.encoder.summary())
        
        print(self.decoder.summary())
        
        if self.cfg['method'].lower() == 'qbn':
            optimizer = tf.keras.optimizers.Adam(learning_rate = self.cfg['r_lr'], clipnorm=self.cfg['clipnorm'], clipvalue=self.cfg['clipvalue'], decay=self.cfg['weight_decay'])
            self.compile(optimizer = optimizer, loss = 'mse')

    def encode(self, x, training=None):
        return self.quant(self.encoder(x, training=training))

    def decode(self, x, training=None):
        return self.decoder(x, training=training)

    def set_inspect(self, inspect):
        self.inspect = inspect
        return inspect

    def call(self, x, **kwargs):
        x = self.encode(x, **kwargs)
        y = self.decode(x, **kwargs)

        if self.inspect:
            return x, y
        else:
            return y

class QBNGRU(tf.keras.layers.GRUCell):

    def __init__(self, instance, cfg, input_dim):

        self.instance = instance
        self.cfg = cfg

        self.num_hx_qbns = cfg['num_hx_qbns']
        self.hx_qbn_idx = 0
        self.hx_qbns = [HxQBN(instance, cfg) for i in range(cfg['num_hx_qbns'])]
        self.hx_qbn = self.hx_qbns[self.hx_qbn_idx]

        self.memory_dim = cfg['a_memory_dim']
        self.input_dim = input_dim

        super().__init__(self.memory_dim, name = 'qbn_gru', recurrent_initializer='orthogonal')
        super().build(input_shape = (None, self.input_dim if self.cfg['one_hot_obs'] else 32))

        self.quantize = self.cfg['method'].lower() == 'qrnn'
        print("self.quantize:", self.quantize)
        self.inspect = False
        
        self.supports_masking = True

    def set_hx_qbn_idx(self, idx):
        self.hx_qbn_idx = idx

    def set_quantize(self, quantize):
        raise Exception()
        self.quantize = quantize

    def set_inspect(self, inspect):
        for i in range(self.num_hx_qbns):
            self.hx_qbns[i].set_inspect(inspect)
        self.inspect = inspect
        return inspect

    def reset(self, batch_dim):
        self.states = tf.zeros((batch_dim, 1, self.memory_dim))
        return self.states

    def call(self, inputs, states = None, training = None, mask = None):

        [h, _] = super().call(inputs, states, training=training)


        if self.quantize:
            # if inputs.shape[0] is not None:
            assert not self.cfg['method'].lower() in ['kmeans', 'qbn']
            if self.inspect:
                assert False
                hq, hx = self.hx_qbns[self.hx_qbn_idx](hs, training=training)
                return (hx[:,-1], hq[:,-1]), hs
            else:
                hx = self.hx_qbns[self.hx_qbn_idx](h, training=training)
            return h, hx
        
        assert self.cfg['method'].lower() != 'qrnn'

        return h, [h]

