import copy
import math
import multiprocessing
import pickle
import sys
import numpy as np
from copy import deepcopy
import time
from simulation import FIB, POUCT, QMDP
from policytree import PolicyTree
from scipy.special import softmax
import stormpy
from scipy.stats import entropy
import tensorflow as tf
import pycarl
from joblib import Parallel, delayed
from mem_top import mem_top
import inspect
from fsc import FiniteMemoryPolicy
from interval_models import IDTMC, IPOMDP, MDPSpec
from models import PDTMCModelWrapper, POMDPWrapper
from datetime import datetime
from net import Net
from instance import Instance
from check import Checker
from in_out import Log, clear_cache
import utils
import random
import subprocess
import os

import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

def get_git_revision_hash() -> str:
    return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()

class Experiment:
    """ Represents a set of cfgs that serve an experiment. """
    def __init__(self, name, cfg, num_runs):
        self.name = name
        self.num_runs = num_runs
        self.cfgs = [cfg]
        self.cfg = cfg

    def add_cfg(self, new_cfg):
        configuration = deepcopy(self.cfg)
        for key in new_cfg:
            configuration[key] = new_cfg[key]
        self.cfgs.append(configuration)


    def execute(self, multi_thread, seeds):
        assert len(seeds) == self.num_runs, f"len(seeds) != self.num_runs <=> {len(seeds)} != {self.num_runs}. Specify seeds for the runs!"
        self.cfg['seeds'] = tuple(seeds)
        self.cfg['git_sha'] = get_git_revision_hash()
        logs = []

        os.environ['PYTHONHASHSEED'] = str(0)
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
        os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
        tf.config.threading.set_inter_op_parallelism_threads(1)
        tf.config.threading.set_intra_op_parallelism_threads(1)

        if multi_thread:
            logs = Parallel(n_jobs = min(self.num_runs, multiprocessing.cpu_count()))(delayed(self._run)(Log(self.cfgs, self.num_runs, self.name), cfg_idx, run_idx, seeds[run_idx]) for cfg_idx in range(len(self.cfgs)) for run_idx in range(self.num_runs))
        else:
            for cfg_idx in range(len(self.cfgs)):
                for run_idx in range(self.num_runs):
                    log = self.run(Log(self.cfgs, self.num_runs, self.name), self.cfgs, cfg_idx, run_idx, seeds[run_idx])
                    logs.append(log)
        utils.inform(f'Finished experiment {self.name}.', indent = 0, itype = 'OKGREEN')
        logs = [log for log in logs if log is not None]
        if len(logs) == 0:
            return
        with open(f"./data/output/{self.name}/logs.pickle", 'wb') as handle:
            pickle.dump(logs, handle)
        Log.output_benchmark_table(logs, f'./data/output/{self.name}')
    
    def _run(self, log, cfg_idx, run_idx, seed):
        return self.run(log, self.cfgs, cfg_idx, run_idx, seed)

    # @profile
    @staticmethod
    def run(log, cfgs, cfg_idx, run_idx, seed):
        cfg = cfgs[cfg_idx]
        pycarl.clear_pools()
        tf.keras.backend.clear_session()
        np.random.seed(seed)
        tf.random.set_seed(seed)
        random.seed(seed)

        utils.inform(f'Starting run {run_idx} with seed {seed}.', indent = 0, itype = 'OKBLUE')
        instance = Instance(cfg)
        pomdp : POMDPWrapper = instance.build_pomdp()
        ps = cfg['p_init'][run_idx]
        for key, value in ps.items():
            if value is None:
                assert cfg['p_bounds'][key][0] < cfg['p_bounds'][key][1], (key, value, cfg['p_bounds'][key])
                ps[key] = random.uniform(cfg['p_bounds'][key][0], cfg['p_bounds'][key][1])
        worst_ps = ps

        length = instance.simulation_length()
        checker = Checker(instance, cfg)
        net : Net = Net(instance, cfg)

        dynamic_uncertainty = cfg['dynamic_uncertainty']
        deterministic_target_policy = cfg['train_deterministic']

        use_nominal_mdp_policy = not dynamic_uncertainty # or cfg['policy'].lower() in ['qmdp', 'mdp']

        use_nominal_pomdp = use_nominal_mdp_policy or not dynamic_uncertainty

        assert instance.objective == 'min'

        spec = MDPSpec(cfg['specification'])

        reset_qbn_weights = cfg['fresh_qbn_every_iter']
        # if reset_qbn_weights:
            # cfg['r_epochs'] *= 2

        tau = cfg['temperature']

        max_k = (3 if cfg['quantization'].lower() == 'tern' else 2)**cfg['bottleneck_dim']

        utils.inform(f"Max k = {max_k}. Target policy: {cfg['policy']} ({'deterministic' if deterministic_target_policy else 'stochastic'}). BASELINE={not dynamic_uncertainty}. Method={cfg['method']}. Policy loss: {cfg['a_loss']}. Spec: {spec}.", indent = 0, itype = 'OKBLUE')

        mdp_goal_states = [i for i, x in enumerate(pomdp.state_labels) if instance.label_to_reach in x]

        assert len(mdp_goal_states) > 0

        print("GOALS:", mdp_goal_states)

        assert np.array(pomdp.initial_state).item() == 0

        use_kmeans = {"kmeans" : True, "qbn" : False, "qrnn" : False, 'policytree' : False}[cfg['method'].lower()]
        qrnn = {"kmeans" : False, "qbn" : False, "qrnn" : True, 'policytree' : False}[cfg['method'].lower()]
        qbn = {"kmeans" : False, "qbn" : True, "qrnn" : False, 'policytree' : False}[cfg['method'].lower()]
        policytree = {"kmeans" : False, "qbn" : False, "qrnn" : False, 'policytree' : True}[cfg['method'].lower()]

        ipomdp = IPOMDP(instance, pomdp, cfg['p_bounds'], mdp_goal_states, force_intervals=False)

        nominal_parameters = {name : (ps[name], ps[name]) for name in cfg['p_bounds'].keys()}
        nominal_pomdp = IPOMDP(instance, pomdp, nominal_parameters, mdp_goal_states, force_intervals=True, build_prism_file=use_nominal_pomdp)

        nominal_T = {(s,a) : {next_s : prob[0] for next_s, prob in next_s_dict.items()} for (s,a), next_s_dict in nominal_pomdp.T.items()}

        np.set_printoptions(threshold=sys.maxsize)

        if not use_nominal_pomdp:
            assert dynamic_uncertainty
            Q = ipomdp.mdp_action_values(spec, use_existing='pomdp_attempt' in cfg['name'])
            print(pomdp.initial_state, np.nanmin(Q, axis=-1)[pomdp.initial_state], 0, np.nanmin(Q, axis=-1)[0])
            utils.inform(f'Synthesized iMDP-policy (min: {np.nanmin(Q):.2f}, max: {np.nanmax(Q):.2f}) w/ value = {np.nanmin(Q, axis=-1)[pomdp.initial_state]}')
            print(Q, file=open(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/MDP-Q.txt", 'w'))
            print(Q[pomdp.initial_state])
            print(Q.shape, ipomdp.imdp_V[pomdp.initial_state[0]])
        elif use_nominal_pomdp:
            assert not dynamic_uncertainty or use_nominal_mdp_policy
            Q = nominal_pomdp.mdp_action_values(spec)
            utils.inform(f'Synthesized nominal ({nominal_parameters}) MDP-policy (min: {np.nanmin(Q):.2f}, max: {np.nanmax(Q):.2f}) w/ value = {np.nanmin(Q, axis=-1)[pomdp.initial_state]} = {nominal_pomdp.imdp_V[pomdp.initial_state]}')
            print(Q, file=open(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/MDP-Q.txt", 'w'))
            print(Q[pomdp.initial_state])
            print(Q.shape, nominal_pomdp.imdp_V[pomdp.initial_state[0]])
        else:
            raise ValueError("What will be the policy?")

        if instance.objective == 'min':
            nan_fixer = 1e6 if deterministic_target_policy else 1e3
            assert np.nanmax(Q) < nan_fixer
        else:
            nan_fixer = -1e6 if deterministic_target_policy else -1e3
            assert np.nanmin(Q) > -nan_fixer

        mdp_q_values = np.nan_to_num(Q, nan=nan_fixer)

        best_robust_fsc = None
        best_wost_static_fsc = None
        best_robust_value = np.inf
        best_worst_static_value = np.inf

        utils.inform(f'{run_idx}\t(empir)\t\t(s,a)-rewards: {ipomdp.state_action_rewards}')

        assert ipomdp.state_action_rewards == nominal_pomdp.state_action_rewards
        
        memory_dependence_on_worst_case_T = cfg['memory_dependent_worst_case_T']

        if use_kmeans:
            k = max_k
            kmeans = KMeans(k, n_init='auto', init='k-means++')

        data = {"observations" : None, "labels" : None}
        
        os.makedirs(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/plots/", exist_ok=True)

        # np.seterr(all='raise')
        
        T = nominal_T
        
        reshape = True
        deterministic_fsc = False

        for round_idx in range(cfg['rounds']):

            tik = time.time()
            
            ########### EXTRACT FSC ###########

            if qrnn or qbn:
                fsc = net.extract_fsc(reshape=reshape, make_greedy = deterministic_fsc)
            elif use_kmeans and round_idx in [0,1,2]:
                fsc = PolicyTree.qmdp_to_fsc(ipomdp, pomdp, T=nominal_T, deterministic = deterministic_fsc, mdp_q_values=Q, tau=tau, nan_fixer=nan_fixer, length = length, batch_dim = instance.cfg['batch_dim'], T_has_memory_dep = memory_dependence_on_worst_case_T)
            elif use_kmeans and round_idx > 0:
                fsc = net.extract_fsc_with_kmeans(kmeans, k, make_greedy=deterministic_fsc, reshape=reshape)
            elif policytree:
                fsc = PolicyTree.qmdp_to_fsc(ipomdp, pomdp, depth=7, T=T, deterministic = deterministic_fsc, mdp_q_values=Q, tau=tau, nan_fixer=nan_fixer, length = length, batch_dim = instance.cfg['batch_dim'], T_has_memory_dep = memory_dependence_on_worst_case_T)
            else:
                raise ValueError("How do we get the FSC?")

            fsc.mask(pomdp.policy_mask)

            dtmc_initial_state = np.array(pomdp.initial_state).item() * np.arange(fsc.nM_generated)

            tok = time.time()

            fsc_extract_time = tok - tik

            utils.inform(f"{run_idx}-{round_idx}\t({fsc.nM_generated}-FSC)\t\tExtracted {fsc.nM_generated}-FSC from {'policy tree' if policytree else 'RNN'} in {fsc_extract_time:.4f}s", indent = 0)
            
            tik = time.time()

            idtmc : IDTMC = ipomdp.create_iDTMC(fsc, add_noise=0)

            tok = time.time()
            
            ########### Construct robust/interval DTMC and compute robust value ###########

            utils.inform(f"{run_idx}-{round_idx}\t(iDTMC)\t\tInduced iMC with {pomdp.nS} x {fsc.nM_generated} = {idtmc.nS} states in {tok-tik:.4f}s", indent = 0)

            target_mc_states = idtmc.labels_to_states["goal"]
            V = idtmc.check_reward(spec, target_mc_states)
            robust_values = V[dtmc_initial_state]
            initial_node = np.argmin(robust_values)
            robust_value = robust_values[initial_node]
            
            ########### Construct and solve LP ###########

            utils.inform(f'{run_idx}-{round_idx}\t(iDTMC)\t\tRVI \t>> %.4f <<' % robust_value + f" from {np.unique(robust_values).size} unique initial value(s) | previous best: {best_robust_value:.4f} (PRISM)", indent = 0)
            lp_vi_value = None
            if dynamic_uncertainty:
                if np.isfinite(robust_value).item():
                    heuristic_T, LP_value = ipomdp.find_sup_pomdp_T(V, instance, fsc, memory_dependence=memory_dependence_on_worst_case_T, pessimistic={MDPSpec.Rminmax : True, MDPSpec.Rminmin : False}[spec])
                    T = heuristic_T[0]
                    utils.inform(f'{run_idx}-{round_idx}\t(LP)\t\tFound worst-case rPOMDP transition model. LP value approx: {(LP_value/pomdp.nS)/fsc.nM_generated}', indent = 0)
                    # worst_dtmc_V = ipomdp.create_DTMC(heuristic_T, fsc)
                    # utils.inform(f'{run_idx}-{round_idx}\t(DTMC)\t\tAdv. VI \t {worst_dtmc_V[dtmc_initial_state]} (PRISM)', indent = 0)
                    # lp_vi_value = worst_dtmc_V[dtmc_initial_state]
                else:
                    utils.inform(f'{run_idx}-{round_idx}\t(LP)\t\tNon-finite value function, re-using previous T.', indent = 0)

            if robust_value < best_robust_value:
                best_robust_value = robust_value
                best_robust_fsc = fsc
                best_robust_weights = net.get_weights()
                
            ########### SIMULATE RNN ###########
                
            rnn_beliefs, rnn_states, hs, rnn_observations, rnn_policies, rnn_actions, rnn_rewards, rnn_dones = net.simulate_with_random_uncertainty(ipomdp, pomdp, greedy = deterministic_target_policy, length = length, batch_dim = instance.cfg['batch_dim'], empirical_rewards_only=cfg['use_supervision_simulation_data'], collect_hs = not qrnn, quantize=qrnn)
            
            assert rnn_rewards[rnn_dones].sum() == 0, "Rewards (costs) were collected by RNN while in the goal state!"

            rnn_cum_rewards = rnn_rewards.sum(axis=1)
            rnn_mean_cum_rewards = rnn_cum_rewards.mean()
            
            rnn_random_uncertainty_return = rnn_mean_cum_rewards

            utils.inform(f'{run_idx}-{round_idx}\t(RNN)\t\tempir\t%.4f' % rnn_mean_cum_rewards + f" ± {rnn_cum_rewards.std():.2f} (RANDOM)", indent = 0)
            
            rnn_beliefs, rnn_states, hs, rnn_observations, rnn_policies, rnn_actions, rnn_rewards, rnn_dones = net.simulate_with_dynamic_uncertainty(ipomdp, pomdp, T, fsc, greedy = deterministic_target_policy, length = length, batch_dim = instance.cfg['batch_dim'], empirical_rewards_only=cfg['use_supervision_simulation_data'], collect_hs = True, quantize=qrnn)
            
            assert rnn_rewards[rnn_dones].sum() == 0, "Rewards (costs) were collected by RNN while in the goal state!"

            rnn_cum_rewards = rnn_rewards.sum(axis=1)
            rnn_mean_cum_rewards = rnn_cum_rewards.mean()

            utils.inform(f'{run_idx}-{round_idx}\t(RNN)\t\tempir\t%.4f' % rnn_mean_cum_rewards + f" ± {rnn_cum_rewards.std():.2f} {'(nominal)' if round_idx == 0 or not dynamic_uncertainty else '(worst-case)'}", indent = 0)
            
            ########### COMPUTE AND SIMULATE SUPERVISION POLICY ###########
            
            if cfg['policy'].lower() in ['pouct', 'pomcp', 'qmdp', 'qumdp']:
                if dynamic_uncertainty:
                    adversarial_pomdp = copy.copy(ipomdp)
                    adversarial_pomdp.has_intervals = False
                    adversarial_pomdp.T = T
                    adversarial_pomdp.build_prism_imdp()
                    adversarial_Q = adversarial_pomdp.mdp_action_values(spec)
                else:
                    adversarial_Q = Q

            if cfg['policy'].lower() == 'fib':
                if dynamic_uncertainty or round_idx == 0:
                    fib = FIB(ipomdp, pomdp, T, label_to_reach=instance.label_to_reach, mdp_q_values=Q, tau=tau, nan_fixer=nan_fixer)
                simul_policy_value = fib.get_value(pomdp.initial_state)
                fib_beliefs, fib_states, fib_observations, fib_policies, fib_actions, fib_rewards, fib_dones = fib.simulate_on_current_POMDP(length = length, batch_dim = instance.cfg['batch_dim'], deterministic = deterministic_target_policy)
                assert fib_rewards[fib_dones].sum() == 0, "Rewards (costs) were collected by FIB while in the goal state!"
                fib_cum_rewards = fib_rewards.sum(axis=1)
                fib_mean_cum_rewards = fib_cum_rewards.mean()
                supervision_policy_return = {'supervision_return' : fib_mean_cum_rewards}
                utils.inform(f'{run_idx}-{round_idx}\t(FIB)\t\tempir\t%.4f' % fib_mean_cum_rewards + f" ± {fib_cum_rewards.std():.2f} {'(nominal' if round_idx == 0 or not dynamic_uncertainty else '(worst-case'}, MDP value={simul_policy_value})", indent = 0)

            if len({'qmdp', 'qumdp'}.intersection({cfg['policy'].lower()})) > 0:
                if dynamic_uncertainty or round_idx == 0:
                    qmdp = QMDP(ipomdp, pomdp, T, label_to_reach=instance.label_to_reach, mdp_q_values=adversarial_Q, tau=tau, nan_fixer=nan_fixer)
                qmdp_beliefs, qmdp_states, qmdp_observations, qmdp_policies, qmdp_actions, qmdp_rewards, qmdp_dones = qmdp.simulate_on_current_POMDP(length = length, batch_dim = instance.cfg['batch_dim'], deterministic = deterministic_target_policy)

                assert qmdp_rewards[qmdp_dones].sum() == 0, "Rewards (costs) were collected by QMDP while in the goal state!"

                qmdp_cum_rewards = qmdp_rewards.sum(axis=1)
                qmdp_mean_cum_rewards = qmdp_cum_rewards.mean()
                
                supervision_policy_return = {'supervision_return' : qmdp_mean_cum_rewards}
                
                simul_policy_value = np.nanmin(adversarial_Q,axis=-1)[pomdp.initial_state]
                
                utils.inform(f'{run_idx}-{round_idx}\t(QMDP)\t\tempir\t%.4f' % qmdp_mean_cum_rewards + f" ± {qmdp_cum_rewards.std():.2f} {'(nominal' if round_idx == 0 or not dynamic_uncertainty else '(worst-case'}, MDP value={simul_policy_value})", indent = 0)

            fsc_rewards = fsc.simulate_fsc(ipomdp, pomdp, T, label_to_reach=instance.label_to_reach, greedy = deterministic_target_policy, length = length, batch_dim = instance.cfg['batch_dim'])

            fsc_cum_rewards = fsc_rewards.sum(axis=1)
            fsc_mean_cum_rewards = fsc_cum_rewards.mean()

            utils.inform(f'{run_idx}-{round_idx}\t(FSC)\t\tempir\t%.4f' % fsc_mean_cum_rewards + f" ± {fsc_cum_rewards.std():.2f} {'(nominal)' if round_idx == 0 or not dynamic_uncertainty else '(worst-case)'}", indent = 0)
            
            if cfg['policy'].lower() in ['pouct', 'pomcp']:
                pouct = POUCT(ipomdp, pomdp, T, label_to_reach=instance.label_to_reach,mdp_q_values=adversarial_Q, tau=tau, nan_fixer=nan_fixer)
                pouct_beliefs, pouct_states, pouct_observations, pouct_policies, pouct_actions, pouct_rewards, pouct_dones = pouct.simulate_on_current_POMDP(length = length, batch_dim = instance.cfg['batch_dim'], deterministic = deterministic_target_policy)
                
                assert pouct_rewards[pouct_dones].sum() == 0, "Rewards (costs) were collected by POUCT while in the goal state!"
                
                pouct_cum_rewards = pouct_rewards.sum(axis=1)
                pouct_mean_cum_rewards = pouct_cum_rewards.mean()
                supervision_policy_return = {'supervision_return' : pouct_mean_cum_rewards}
                
                simul_policy_value = np.nanmin(adversarial_Q,axis=-1)[pomdp.initial_state]

                utils.inform(f'{run_idx}-{round_idx}\t(POUCT)\t\tempir\t%.4f' % pouct_mean_cum_rewards + f" ± {pouct_cum_rewards.std():.2f} {'(nominal' if round_idx == 0 or not dynamic_uncertainty else '(worst-case'}, MDP value={simul_policy_value})", indent = 0)

            if cfg['use_supervision_simulation_data']:
                if cfg['policy'].lower() == 'qumdp' or cfg['policy'].lower() == 'qmdp':
                    train_observations = qmdp_observations
                    train_dones = qmdp_dones
                    a_labels = qmdp_policies
                elif cfg['policy'].lower() == 'fib':
                    train_observations = fib_observations
                    train_dones = fib_dones
                    a_labels = fib_policies
                elif cfg['policy'].lower() in ['pouct', 'pomcp']:
                    train_observations = pouct_observations
                    train_dones = pouct_dones
                    a_labels = pouct_policies
                else:
                    raise ValueError()
            else:
                if cfg['policy'].lower() == 'mdp' or cfg['policy'].lower() == 'umdp':
                    q_values = mdp_q_values[rnn_states]
                elif cfg['policy'].lower() == 'qumdp' or cfg['policy'].lower() == 'qmdp':
                    assert rnn_beliefs.shape[-1] == mdp_q_values.shape[0], "Shape mismatch."
                    assert not np.isnan(rnn_beliefs).any()
                    q_values = np.matmul(rnn_beliefs, mdp_q_values)
                    # q_values = np.matmul(beliefs, Q)
                else:
                    raise ValueError("Invalid policy?")
                if deterministic_target_policy:
                    nanarg = np.nanargmin if instance.objective == 'min' else np.nanargmax
                    a_labels = utils.one_hot_encode(nanarg(q_values, axis = -1), pomdp.nA, dtype ='float32')
                else:
                    a_labels = np.nan_to_num(utils.nan_soft_max_norm(q_values, minimize=(instance.objective == 'min'), axis=-1, tau=tau), nan=0.0)

            if cfg['one_hot_obs']:
                a_inputs = utils.one_hot_encode(train_observations, pomdp.nO, dtype = 'float32')
            else:
                a_inputs = train_observations

            if not np.isfinite(a_inputs).all():
                print(a_inputs)
                print("The inputs (observations) contain non-finite values, we are prone to fail inside the fit function!")
                raise AssertionError

            if not np.isfinite(a_labels).all():
                print(a_labels)
                print("The labels contain non-finite values, we are prone to fail inside the fit function!")
                raise AssertionError
            
            max_sequence_length = np.count_nonzero(~train_dones, axis=-1).max()
            
            qbn_log = ""
            if use_kmeans or qbn or qrnn:
                print(rnn_dones.shape, hs.shape, hs[~(rnn_dones.astype(bool))].shape)
                relevant_hs = np.unique(hs[~(rnn_dones.astype(bool))], axis=0)
                qbn_log += f"({relevant_hs.shape[0]}/{hs.size})"

            ########### TRAIN QBN (+ RNN) ###########
            utils.inform(f'{run_idx}-{round_idx}\t(NNs)\t\tTraining QBN and RNN on {np.count_nonzero(~train_dones)}/{np.size(train_dones)} steps with a maximum length of {max_sequence_length}.' + qbn_log, indent = 0)

            loss_info = True
            if use_kmeans:
                if round_idx > 1:
                    kmeans_ = [KMeans(2, n_init='auto').fit(relevant_hs)]
                    for k in range(3, 9):
                        c = KMeans(k, n_init='auto').fit(relevant_hs)
                        if c.inertia_ > kmeans_[-1].inertia_:
                            break
                        kmeans_.append(c)
                    scores = [silhouette_score(relevant_hs, c.labels_) for c in kmeans_]
                    k = np.argmin(scores) + 2
                    # print(scores)
                    kmeans = kmeans_[k-2]
                    # kmeans.fit(relevant_hs)
                    clusters = kmeans.predict(relevant_hs)

                    fig = plt.figure()
                    ax = fig.add_subplot(projection='3d')
                    # print(relevant_hs.shape)

                    ax.scatter(relevant_hs[:, 0], relevant_hs[:, 1], relevant_hs[:, 2], c=clusters, alpha=0.5)

                    plt.savefig(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/plots/kmeans-{round_idx}.png")

                    plt.close(fig)
                disc_info = f'{run_idx}-{round_idx}\t(Kmeans)\tinertia \t%.4f' % (kmeans.inertia_ if round_idx > 1 else np.inf)
                r_loss = kmeans.inertia_ if round_idx > 1 else np.inf
            elif qbn or qrnn: # use QBN
                if round_idx > 0:
                    try:
                        r_loss = net.improve_r(relevant_hs, reset_weights=reset_qbn_weights)
                    except:
                        loss_info = False
                        r_loss = [0.0]
                else:
                    r_loss = [np.inf]
                disc_info = f'{run_idx}-{round_idx}\t(QBN)\t\trloss \t%.4f' % r_loss[0] + ('\t>>>> %3.4f' % r_loss[-1]) + ", reset: " + str(reset_qbn_weights)
                projected_hidden_states = net.qbn_gru_rnn.qbn_gru.hx_qbn.encoder(relevant_hs, training=False)
                clusters = net.qbn_gru_rnn.qbn_gru.hx_qbn.quant(projected_hidden_states).numpy()
                clusters = [net.hqs_idxs[tuple(c.tolist())] for c in clusters]
                fig = plt.figure()
                if projected_hidden_states.shape[1] == 2:
                    plt.scatter(projected_hidden_states[:, 0], projected_hidden_states[:, 1], c=clusters, alpha=0.5)
                    plt.savefig(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/plots/qbn-{round_idx}.png")
                    plt.close(fig)
                else:
                    plt.close(fig)
                if qrnn:
                    r_loss = None
                    loss_info = False
                else:
                    r_loss = np.array(r_loss)[-1]
            elif policytree:
                r_loss = None
                loss_info = False
            else:
                raise ValueError()

            if loss_info: utils.inform(disc_info, indent = 0)            

            ########### TRAIN RNN ###########
            a_loss = net.improve_a(a_inputs[:, :max_sequence_length, ...], a_labels[:, :max_sequence_length, ...], quantize=qrnn, reset_weights=cfg['fresh_qrnn_every_iter']) # , mask=dones)
            
            utils.inform(f'{run_idx}-{round_idx}\t(RNN)\t\taloss \t%.4f' % a_loss[0] + '\t>>>> %3.4f' % a_loss[-1] + ", reset: " + str(cfg['fresh_qrnn_every_iter']) , indent = 0)

            to_log = {
                'k' : fsc.nM_generated,
                'ps' : ps,
                'robust_value' : robust_value,
                'rnn_random_uncertainty_return' : rnn_random_uncertainty_return,
                'rnn_return' : rnn_mean_cum_rewards, 
                'fsc_return' : fsc_mean_cum_rewards,
                'a_loss' : np.array(a_loss)[-1], 
                'r_loss' : r_loss,
                'supervision_policy_value' : simul_policy_value
            }
            
            to_log.update(supervision_policy_return)

            log.flush(cfg_idx, run_idx, **to_log)

        print("Best (min) robust value:", best_robust_value, "best (min) worst (max) static value:", best_worst_static_value)
        with open(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/best_robust_fsc.pickle", 'wb') as handle:
            pickle.dump(best_robust_fsc, handle)
        with open(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/best_robust_fsc.tf", "wb") as handle:
            pickle.dump(best_robust_weights, handle)
        with open(f"{log.base_output_dir}/{cfg_idx}/{run_idx}/best_worst_static_fsc.pickle", 'wb') as handle:
            pickle.dump(best_wost_static_fsc, handle)
        return log
