## adapted from https://github.com/rail-berkeley/softlearning/blob/master/softlearning/algorithms/sac.py

import os
import math
import pickle
from collections import OrderedDict
from numbers import Number
from itertools import count
import gtimer as gt
import pdb

import numpy as np
import scipy
import tensorflow as tf
from tensorflow.python.training import training_util

from softlearning.algorithms.rl_algorithm import RLAlgorithm
from softlearning.replay_pools.simple_replay_pool import SimpleReplayPool
from softlearning.policies.utils import get_policy_from_variant, get_policy
from softlearning.value_functions.utils import get_Q_function_from_variant
from softlearning.utils.PCGrad import PCGrad

from cuds.models.constructor import construct_model, format_samples_for_training
from cuds.models.fake_env import FakeEnv
from cuds.utils.writer import Writer
from cuds.utils.visualization import visualize_policy
from cuds.utils.logging import Progress, Silent
import cuds.utils.filesystem as filesystem
import cuds.off_policy.loader as loader


def td_target(reward, discount, next_value):
    return reward + discount * next_value


class MultitaskCQL(RLAlgorithm):
    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            pool,
            static_fns,
            plotter=None,
            tf_summaries=False,

            lr=3e-4,
            q_lr=3e-4,
            reward_scale=1.0,
            target_entropy='auto',
            discount=0.99,
            tau=5e-3,
            target_update_interval=1,
            action_prior='uniform',
            reparameterize=False,
            store_extra_policy_info=False,

            deterministic=False,
            rollout_random=True,
            rollout_half_random=False,
            rollout_random_action_scale=1.0,
            num_repeat=1,
            model_train_freq=250,
            num_networks=7,
            num_elites=5,
            model_retain_epochs=20,
            rollout_batch_size=100e3,
            real_ratio=0.1,
            # rollout_schedule=[20,100,1,1],
            rollout_length=1,
            hidden_dim=200,
            num_layers=4,
            model_lr=1e-3,
            max_model_t=None,
            model_type='mlp',
            multi_step_prediction=False,
            num_plan_steps=1,
            reward_classification=False,
            model_rew_zero=False,
            sn=False,
            no_sn_last=False,
            gradient_penalty=0.0,
            gradient_penalty_scale=10.0,
            separate_mean_var=False,
            std_thresh=0.0,
            std_percentile=0.0,
            per_batch_std_percentile=0.0,
            oracle=False,
            identity_terminal=0,
            normalize_rew=False,
            rew_normalize_shift=0.5,
            rew_normalize_scale=4.0,

            pool_load_path='',
            pool_load_max_size=0,
            pool_eval_load_path='',
            pool_eval_load_max_size=0,
            model_name=None,
            model_load_dir=None,
            multitask_type=None,
            penalty_coeff=0.,
            penalty_learned_var=False,

            ## For min_Q runs
            with_min_q=False,
            new_min_q=False,
            min_q_for_real=False,
            min_q_for_real_and_fake=False,
            min_q_for_fake_only=False,
            min_q_for_fake_states=False,
            min_q_for_real_fake_states=False,
            min_q_real_states_weight=1.0,
            backup_for_fake_only=False,
            policy_for_fake_only=False,
            backup_for_one_step=False,
            backup_with_uniform=False,
            min_q_version=3,
            temp=1.0,
            hinge_bellman=False,
            use_projected_grad=False,
            normalize_magnitudes=False,
            regress_constant=False,
            min_q_weight=1.0,
            data_subtract=False,
            policy_eval_start=0,
            learn_bc_policy=False,
            learn_bc_epoch=200,
            policy_variant={},

            ## sort of backup
            max_q_backup=False,
            deterministic_backup=False,
            num_random=4,

            ## Lagrange multiplier
            with_lagrange=False,
            lagrange_thresh=10.0,

            ## Cross-validate
            restore=False,
            cross_validate=False,
            cross_validate_model_eval=False,
            cross_validate_eval_n_episodes=10,
            cross_validate_n_steps=1,
            use_fqe=False,
            fqe_num_qs=7,
            fqe_minq=False,
            Q_variant={},

            # Multi-Task
            num_tasks=1,
            goal_conditioned=False,
            goal_dim=6,
            PCGrad=False,
            PCGrad_bellman_only=False,
            share_data=False,
            relabel_reward=False,
            share_data_cql=False,
            policy_prob_weighting=False,
            policy_prob_weighting_no_cql=False,
            use_old_policy=False,
            old_policy_update_interval=100,
            use_relabel_weights=False,
            use_relabel_weights_diff=False,
            relabel_balance_batch=False,
            relabel_weight_temp=1.0,
            relabel_weight_temp_adaptive=False,
            relabel_weight_temp_adaptive_per_task=False,
            relabel_weight_temp_tau=5e-3,
            relabel_weight_temp_min=10.0,
            relabel_weight_temp_max=10000.0,
            relabel_weight_orig_task=False,
            relabel_prob=1.0,
            **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MultitaskCQL, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.active_observation_shape)
        act_dim = np.prod(training_environment.action_space.shape)

        self.num_tasks = num_tasks
        self.goal_conditioned = goal_conditioned
        self.goal_dim = goal_dim
        self.PCGrad = PCGrad
        self.PCGrad_bellman_only = PCGrad_bellman_only
        self.share_data = share_data
        self.relabel_reward = relabel_reward
        self.share_data_cql = share_data_cql
        self.policy_prob_weighting = policy_prob_weighting
        self.policy_prob_weighting_no_cql = policy_prob_weighting_no_cql
        self.use_old_policy = use_old_policy
        self.old_policy_update_interval = old_policy_update_interval
        self.use_relabel_weights = use_relabel_weights
        self.relabel_balance_batch = relabel_balance_batch
        self.use_relabel_weights_diff = use_relabel_weights_diff
        self.relabel_weight_temp = relabel_weight_temp
        self.relabel_weight_temp_adaptive = relabel_weight_temp_adaptive
        self.relabel_weight_temp_adaptive_per_task = relabel_weight_temp_adaptive_per_task
        self.relabel_weight_temp_tau = relabel_weight_temp_tau
        self.relabel_weight_temp_min = relabel_weight_temp_min
        self.relabel_weight_temp_max = relabel_weight_temp_max
        self.relabel_weight_orig_task = relabel_weight_orig_task
        self.multitask_type = multitask_type
        self.relabel_prob = relabel_prob
        if multitask_type is None:
            self.use_relabel_weights = False
            self.use_relabel_weights_diff = False

        self.restore = restore
        self.cross_validate = cross_validate
        self.cross_validate_model_eval = cross_validate_model_eval
        self.cross_validate_eval_n_episodes = cross_validate_eval_n_episodes
        self.cross_validate_n_steps = cross_validate_n_steps
        self.use_fqe = use_fqe
        self.fqe_minq = fqe_minq
        self.fqe_num_qs = fqe_num_qs

        self._model_type = model_type
        self._identity_terminal = identity_terminal
        self.multi_step_prediction = multi_step_prediction
        self.num_plan_steps = num_plan_steps
        self.num_networks = num_networks
        if self.cross_validate_model_eval:
            num_elites = 1
        self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_layers=num_layers,
                                      model_lr=model_lr, num_networks=num_networks, num_elites=num_elites,
                                      model_type=model_type, sn=sn, gradient_penalty=gradient_penalty,
                                      gradient_penalty_scale=gradient_penalty_scale, separate_mean_var=separate_mean_var,
                                      no_sn_last=no_sn_last, name=model_name, load_dir=model_load_dir, deterministic=deterministic,
                                      multi_step_prediction=multi_step_prediction, num_plan_steps=num_plan_steps,
                                      reward_classification=reward_classification)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns, penalty_coeff=penalty_coeff, penalty_learned_var=penalty_learned_var,
                                std_thresh=std_thresh, per_batch_std_percentile=per_batch_std_percentile,
                                oracle=oracle, oracle_env=evaluation_environment, model_rew_zero=model_rew_zero)
        self.std_percentile = std_percentile

        self._rollout_schedule = [20, 100, rollout_length, rollout_length]
        self._max_model_t = max_model_t

        # self._model_pool_size = model_pool_size
        # print('[ CUDS ] Model pool size: {:.2E}'.format(self._model_pool_size))
        # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size)

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._rollout_random = rollout_random
        self._rollout_half_random = rollout_half_random
        self._rollout_random_action_scale = rollout_random_action_scale
        self.num_repeat = num_repeat
        self._real_ratio = real_ratio

        self._normalize_rew = normalize_rew
        self._rew_normalize_shift = rew_normalize_shift
        self._rew_normalize_scale = rew_normalize_scale

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)
        if self.use_fqe:
            self._fqe_Qs = [get_Q_function_from_variant(Q_variant, training_environment) for i in range(fqe_num_qs)]
            self._fqe_Q_targets = [tuple(tf.keras.models.clone_model(Q) for Q in fqe_Qs) for fqe_Qs in self._fqe_Qs]

        if self.use_old_policy:
            self._old_policy = get_policy_from_variant(policy_variant, training_environment, Qs)
            self._update_old_policy()

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = q_lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto'
            else target_entropy)
        print('[ CUDS ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self.with_lagrange = with_lagrange
        if self.with_lagrange:
            self.target_action_gap = lagrange_thresh

        ## min Q
        self.with_min_q = with_min_q
        self.new_min_q = new_min_q
        self.min_q_for_real = min_q_for_real
        self.min_q_for_real_and_fake = min_q_for_real_and_fake
        self.min_q_for_fake_only = min_q_for_fake_only
        self.min_q_for_fake_states = min_q_for_fake_states
        self.min_q_for_real_fake_states = min_q_for_real_fake_states
        self.min_q_real_states_weight = min_q_real_states_weight
        self.backup_for_fake_only = backup_for_fake_only
        self.policy_for_fake_only = policy_for_fake_only
        self.backup_for_one_step = backup_for_one_step
        self.backup_with_uniform = backup_with_uniform
        self.temp = temp
        self.min_q_version = min_q_version
        self.use_projected_grad = use_projected_grad
        self.normalize_magnitudes = normalize_magnitudes
        self.regress_constant = regress_constant
        self.min_q_weight = min_q_weight
        if type(self.min_q_weight) is not float:
            self.min_q_weight = tf.constant(self.min_q_weight)
        self.data_subtract = data_subtract
        self.policy_eval_start = policy_eval_start

        self.max_q_backup = max_q_backup
        self.deterministic_backup = deterministic_backup
        self.num_random = num_random

        self._build()

        self.learn_bc_epoch = learn_bc_epoch
        if learn_bc_policy:
            policy_variant['policy_params']['kwargs']['hidden_layer_sizes'] = (256, 256, 256)
            self._bc_policy = get_policy_from_variant(policy_variant, training_environment, Qs)
            self._init_bc_policy_update()
        else:
            self._bc_policy = None

        #### load replay pool data
        self._pool_load_path = pool_load_path
        self._pool_load_max_size = pool_load_max_size

        loader.restore_pool(self._pool, self._pool_load_path, self._pool_load_max_size, save_path=self._log_dir, env=training_environment,
                            multitask_type=multitask_type, relabel_balance_batch=relabel_balance_batch)
        self._init_pool_size = self._pool.size
        print('[ CUDS ] Starting with pool size: {}'.format(self._init_pool_size))
        if pool_eval_load_path != '':
            self._pool_eval_load_path = pool_eval_load_path
            self._pool_eval_load_max_size = pool_eval_load_max_size

            obs_space = self._pool._observation_space
            act_space = self._pool._action_space
            print('[ CUDS ] Initializing evaluation pool with size {:.2e}'.format(
                pool_eval_load_max_size
            ))
            self._eval_pool = SimpleReplayPool(obs_space, act_space, pool_eval_load_max_size)

            # TODO: fix this
            loader.restore_pool(self._eval_pool, self._pool_eval_load_path, self._pool_eval_load_max_size, save_path=self._log_dir)
            self._init_eval_pool_size = self._eval_pool.size
        else:
            self._eval_pool = None
        ####

    def _build(self):
        self._training_ops = {}

        self._init_global_step()
        self._init_placeholders()
        self._init_relabeling()
        self._init_actor_update()
        self._init_critic_update()
        if self.use_fqe:
            self._fqe_Q_values, self._fqe_Q_losses, self.min_fqe_Q_losses, self._fqe_Q_optimizers = [], [], [], []
            if self.with_lagrange:
                self.log_alpha_prime_fqe, self._alpha_prime_fqe_optimizer, self.alpha_prime_fqe, self.orig_min_fqe_Q_losses, self.alpha_prime_fqe_loss, self._alpha_prime_fqe_train_op = [], [], [], [], [], []
            for i in range(len(self._fqe_Qs)):
                self._init_critic_update(use_fqe=self.use_fqe, idx=i)

    def _train(self):
        
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        policy = self._policy
        pool = self._pool
        model_metrics = {}

        if not self._training_started:
            self._init_training()

            # self._initial_exploration_hook(
            #     training_environment, self._initial_exploration_policy, pool)

        self.sampler.initialize(training_environment, policy, pool)

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        self._training_before_hook()

        #### model training
        print('[ CUDS ] log_dir: {} | ratio: {}'.format(self._log_dir, self._real_ratio))
        print('[ CUDS ] Training model at epoch {} | freq {} | timestep {} (total: {})'.format(
            self._epoch, self._model_train_freq, self._timestep, self._total_timestep)
        )

        if self._model.model_loaded or self._real_ratio == 1.0:
            max_epochs = 1
        elif self.multi_step_prediction:
            max_epochs = 150
        else:
            max_epochs = None
        if self.multi_step_prediction:
            model_train_metrics = self._train_model(batch_size=256, max_epochs=max_epochs, holdout_ratio=0.1, max_t=self._max_model_t, no_early_stop=True)
        else:
            model_train_metrics = self._train_model(batch_size=256, max_epochs=max_epochs, holdout_ratio=0.1, max_t=self._max_model_t)
        if self._model.model_loaded and self.cross_validate_model_eval:
            self._model._model_inds = [4]
        model_metrics.update(model_train_metrics)

        if self.restore and self.cross_validate:
            self._cross_validate(rollout_batch_size=self._rollout_batch_size)

        self._log_model()
        gt.stamp('epoch_train_model')

        if self._bc_policy is not None:
            self._train_bc_policy(batch_size=256, max_epochs=self.learn_bc_epoch, holdout_ratio=0.1)
            # gt.stamp('training_paths')
            bc_evaluation_paths = self._evaluation_paths(
                self._bc_policy, evaluation_environment)
            gt.stamp('bc_evaluation_paths')

            # training_metrics = self._evaluate_rollouts(
            #     training_paths, training_environment)
            # gt.stamp('training_metrics')
            if bc_evaluation_paths:
                bc_evaluation_metrics = self._evaluate_rollouts(
                    bc_evaluation_paths, evaluation_environment)
                gt.stamp('bc_evaluation_metrics')
            else:
                bc_evaluation_paths = {}
                bc_evaluation_metrics = None
        else:
            bc_evaluation_metrics = None
        #### 

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):

            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')

            self._training_progress = Progress(self._epoch_length * self._n_train_repeat)
            start_samples = self.sampler._total_samples
            for timestep in count():
                # samples_now = self.sampler._total_samples
                # self._timestep = samples_now - start_samples
                self._timestep = timestep

                if (timestep >= self._epoch_length
                    and self.ready_to_train):
                    break

                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                ## model rollouts
                if timestep % self._model_train_freq == 0 and self._real_ratio < 1.0:
                    self._training_progress.pause()
                    # print('[ CUDS ] Rolling out model at timestep: {} (ready to train: {})'.format(timestep, self.ready_to_train))
                    self._set_rollout_length()
                    self._reallocate_model_pool()
                    model_rollout_metrics = self._rollout_model(rollout_batch_size=self._rollout_batch_size, deterministic=self._deterministic)
                    model_metrics.update(model_rollout_metrics)
                    
                    gt.stamp('epoch_rollout_model')
                    # self._visualize_model(self._evaluation_environment, self._total_timestep)
                    self._training_progress.resume()

                ## environment rollouts
                # self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')

                ## train actor and critic
                if self.ready_to_train:
                    self._do_training_repeats(timestep=timestep)
                gt.stamp('train')

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            training_paths = self.sampler.get_last_n_paths(
                math.ceil(self._epoch_length / self.sampler._max_path_length))

            # gt.stamp('training_paths')
            evaluation_paths = self._evaluation_paths(
                policy, evaluation_environment)
            gt.stamp('evaluation_paths')

            # training_metrics = self._evaluate_rollouts(
            #     training_paths, training_environment)
            # gt.stamp('training_metrics')
            if evaluation_paths:
                evaluation_metrics = self._evaluate_rollouts(
                    evaluation_paths, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics = {}

            # self._epoch_after_hook(training_paths)
            gt.stamp('epoch_after_hook')

            sampler_diagnostics = self.sampler.get_diagnostics()

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batch=self._evaluation_batch(),
                training_paths=training_paths,
                evaluation_paths=evaluation_paths)

            time_diagnostics = gt.get_times().stamps.itrs

            diagnostics.update(OrderedDict((
                *(
                    (f'evaluation/{key}', evaluation_metrics[key])
                    for key in sorted(evaluation_metrics.keys())
                ),
                # *(
                #     (f'training/{key}', training_metrics[key])
                #     for key in sorted(training_metrics.keys())
                # ),
                *(
                    (f'times/{key}', time_diagnostics[key][-1])
                    for key in sorted(time_diagnostics.keys())
                ),
                *(
                    (f'sampler/{key}', sampler_diagnostics[key])
                    for key in sorted(sampler_diagnostics.keys())
                ),
                *(
                    (f'model/{key}', model_metrics[key])
                    for key in sorted(model_metrics.keys())
                ),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('timesteps_total', self._total_timestep),
                ('train-steps', self._num_train_steps),
            )))

            if bc_evaluation_metrics:
                diagnostics.update(OrderedDict(
                    (f'bc_evaluation/{key}', bc_evaluation_metrics[key])
                    for key in sorted(bc_evaluation_metrics.keys())
                ))

            if self._eval_render_mode is not None and hasattr(
                    evaluation_environment, 'render_rollouts'):
                training_environment.render_rollouts(evaluation_paths)

            ## ensure we did not collect any more data
            assert self._pool.size == self._init_pool_size

            yield diagnostics

        self.sampler.terminate()

        self._training_after_hook()

        self._training_progress.close()

        yield {'done': True, **diagnostics}

    def train(self, *args, **kwargs):
        return self._train(*args, **kwargs)

    def _evaluate_rollouts(self, paths, env):
        """Compute evaluation metrics for the given rollouts."""
        total_returns = [path['rewards'].sum() for path in paths]
        episode_lengths = [len(p['rewards']) for p in paths]
        if 'infos' in paths[0] and 'success' in paths[0]['infos'][0]:
            total_successes = [[info['success'] for info in path['infos']][-1] for path in paths]
        else:
            total_successes = None

        if total_successes:
            diagnostics = OrderedDict((
                ('return-average', np.mean(total_returns)),
                ('return-min', np.min(total_returns)),
                ('return-max', np.max(total_returns)),
                ('return-std', np.std(total_returns)),
                ('episode-length-avg', np.mean(episode_lengths)),
                ('episode-length-min', np.min(episode_lengths)),
                ('episode-length-max', np.max(episode_lengths)),
                ('episode-length-std', np.std(episode_lengths)),
                ('success-average', np.mean(total_successes)),
                ('success-min', np.min(total_successes)),
                ('success-max', np.max(total_successes)),
                ('success-std', np.std(total_successes)),
            ))
        else:
            diagnostics = OrderedDict((
                ('return-average', np.mean(total_returns)),
                ('return-min', np.min(total_returns)),
                ('return-max', np.max(total_returns)),
                ('return-std', np.std(total_returns)),
                ('episode-length-avg', np.mean(episode_lengths)),
                ('episode-length-min', np.min(episode_lengths)),
                ('episode-length-max', np.max(episode_lengths)),
                ('episode-length-std', np.std(episode_lengths)),
            ))
        if hasattr(self, 'num_tasks') and paths[0]['observations'][0, -self.num_tasks:].sum() == 1:
            total_returns_per_task = [[] for _ in range(self.num_tasks)]
            total_successes_per_task = [[] for _ in range(self.num_tasks)]
            for path in paths:
                total_returns_per_task[path['observations'][0,-self.num_tasks:].argmax()].append(path['rewards'].sum())
                if total_successes:
                    path_results = [info['success'] for info in path['infos']]
                    total_successes_per_task[path['observations'][0,-self.num_tasks:].argmax()].append(path_results[-1])
                else:
                    total_successes_per_task = None
            for i in range(self.num_tasks):
                if len(total_returns_per_task[i]) == 0:
                    total_returns_per_task[i].append(0.)
                if total_successes_per_task and len(total_successes_per_task[i]) == 0:
                    total_successes_per_task[i].append(0.)                
                diagnostics[f'return-average-task{i}'] = np.mean(total_returns_per_task[i])
                diagnostics[f'return-max-task{i}'] = np.max(total_returns_per_task[i])
                if total_successes_per_task:
                    diagnostics[f'success-average-task{i}'] = np.mean(total_successes_per_task[i])
                    diagnostics[f'success-max-task{i}'] = np.max(total_successes_per_task[i])

        if not hasattr(env, "oracle_env"):
            env_infos = env.get_path_infos(paths)
            for key, value in env_infos.items():
                diagnostics[f'env_infos/{key}'] = value

        return diagnostics

    def _cross_validate(self, rollout_batch_size):
        if self.cross_validate_model_eval:
            from softlearning.samplers import rollouts

            if self.cross_validate_eval_n_episodes < 1: return ()
            evaluation_metrics_list = []
            for i in range(self.num_networks):
                self._model._model_inds = [i]
                with self._policy.set_deterministic(self._eval_deterministic):
                    evaluation_paths = rollouts(
                        self.cross_validate_eval_n_episodes,
                        self.fake_env,
                        self._policy,
                        self.sampler._max_path_length,
                        render_mode=None)
                evaluation_metrics = self._evaluate_rollouts(
                        evaluation_paths, self.fake_env)
                # print('[ Cross Validation ] return-average: {} | return-std: {}'.format(evaluation_metrics['return-average'], evaluation_metrics['return-std']))
                evaluation_metrics_list.append(evaluation_metrics)
            # print('[ Cross Validation across models ] return-average: {} | return-std: {}'.format(np.mean([evaluation_metrics['return-average'] for evaluation_metrics in evaluation_metrics_list]),
            #                                                                                     np.std([evaluation_metrics['return-average'] for evaluation_metrics in evaluation_metrics_list])))
            return evaluation_metrics_list
        else:
            all_values, model_disagreements, branched_returns = [[] for _ in range(self.cross_validate_n_steps)], [], []
            for _ in range(self._pool.size // rollout_batch_size):
                print('[ Cross Validation ] Batch {} / {}'.format(_, self._pool.size // rollout_batch_size))
                batch = self.sampler.random_batch(rollout_batch_size)
                obs = batch['observations']
                steps_added = []
                rewards = np.zeros_like(batch['rewards'])
                branched_return = rewards.copy()
                feed_dict = self._get_feed_dict(None, batch)

                Q_values = self._session.run(self._Q_values, feed_dict)
                all_values[0].append(np.mean(np.array(Q_values), axis=0))
                for i in range(self.cross_validate_n_steps):
                    act = self._policy.actions_np(obs)

                    if self._model_type == 'identity':
                        next_obs = obs
                        rew = np.zeros((len(obs), 1))
                        term = (np.ones((len(obs), 1)) * self._identity_terminal).astype(np.bool)
                        info = {}
                    else:
                        next_obs, rew, term, info = self.fake_env.step(obs, act)
                        if info.get('mask', None) is not None:
                            obs = obs[info['mask']]
                            act = act[info['mask']]
                    steps_added.append(len(obs))

                    if info.get('mask', None) is not None and len(obs) == 0:
                        print('[ Model Rollout ] Breaking early due to variance threshold: {}'.format(i))
                        break

                    samples = {'observations': obs, 'actions': act, 'next_observations': next_obs, 'rewards': rew, 'terminals': term}
                    if i > 0:
                        feed_dict = self._get_feed_dict(None, samples)
                        Q_values = self._session.run(self._Q_values, feed_dict)
                        # TODO: this doesn't seem to be right
                        returns = rewards + (self._discount ** i) * np.array(Q_values).mean(axis=0) * (1.0 - term)
                        all_values[i].append(returns)
                    if i == self.cross_validate_n_steps - 1:
                        inputs = np.concatenate((obs, act), axis=-1)
                        ensemble_model_means, ensemble_model_vars = self._model.predict(inputs, factored=True)
                        ensemble_model_means[:, :, 1:] += obs
                        model_disagreements.append(np.amax(np.linalg.norm(ensemble_model_means[:, :, 1:] - np.mean(ensemble_model_means[:, :, 1:], axis=0), axis=-1), axis=0))
                        branched_returns.append(branched_return)

                    nonterm_mask = ~term.squeeze(-1)
                    if nonterm_mask.sum() == 0:
                        print('[ Model Rollout ] Breaking early: {} | {} / {}'.format(i, nonterm_mask.sum(), nonterm_mask.shape))
                        break

                    obs = next_obs[nonterm_mask]
                    rewards = rewards[nonterm_mask] + (self._discount**i) * rew[nonterm_mask]
                    branched_return = branched_return[nonterm_mask] + rew[nonterm_mask]
                    for j in range(i+1):
                        all_values[j][_] = all_values[j][_][nonterm_mask]

            for i in range(self.cross_validate_n_steps):
                all_values[i] = np.concatenate(all_values[i], axis=0)
            all_values = np.array(all_values)
            value_disagreement = np.var(all_values, axis=0).mean()
            model_disagreement = np.mean(np.concatenate(model_disagreements, axis=0))
            branched_returns = np.mean(np.concatenate(branched_returns, axis=0))
            print('[ Cross Validation ] Value Disagreement: {} | Model Disagreement: {}'.format(value_disagreement, model_disagreement))
            print('[ Cross Validation ] N-step Values are: {}'.format(np.mean(all_values, axis=1)))
            print('[ Cross Validation ] branched_returns are: {}'.format(branched_returns))
        import pdb; pdb.set_trace()
        return value_disagreement, model_disagreement

    def _log_policy(self):
        save_path = os.path.join(self._log_dir, 'models')
        filesystem.mkdir(save_path)
        weights = self._policy.get_weights()
        data = {'policy_weights': weights}
        full_path = os.path.join(save_path, 'policy_{}.pkl'.format(self._total_timestep))
        print('Saving policy to: {}'.format(full_path))
        pickle.dump(data, open(full_path, 'wb'))

    def _log_model(self):
        print('MODEL: {}'.format(self._model_type))
        if self._model_type == 'identity':
            print('[ CUDS ] Identity model, skipping save')
        elif self._model.model_loaded:
            print('[ CUDS ] Loaded model, skipping save')
        else:
            save_path = os.path.join(self._log_dir, 'models')
            filesystem.mkdir(save_path)
            print('[ CUDS ] Saving model to: {}'.format(save_path))
            self._model.save(save_path, self._total_timestep)

    def _set_rollout_length(self):
        min_epoch, max_epoch, min_length, max_length = self._rollout_schedule
        if self._epoch <= min_epoch:
            y = min_length
        else:
            dx = (self._epoch - min_epoch) / (max_epoch - min_epoch)
            dx = min(dx, 1)
            y = dx * (max_length - min_length) + min_length

        self._rollout_length = int(y)
        print('[ Model Length ] Epoch: {} (min: {}, max: {}) | Length: {} (min: {} , max: {})'.format(
            self._epoch, min_epoch, max_epoch, self._rollout_length, min_length, max_length
        ))

    def _reallocate_model_pool(self):
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space

        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_length * rollouts_per_epoch)
        new_pool_size = self._model_retain_epochs * model_steps_per_epoch

        if not hasattr(self, '_model_pool'):
            print('[ CUDS ] Initializing new model pool with size {:.2e}'.format(
                new_pool_size
            ))
            self._model_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            if self.backup_for_one_step:
                self._model_pool_one_step = SimpleReplayPool(obs_space, act_space, new_pool_size)
        
        elif self._model_pool._max_size != new_pool_size:
            print('[ CUDS ] Updating model pool | {:.2e} --> {:.2e}'.format(
                self._model_pool._max_size, new_pool_size
            ))
            samples = self._model_pool.return_all_samples()
            new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
            new_pool.add_samples(samples)
            assert self._model_pool.size == new_pool.size
            self._model_pool = new_pool
            if self.backup_for_one_step:
                # One-step pool
                samples = self._model_pool_one_step.return_all_samples()
                new_pool = SimpleReplayPool(obs_space, act_space, new_pool_size)
                new_pool.add_samples(samples)
                assert self._model_pool_one_step.size == new_pool.size
                self._model_pool_one_step = new_pool

    def _train_model(self, **kwargs):
        if self._model_type == 'identity':
            print('[ CUDS ] Identity model, skipping model')
            model_metrics = {}
        else:
            env_samples = self._pool.return_all_samples()
            train_inputs, train_outputs = format_samples_for_training(env_samples,
                                                                    multi_step_prediction=self.multi_step_prediction,
                                                                    num_plan_steps=self.num_plan_steps)
            model_metrics = self._model.train(train_inputs, train_outputs, **kwargs)
            if self.std_percentile > 0.0:
                assert self.fake_env.std_thresh == 0.0
                obs = env_samples['observations']
                act = env_samples['actions']
                inputs = np.concatenate((obs, act), axis=-1)
                ensemble_model_means, ensemble_model_vars = self._model.predict(inputs, factored=True)
                ensemble_model_means[:,:,1:] += obs
                ensemble_model_stds = np.sqrt(ensemble_model_vars)
                if self.std_percentile < 100:
                    std_percentile = np.percentile(np.amax(np.linalg.norm(ensemble_model_stds[:, :, 1:], axis=-1), axis=0), self.std_percentile)
                else:
                    std_percentile = np.amax(np.linalg.norm(ensemble_model_stds[:, :, 1:], axis=-1), axis=0).max()
                self.fake_env.std_thresh = std_percentile
        return model_metrics

    def _train_bc_policy(self, batch_size=256, max_epochs=200, holdout_ratio=0.1, max_logging=1000):
        import time
        import itertools

        env_samples = self._pool.return_all_samples()
        inputs, targets = env_samples['observations'], env_samples['actions']
        def shuffle_rows(arr):
            idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1)
            return arr[idxs]

        # Split into training and holdout sets
        num_holdout = min(int(inputs.shape[0] * holdout_ratio), max_logging)
        permutation = np.random.permutation(inputs.shape[0])
        inputs, holdout_inputs = inputs[permutation[num_holdout:]], inputs[permutation[:num_holdout]]
        targets, holdout_targets = targets[permutation[num_holdout:]], targets[permutation[:num_holdout]]

        idxs = np.random.randint(inputs.shape[0], size=[inputs.shape[0]])
        progress = Progress(max_epochs)

        epoch_iter = range(max_epochs)

        t0 = time.time()
        grad_updates = 0
        for epoch in epoch_iter:
            for batch_num in range(int(np.ceil(idxs.shape[-1] / batch_size))):
                batch_idxs = idxs[batch_num * batch_size:(batch_num + 1) * batch_size]
                _, train_loss = self._session.run(
                    [self.bc_policy_train_op, self.policy_bc_loss],
                    feed_dict={self._observations_ph: inputs[batch_idxs], self._actions_ph: targets[batch_idxs]}
                )
                grad_updates += 1
                if np.isnan(train_loss):
                    import pdb; pdb.set_trace()

            idxs = shuffle_rows(idxs)
            if holdout_ratio < 1e-12:
                losses = self._session.run(
                        self.policy_bc_loss,
                        feed_dict={
                            self._observations_ph: inputs[idxs[:max_logging]],
                            self._actions_ph: targets[idxs[:max_logging]]
                        }
                    )
                named_losses = [['BC', losses[i]]]
                progress.set_description(named_losses)
            else:
                losses = self._session.run(
                        self.policy_bc_loss,
                        feed_dict={
                            self._observations_ph: inputs[idxs[:max_logging]],
                            self._actions_ph: targets[idxs[:max_logging]]
                        }
                    )
                holdout_losses = self._session.run(
                        self.policy_bc_loss,
                        feed_dict={
                            self._observations_ph: holdout_inputs,
                            self._actions_ph: holdout_targets
                        }
                    )
                named_losses = [['BC', losses]]
                named_holdout_losses = [['BC Val', holdout_losses]]
                named_losses = named_losses + named_holdout_losses + [['T', time.time() - t0]]
                progress.set_description(named_losses)

            progress.update()
            t = time.time() - t0

        progress.stamp()

        holdout_losses = self._session.run(
            self.policy_bc_loss,
            feed_dict={
                self._observations_ph: holdout_inputs,
                self._actions_ph: holdout_targets
            }
        )

        print('[ BC Policy ] Holdout', holdout_losses)

    def _rollout_model(self, rollout_batch_size, **kwargs):
        print('[ Model Rollout ] Starting | Epoch: {} | Rollout length: {} | Batch size: {} | Type: {}'.format(
            self._epoch, self._rollout_length, rollout_batch_size, self._model_type
        ))
        batch = self.sampler.random_batch(rollout_batch_size)
        steps_added = []
        for j in range(self.num_repeat):
            obs = batch['observations']
            for i in range(self._rollout_length):
                if self._bc_policy is not None:
                    act = self._bc_policy.actions_np(obs)
                elif not self._rollout_random:
                    act = self._policy.actions_np(obs)
                elif self._rollout_half_random:
                    if np.random.random() > 0.5:
                        act_ = self._policy.actions_np(obs)
                        act = np.random.uniform(low=-self._rollout_random_action_scale, high=self._rollout_random_action_scale, size=act_.shape)
                    else:
                        act = self._policy.actions_np(obs)
                else:
                    act_ = self._policy.actions_np(obs)
                    act = np.random.uniform(low=-self._rollout_random_action_scale, high=self._rollout_random_action_scale, size=act_.shape)

                if self._model_type == 'identity':
                    next_obs = obs
                    rew = np.zeros((len(obs), 1))
                    term = (np.ones((len(obs), 1)) * self._identity_terminal).astype(np.bool)
                    info = {}
                else:
                    next_obs, rew, term, info = self.fake_env.step(obs, act, **kwargs)
                    if info.get('mask', None) is not None:
                        obs = obs[info['mask']]
                        act = act[info['mask']]
                steps_added.append(len(obs))

                if info.get('mask', None) is not None and len(obs) == 0:
                    print('[ Model Rollout ] Breaking early {} due to varaince threshold: {}'.format(j, i))
                    break

                samples = {'observations': obs, 'actions': act, 'next_observations': next_obs, 'rewards': rew, 'terminals': term}
                if self.backup_for_one_step and i == 0:
                    self._model_pool_one_step.add_samples(samples)
                self._model_pool.add_samples(samples)

                nonterm_mask = ~term.squeeze(-1)
                if nonterm_mask.sum() == 0:
                    print('[ Model Rollout ] Breaking early {} : {} | {} / {}'.format(j, i, nonterm_mask.sum(), nonterm_mask.shape))
                    break

                obs = next_obs[nonterm_mask]

        mean_rollout_length = sum(steps_added) / rollout_batch_size / self.num_repeat
        rollout_stats = {'mean_rollout_length': mean_rollout_length}
        print('[ Model Rollout ] Added: {:.1e} | Model pool: {:.1e} (max {:.1e}) | Length: {} | Train rep: {}'.format(
            sum(steps_added), self._model_pool.size, self._model_pool._max_size, mean_rollout_length, self._n_train_repeat
        ))
        return rollout_stats

    def _visualize_model(self, env, timestep):
        ## save env state
        state = env.unwrapped.state_vector()
        qpos_dim = len(env.unwrapped.sim.data.qpos)
        qpos = state[:qpos_dim]
        qvel = state[qpos_dim:]

        print('[ Visualization ] Starting | Epoch {} | Log dir: {}\n'.format(self._epoch, self._log_dir))
        visualize_policy(env, self.fake_env, self._policy, self._writer, timestep)
        print('[ Visualization ] Done')
        ## set env state
        env.unwrapped.set_state(qpos, qvel)

    def _training_batch(self, batch_size=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size*self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        if self.multitask_type == 'hipi':
            env_batch = self._pool.random_batch(env_batch_size, use_hipi=True)
        elif self.multitask_type == 'relabel-all' and self.goal_conditioned:
            env_batch = self._pool.random_batch(env_batch_size, relabel_all=True)
        else:
            env_batch = self._pool.random_batch(env_batch_size)
        env_batch['real_indicator'] = np.ones_like(env_batch['rewards'])
        env_batch['fake_indicator'] = np.zeros_like(env_batch['rewards'])

        if model_batch_size > 0:
            model_batch = self._model_pool.random_batch(model_batch_size)
            model_batch['real_indicator'] = np.zeros_like(model_batch['rewards'])
            model_batch['fake_indicator'] = np.ones_like(model_batch['rewards'])

            if self.backup_for_one_step:
                assert self.min_q_for_fake_only
                one_step_model_batch = self._model_pool_one_step.random_batch(model_batch_size)
                one_step_model_batch['real_indicator'] = np.zeros_like(one_step_model_batch['rewards'])
                one_step_model_batch['fake_indicator'] = np.zeros_like(one_step_model_batch['rewards'])
            if not self.backup_for_one_step:
                # keys = env_batch.keys()
                keys = set(env_batch.keys()) & set(model_batch.keys())
                batch = {k: np.concatenate((env_batch[k], model_batch[k]), axis=0) for k in keys}
            else:
                keys = set(env_batch.keys()) & set(model_batch.keys()) & set(one_step_model_batch.keys())
                batch = {k: np.concatenate((env_batch[k], model_batch[k], one_step_model_batch[k]), axis=0) for k in keys}
        else:
            ## if real_ratio == 1.0, no model pool was ever allocated,
            ## so skip the model pool sampling
            batch = env_batch
        return batch

    def _init_global_step(self):
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update({
            'increment_global_step': training_util._increment_global_step(1)
        })

    def _init_placeholders(self):
        """Create input placeholders for the SAC algorithm.

        Creates `tf.placeholder`s for:
            - observation
            - next observation
            - action
            - reward
            - terminals
        """
        self._iteration_ph = tf.placeholder(
            tf.int64, shape=None, name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, *self._action_shape),
            name='actions',
        )

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='terminals',
        )

        self._real_indicator_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='real_indicator',
        )

        self._fake_indicator_ph = tf.placeholder(
            tf.float32,
            shape=(None, 1),
            name='fake_indicator',
        )

        if self.use_relabel_weights:
            self._relabel_masks = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='relabel_masks',
            )

        if self.multitask_type == 'hipi' and self.goal_conditioned:
            self._tasks_ph = tf.placeholder(
                tf.float32,
                shape=(self.num_tasks, self.goal_dim),
                name='tasks',
            )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, *self._action_shape),
                name='raw_actions',
            )

    def _get_Q_target(self, use_fqe=False, idx=0):
        if use_fqe:
            Q_targets = self._fqe_Q_targets[idx]
        else:
            Q_targets = self._Q_targets
        if self.multitask_type != 'hipi':
            next_observations_ph = self._next_observations_ph
            rewards_ph = self._rewards_ph
            terminals_ph = self._terminals_ph
        else:
            next_observations_ph = self.relabeled_next_observations
            rewards_ph = self.relabeled_rewards
            terminals_ph = self.relabeled_terminals
        if self.max_q_backup:
            """when using max q backup"""
            next_observations_tile = tf.reshape(tf.tile(tf.expand_dims(next_observations_ph, axis=1),
                                                    tf.constant([1, 10, 1], tf.int32)),
                                                    [-1, *self._observation_shape])
            next_actions_temp = self._policy.actions([next_observations_tile])
            next_Qs_values = tuple(
                tf.reduce_max(tf.reshape(Q([next_observations_tile, next_actions_temp]), [-1, 10, 1]), axis=1)
                for Q in Q_targets)
            min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
            next_value = min_next_Q
        else:
            next_actions = self._policy.actions([next_observations_ph])
            next_log_pis = self._policy.log_pis(
                [next_observations_ph], next_actions)

            next_Qs_values = tuple(
                Q([next_observations_ph, next_actions])
                for Q in Q_targets)

            min_next_Q = tf.reduce_min(next_Qs_values, axis=0)
            if not self.deterministic_backup:
                next_value = min_next_Q - self._alpha * next_log_pis
            else:
                next_value = min_next_Q

        if not self._normalize_rew:
            Q_target = td_target(
                reward=self._reward_scale * rewards_ph,
                discount=self._discount,
                next_value=(1 - terminals_ph) * next_value)
        else:
            Q_target = td_target(
                reward=self._reward_scale * (rewards_ph - self._rew_normalize_shift) * self._rew_normalize_scale,
                discount=self._discount,
                next_value=(1 - terminals_ph) * next_value)

        return Q_target

    def _init_critic_update(self, use_fqe=False, idx=0):
        """Create minimization operation for critic Q-function.

        Creates a `tf.optimizer.minimize` operation for updating
        critic Q-function with gradient descent, and appends it to
        `self._training_ops` attribute.
        """
        if use_fqe:
            Qs = self._fqe_Qs[idx]
        else:
            Qs = self._Qs
        if self.multitask_type != 'hipi':
            observations_ph = self._observations_ph
            next_observations_ph = self._next_observations_ph
        else:
            observations_ph = self.relabeled_observations
            next_observations_ph = self.relabeled_next_observations
        Q_target = tf.stop_gradient(self._get_Q_target(use_fqe=use_fqe, idx=idx))

        if self.multitask_type != 'hipi':
            assert Q_target.shape.as_list() == [None, 1]

        Q_values = tuple(
            Q([observations_ph, self._actions_ph])
            for Q in Qs)
        if self.with_min_q and self.backup_for_fake_only:
            Q_losses = tuple(
                tf.losses.mean_squared_error(
                    labels=Q_target*self._real_indicator_ph, predictions=Q_value*self._real_indicator_ph, weights=0.5) / self._real_ratio
                for Q_value in Q_values)
        elif self.with_min_q and self.backup_for_one_step:
            Q_losses = tuple(
                tf.losses.mean_squared_error(
                    labels=Q_target*(1.0 - self._fake_indicator_ph), predictions=Q_value*(1.0 - self._fake_indicator_ph), weights=0.5)
                for Q_value in Q_values)
        elif self.with_min_q and self.backup_with_uniform:
            mask = tf.math.equal(self._fake_indicator_ph, tf.ones_like(self._fake_indicator_ph))
            policy_prob = self._policy_prob = tf.boolean_mask(self._policy.pis([observations_ph], self._actions_ph), mask)
            policy_prob = tf.maximum(self._policy_prob, 1e-5) / tf.maximum(tf.reduce_min(self._policy_prob), 1e-5)
            Q_losses = []
            for Q_value in Q_values:
                Q_real_loss = tf.reduce_sum(tf.square(Q_target*self._real_indicator_ph - Q_value*self._real_indicator_ph)) / tf.reduce_sum(self._real_indicator_ph)
                Q_target_fake = tf.boolean_mask(Q_target, mask)
                Q_value_fake = tf.boolean_mask(Q_value, mask)
                Q_fake_loss = tf.reduce_sum(tf.square(Q_target_fake - Q_value_fake) * 0.5 / policy_prob)
                Q_fake_loss = Q_fake_loss / tf.reduce_sum(0.5 / policy_prob)
                Q_losses.append(0.5*Q_real_loss*self._real_ratio + 0.5*Q_fake_loss*(1.0-self._real_ratio))
            Q_losses = tuple(Q_losses)
        elif self.with_min_q and self.policy_prob_weighting:
            # policy_prob = tf.maximum(policy_prob, 1e-5) / tf.maximum(tf.reduce_max(policy_prob), 1e-5)
            Q_losses = tuple(
                tf.losses.mean_squared_error(
                    labels=Q_target, predictions=Q_value, weights=0.5*self._policy_prob)
                for Q_value in Q_values)
        elif self.with_min_q and self.use_relabel_weights:
            # policy_prob = tf.maximum(policy_prob, 1e-5) / tf.maximum(tf.reduce_max(policy_prob), 1e-5)
            Q_losses = tuple(
                tf.losses.mean_squared_error(
                    labels=Q_target*self.relabel_weights[i], predictions=Q_value*self.relabel_weights[i], weights=0.5)
                for i, Q_value in enumerate(Q_values))
        else:
            Q_losses = tuple(
                tf.losses.mean_squared_error(
                    labels=Q_target, predictions=Q_value, weights=0.5)
                for Q_value in Q_values)
        if not use_fqe:
            self._Q_values = Q_values
            self._Q_losses = Q_losses
        else:
            self._fqe_Q_values.append(Q_values)
            self._fqe_Q_losses.append(Q_losses)

        self._Q_values_curr_actions = tf.constant(0.)

        if self.with_min_q and ((use_fqe and self.fqe_minq) or not use_fqe):
            if not self.min_q_for_fake_states:
                action_tile = tf.tile(tf.expand_dims(self._actions_ph, axis=1), tf.constant([1, self.num_random, 1], tf.int32))
                random_actions_tensor = tf.reshape(tf.random.uniform(tf.shape(action_tile), minval=-1, maxval=1), [-1, *self._action_shape])
                obs_tile = tf.reshape(tf.tile(tf.expand_dims(observations_ph, axis=1),
                                    tf.constant([1, self.num_random, 1], tf.int32)),
                                    [-1, *self._observation_shape])
                next_obs_tile = tf.reshape(tf.tile(tf.expand_dims(next_observations_ph, axis=1),
                                    tf.constant([1, self.num_random, 1], tf.int32)),
                                    [-1, *self._observation_shape])

                curr_actions_tensor = self._policy.actions([obs_tile])
                curr_log_pis = tf.reshape(self._policy.log_pis([obs_tile], curr_actions_tensor), [-1, self.num_random, 1])
                new_actions_tensor = self._policy.actions([next_obs_tile])
                new_log_pis = tf.reshape(self._policy.log_pis([next_obs_tile], new_actions_tensor), [-1, self.num_random, 1])
                Q_values_rand = tuple(
                    tf.reshape(Q([obs_tile, random_actions_tensor]), [-1, self.num_random, 1])
                    for Q in Qs)
                Q_values_curr_actions = tuple(
                    tf.reshape(Q([obs_tile, curr_actions_tensor]), [-1, self.num_random, 1])
                    for Q in Qs)
                Q_values_next_actions = tuple(
                    tf.reshape(Q([obs_tile, new_actions_tensor]), [-1, self.num_random, 1])
                    for Q in Qs)
                cat_Q_values = tuple(
                    tf.concat([Q_values_rand[i], tf.expand_dims(Q_values[i], axis=1), Q_values_next_actions[i], Q_values_curr_actions[i]], axis=1)
                    for i in range(len(Qs)))
                Q_values_std = tuple(
                    tf.math.reduce_std(cat_Q_value, axis=1)
                    for cat_Q_value in cat_Q_values)
                self._Q_values_curr_actions = tuple(
                    tf.reduce_mean(Q_value_curr_actions, axis=1)
                    for Q_value_curr_actions in Q_values_curr_actions)

                if self.min_q_version == 3:
                    # importance sammpled version
                    random_density = np.log(0.5 ** self._action_shape[-1])
                    if self.min_q_for_real and not self.min_q_for_real_and_fake:
                        cat_Q_values = tuple(
                            tf.concat([Q_values_rand[i] - random_density, Q_values_next_actions[i] - tf.stop_gradient(new_log_pis), Q_values_curr_actions[i] - tf.stop_gradient(curr_log_pis)], axis=1) * tf.expand_dims(self._real_indicator_ph, axis=1)
                            for i in range(len(Qs)))
                        """log sum exp for the min"""
                        min_Q_losses = tuple(
                            tf.reduce_sum(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * self._real_indicator_ph) * self.min_q_weight * self.temp / tf.reduce_sum(self._real_indicator_ph)
                            for cat_Q_value in cat_Q_values)
                    elif self.min_q_for_fake_only:
                        cat_Q_values = tuple(
                            tf.concat([Q_values_rand[i] - random_density, Q_values_next_actions[i] - tf.stop_gradient(new_log_pis), Q_values_curr_actions[i] - tf.stop_gradient(curr_log_pis)], axis=1) * tf.expand_dims(self._fake_indicator_ph, axis=1)
                            for i in range(len(Qs)))
                        # cat_Q_values = tuple(
                        #     tf.concat([Q_values_rand[i] - random_density, Q_values_curr_actions[i] - tf.stop_gradient(curr_log_pis)], axis=1) * tf.expand_dims(self._fake_indicator_ph, axis=1)
                        #     for i in range(len(Qs)))
                        """log sum exp for the min"""
                        min_Q_losses = tuple(
                            tf.reduce_sum(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * (self._fake_indicator_ph)) * self.min_q_weight * self.temp / tf.reduce_sum(self._fake_indicator_ph)
                            for cat_Q_value in cat_Q_values)
                    else:
                        cat_Q_values = tuple(
                            tf.concat([Q_values_rand[i] - random_density, Q_values_next_actions[i] - tf.stop_gradient(new_log_pis), Q_values_curr_actions[i] - tf.stop_gradient(curr_log_pis)], axis=1)
                            for i in range(len(Qs)))
                        
                        """log sum exp for the min"""
                        if self.policy_prob_weighting and not self.policy_prob_weighting_no_cql:
                            if type(self.min_q_weight) is not float:
                                min_Q_losses = tuple(
                                    tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1), [self.num_tasks, -1, 1]) * self._policy_prob, axis=[1, 2]) * self.min_q_weight) * self.temp
                                    for cat_Q_value in cat_Q_values)
                            else:
                                min_Q_losses = tuple(
                                    tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * self._policy_prob) * self.min_q_weight * self.temp
                                    for cat_Q_value in cat_Q_values)
                        elif self.use_relabel_weights:
                            if type(self.min_q_weight) is not float:
                                min_Q_losses = tuple(
                                    tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * self.relabel_weights[i], [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) * self.temp
                                    for i, cat_Q_value in enumerate(cat_Q_values))
                            else:
                                min_Q_losses = tuple(
                                    tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * self.relabel_weights[i]) * self.min_q_weight * self.temp
                                    for i, cat_Q_value in enumerate(cat_Q_values))
                        else:
                            if type(self.min_q_weight) is not float:
                                min_Q_losses = tuple(
                                    tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1), [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) * self.temp
                                    for cat_Q_value in cat_Q_values)
                            else:
                                min_Q_losses = tuple(
                                    tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1)) * self.min_q_weight * self.temp
                                    for cat_Q_value in cat_Q_values)
                else:
                    # importance sammpled version
                    random_density = np.log(0.5 ** self._action_shape[-1])
                    if self.min_q_for_real and not self.min_q_for_real_and_fake:
                        cat_Q_values = tuple(
                            tf.concat([Q_values_rand[i], Q_values_next_actions[i], Q_values_curr_actions[i]], axis=1) * tf.expand_dims(self._real_indicator_ph, axis=1)
                            for i in range(len(Qs)))
                        """log sum exp for the min"""
                        min_Q_losses = tuple(
                            tf.reduce_sum(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * self._real_indicator_ph) * self.min_q_weight * self.temp / tf.reduce_sum(self._real_indicator_ph)
                            for cat_Q_value in cat_Q_values)
                    elif self.min_q_for_fake_only:
                        cat_Q_values = tuple(
                            tf.concat([Q_values_rand[i], Q_values_next_actions[i], Q_values_curr_actions[i]], axis=1) * tf.expand_dims(self._fake_indicator_ph, axis=1)
                            for i in range(len(Qs)))
                        # cat_Q_values = tuple(
                        #     tf.concat([Q_values_rand[i], Q_values_curr_actions[i]], axis=1) * tf.expand_dims(self._fake_indicator_ph, axis=1)
                        #     for i in range(len(Qs)))
                        """log sum exp for the min"""
                        min_Q_losses = tuple(
                            tf.reduce_sum(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1) * (self._fake_indicator_ph)) * self.min_q_weight * self.temp / tf.reduce_sum(self._fake_indicator_ph)
                            for cat_Q_value in cat_Q_values)
                    else:
                        cat_Q_values = tuple(
                            tf.concat([Q_values_rand[i], Q_values_next_actions[i], Q_values_curr_actions[i]], axis=1)
                            for i in range(len(Qs)))
                
                        """log sum exp for the min"""
                        if type(self.min_q_weight) is not float:
                            min_Q_losses = tuple(
                                tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1), [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) * self.temp
                                for cat_Q_value in cat_Q_values)
                        else:
                            min_Q_losses = tuple(
                                tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1)) * self.min_q_weight * self.temp
                                for cat_Q_value in cat_Q_values)
            else:
                mask = tf.squeeze(tf.math.equal(self._fake_indicator_ph, tf.ones_like(self._fake_indicator_ph)), axis=-1)
                real_mask = tf.squeeze(tf.math.equal(self._real_indicator_ph, tf.ones_like(self._real_indicator_ph)), axis=-1)

                action_tile = tf.tile(tf.expand_dims(self._actions_ph, axis=1), tf.constant([1, self.num_random, 1], tf.int32))
                random_actions_tensor = tf.reshape(tf.random.uniform(tf.shape(action_tile), minval=-1, maxval=1), [-1, *self._action_shape])
                # next_obs_mix = tf.concat([tf.boolean_mask(self._observations_ph, real_mask), tf.boolean_mask(self._next_observations_ph, mask)], axis=0)
                next_obs_tile = tf.reshape(tf.tile(tf.expand_dims(observations_ph, axis=1),
                                    tf.constant([1, self.num_random, 1], tf.int32)),
                                    [-1, *self._observation_shape])

                new_actions_tensor = self._policy.actions([next_obs_tile])
                new_log_pis = tf.reshape(self._policy.log_pis([next_obs_tile], new_actions_tensor), [-1, self.num_random, 1])
                Q_values_next_rand = tuple(
                    tf.reshape(Q([next_obs_tile, random_actions_tensor]), [-1, self.num_random, 1])
                    for Q in Qs)
                Q_values_next_actions = tuple(
                    tf.reshape(Q([next_obs_tile, new_actions_tensor]), [-1, self.num_random, 1])
                    for Q in Qs)

                if self.min_q_version == 3:
                    # importance sammpled version
                    random_density = np.log(0.5 ** self._action_shape[-1])
                    if self.min_q_for_real_fake_states:
                        scale = tf.concat([tf.cast(real_mask, tf.float32)*self.min_q_real_states_weight, tf.cast(mask, tf.float32)], axis=0)
                        cat_Q_values = tuple(
                            tf.concat([Q_values_next_rand[i] - random_density, Q_values_next_actions[i] - tf.stop_gradient(new_log_pis)], axis=1)*scale
                            for i in range(len(Qs)))
                        """log sum exp for the min"""
                        min_Q_losses = tuple(
                            tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1)) * self.min_q_weight * self.temp
                            for cat_Q_value in cat_Q_values)
                    else:
                        cat_Q_values = tuple(
                            tf.boolean_mask(tf.concat([Q_values_next_rand[i] - random_density, Q_values_next_actions[i] - tf.stop_gradient(new_log_pis)], axis=1), mask)
                            for i in range(len(Qs)))
                        """log sum exp for the min"""
                        if type(self.min_q_weight) is not float:
                            min_Q_losses = tuple(
                                tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1), [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) * self.temp
                                for cat_Q_value in cat_Q_values)
                        else:
                            min_Q_losses = tuple(
                                tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1)) * self.min_q_weight * self.temp
                                for cat_Q_value in cat_Q_values)
                else:
                    if self.min_q_for_real_fake_states:
                        scale = tf.concat([tf.cast(real_mask, tf.float32)*self.min_q_real_states_weight, tf.cast(mask, tf.float32)], axis=0)
                        cat_Q_values = tuple(
                            tf.concat([Q_values_next_rand[i], Q_values_next_actions[i]], axis=1)*scale
                            for i in range(len(Qs)))
                        """log sum exp for the min"""
                        min_Q_losses = tuple(
                            tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1)) * self.min_q_weight * self.temp
                            for cat_Q_value in cat_Q_values)
                    else:
                        cat_Q_values = tuple(
                            tf.boolean_mask(tf.concat([Q_values_next_rand[i], Q_values_next_actions[i]], axis=1), mask)
                            for i in range(len(Qs)))
                        """log sum exp for the min"""
                        if type(self.min_q_weight) is not float:
                            min_Q_losses = tuple(
                                tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1), [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) * self.temp
                                for cat_Q_value in cat_Q_values)
                        else:
                            min_Q_losses = tuple(
                                tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=1)) * self.min_q_weight * self.temp
                                for cat_Q_value in cat_Q_values)
                    
            if self.data_subtract:
                """Subtract the log likelihood of data"""
                if self.min_q_for_real or self.min_q_for_fake_only or self.min_q_for_fake_states:
                    min_Q_losses = tuple(
                    # min_Q_losses[i] - tf.reduce_mean(Q_values[i] * self._real_indicator_ph) * self.min_q_weight
                    # for i in range(len(Qs)))
                    min_Q_losses[i] - tf.reduce_sum(Q_values[i] * self._real_indicator_ph) * self.min_q_weight / tf.reduce_sum(self._real_indicator_ph)
                    for i in range(len(Qs)))
                else:
                    if self.policy_prob_weighting:# or self.policy_prob_weighting_no_cql:
                        if type(self.min_q_weight) is not float:
                            min_Q_losses = tuple(
                                min_Q_losses[i] - tf.reduce_mean(tf.reduce_mean(tf.reshape(Q_values[i]*self._policy_prob, [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight)
                                for i in range(len(Qs)))
                        else:
                            min_Q_losses = tuple(
                                min_Q_losses[i] - tf.reduce_mean(Q_values[i]*self._policy_prob) * self.min_q_weight
                                for i in range(len(Qs)))
                    elif self.use_relabel_weights:# or self.policy_prob_weighting_no_cql:
                        if type(self.min_q_weight) is not float:
                            min_Q_losses = tuple(
                                min_Q_losses[i] - tf.reduce_mean(tf.reduce_mean(tf.reshape(Q_values[i]*self.relabel_weights[i], [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight)
                                for i in range(len(Qs)))
                        else:
                            min_Q_losses = tuple(
                                min_Q_losses[i] - tf.reduce_mean(Q_values[i]*self.relabel_weights[i]) * self.min_q_weight
                                for i in range(len(Qs)))
                    else:
                        if type(self.min_q_weight) is not float:
                            min_Q_losses = tuple(
                                min_Q_losses[i] - tf.reduce_mean(tf.reduce_mean(tf.reshape(Q_values[i], [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight)
                                for i in range(len(Qs)))
                        else:
                            min_Q_losses = tuple(
                                min_Q_losses[i] - tf.reduce_mean(Q_values[i]) * self.min_q_weight
                                for i in range(len(Qs)))

            if self.with_lagrange:
                if use_fqe:
                    assert self.fqe_minq
                    self.log_alpha_prime_fqe.append(tf.get_variable('log_alpha_prime_fqe_%d' % idx,
                                                            dtype=tf.float32,
                                                            initializer=0.0))
                    self._alpha_prime_fqe_optimizer.append(tf.train.AdamOptimizer(
                        self._Q_lr, name='alpha_prime_fqe_optimizer_%d' % idx))

                    self.alpha_prime_fqe.append(tf.clip_by_value(tf.exp(self.log_alpha_prime_fqe[-1]), 0.0, 2000000.0))
                    self.orig_min_fqe_Q_losses.append(min_Q_losses)
                    min_Q_losses = tuple(
                        self.alpha_prime_fqe[-1] * (min_Q_losses[i] - self.target_action_gap)
                        for i in range(len(Qs)))
                    alpha_prime_loss = -0.5 * (min_Q_losses[0] + min_Q_losses[1])
                    self.alpha_prime_fqe_loss.append(alpha_prime_loss)
                    self._alpha_prime_fqe_train_op.append(self._alpha_prime_fqe_optimizer[-1].minimize(
                        loss=alpha_prime_loss, var_list=[self.log_alpha_prime_fqe[-1]]))

                    self._training_ops.update({
                        'temperature_alpha_prime_fqe_%d' % idx: self._alpha_prime_fqe_train_op[-1]
                    })
                else:
                    self.log_alpha_prime = tf.get_variable('log_alpha_prime',
                                                            dtype=tf.float32,
                                                            initializer=0.0)
                    self._alpha_prime_optimizer = tf.train.AdamOptimizer(
                        self._Q_lr, name='alpha_prime_optimizer')

                    self.alpha_prime = tf.clip_by_value(tf.exp(self.log_alpha_prime), 0.0, 2000000.0)
                    self.orig_min_Q_losses = min_Q_losses
                    min_Q_losses = tuple(
                        self.alpha_prime * (min_Q_losses[i] - self.target_action_gap)
                        for i in range(len(Qs)))
                    self.alpha_prime_loss = alpha_prime_loss = -0.5 * (min_Q_losses[0] + min_Q_losses[1])
                    self._alpha_prime_train_op = self._alpha_prime_optimizer.minimize(
                        loss=alpha_prime_loss, var_list=[self.log_alpha_prime])

                    self._training_ops.update({
                        'temperature_alpha_prime': self._alpha_prime_train_op
                    })
            if use_fqe:
                assert self.fqe_minq
                self.min_fqe_Q_losses.append(min_Q_losses)
            else:         
                self.min_Q_losses = min_Q_losses
            Q_losses = tuple(
                Q_losses[i] + min_Q_losses[i]
                for i in range(len(Qs)))

        if use_fqe:
            self._fqe_Q_optimizers.append(tuple(
                tf.train.AdamOptimizer(
                    learning_rate=self._Q_lr,
                    name='fqe_{}_{}_optimizer'.format(Q._name, i)
                ) for i, Q in enumerate(Qs)))
            Q_training_ops = tuple(
                tf.contrib.layers.optimize_loss(
                    Q_loss,
                    self.global_step,
                    learning_rate=self._Q_lr,
                    optimizer=Q_optimizer,
                    variables=Q.trainable_variables,
                    increment_global_step=False,
                    summaries=((
                        "loss", "gradients", "gradient_norm", "global_gradient_norm"
                    ) if self._tf_summaries else ()))
                for i, (Q, Q_loss, Q_optimizer)
                in enumerate(zip(Qs, Q_losses, self._fqe_Q_optimizers[-1])))

            self._training_ops.update({'fqe_Q_%d' % idx: tf.group(Q_training_ops)})
        else:
            self._Q_optimizers = tuple(
                tf.train.AdamOptimizer(
                    learning_rate=self._Q_lr,
                    name='{}_{}_optimizer'.format(Q._name, i)
                ) for i, Q in enumerate(self._Qs))
            if self.PCGrad and self.with_min_q:
                Q_training_ops = []
                if self.PCGrad_bellman_only:
                    for i in range(len(self._Q_optimizers)):
                        PCGrad_Q_optimizer = PCGrad(self._Q_optimizers[i], name='PCGrad_Q_%d' % i)
                        Q_target_task = tf.reshape(Q_target, [self.num_tasks, -1, 1])
                        Q_value_task = tf.reshape(Q_values[i], [self.num_tasks, -1, 1])
                        Q_losses_task = [
                            tf.losses.mean_squared_error(
                                labels=Q_target_task[j], predictions=Q_value_task[j], weights=0.5)
                            for j in range(self.num_tasks)]
                        # grads_and_vars = PCGrad_Q_optimizer.compute_gradients(Q_losses_task, var_list=Q.trainable_variables)
                        # train_op = PCGrad_Q_optimizer.apply_gradients(grads_and_vars, self.global_step)
                        train_op = PCGrad_Q_optimizer.minimize(Q_losses_task, self.global_step, var_list=self._Qs[i].trainable_variables)
                        minq_train_op = tf.contrib.layers.optimize_loss(
                                        min_Q_losses[i],
                                        self.global_step,
                                        learning_rate=self._Q_lr,
                                        optimizer=self._Q_optimizers[i],
                                        variables=self._Qs[i].trainable_variables,
                                        increment_global_step=False,
                                        summaries=((
                                            "loss", "gradients", "gradient_norm", "global_gradient_norm"
                                        ) if self._tf_summaries else ()))
                        Q_training_ops.extend([train_op, minq_train_op])
                else:
                    assert not self.with_lagrange
                    for i in range(len(self._Q_optimizers)):
                        PCGrad_Q_optimizer = PCGrad(self._Q_optimizers[i], name='PCGrad_Q_%d' % i)
                        Q_target_task = tf.reshape(Q_target, [self.num_tasks, -1, 1])
                        Q_value_task = tf.reshape(Q_values[i], [self.num_tasks, -1, 1])
                        Q_losses_task = [
                            tf.losses.mean_squared_error(
                                labels=Q_target_task[j], predictions=Q_value_task[j], weights=0.5)
                            for j in range(self.num_tasks)]
                        cat_Q_values_per_task = tf.reshape(cat_Q_values[i], [self.num_tasks, -1, self.num_random*3,  1])
                        if type(self.min_q_weight) is not float:
                            min_q_loss_task = [tf.reduce_mean(tf.reduce_mean(tf.reshape(tf.math.reduce_logsumexp(cat_Q_values_per_task[j] / self.temp, axis=1), [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) * self.temp
                                                for j in range(self.num_tasks)]
                            if self.data_subtract:
                                min_q_loss_task = [min_q_loss_task[j] - tf.reduce_mean(tf.reduce_mean(tf.reshape(Q_value_task[j], [self.num_tasks, -1, 1]), axis=[1, 2]) * self.min_q_weight) for j in range(self.num_tasks)]
                        else:
                            min_q_loss_task = [tf.reduce_mean(tf.math.reduce_logsumexp(cat_Q_values_per_task[j] / self.temp, axis=1)) * self.min_q_weight * self.temp
                                                for j in range(self.num_tasks)]
                            if self.data_subtract:
                                min_q_loss_task = [min_q_loss_task[j] - tf.reduce_mean(Q_value_task[j]) * self.min_q_weight for j in range(self.num_tasks)]
                        Q_losses_task = [Q_losses_task[j] + min_q_loss_task[j] for j in range(self.num_tasks)]
                        train_op = PCGrad_Q_optimizer.minimize(Q_losses_task, self.global_step, var_list=self._Qs[i].trainable_variables)
                        Q_training_ops.append(train_op)
                Q_training_ops = tuple(Q_training_ops)
            else:
                # self._Q_gradient_norm_tasks = []
                self._Q_gradient_norm = []
                for i, Q in enumerate(self._Qs):
                    # Q_target_task = tf.reshape(Q_target, [self.num_tasks, -1, 1])
                    # Q_value_task = tf.reshape(Q_values[i], [self.num_tasks, -1, 1])
                    # Q_losses_task = [
                    #     tf.losses.mean_squared_error(
                    #         labels=Q_target_task[j], predictions=Q_value_task[j], weights=0.5)
                    #     for j in range(self.num_tasks)]
                    # Q_gradients_task = [
                    #     tf.gradients(
                    #         Q_losses_task[j], Q.trainable_variables)
                    #     for j in range(self.num_tasks)]
                    # Q_gradients_task_list = [[grad for grad in Q_gradient_task if grad is not None]
                    #                         for Q_gradient_task in Q_gradients_task]
                    # Q_gradients_norm_task = tf.stack([tf.math.l2_normalize(tf.concat([tf.reshape(grad, [-1,]) for grad in Q_gradient_task], axis=0))
                    #                         for Q_gradient_task in Q_gradients_task_list])
                    # self._Q_gradient_norm_tasks.append(Q_gradients_norm_task)
                    Q_gradients = tf.gradients(Q_losses[i], Q.trainable_variables)
                    Q_gradients_list = [grad for grad in Q_gradients if grad is not None]
                    Q_gradients_norm_all = tf.math.l2_normalize(tf.concat([tf.reshape(grad, [-1,]) for grad in Q_gradients_list], axis=0))
                    self._Q_gradient_norm.append(Q_gradients_norm_all)
                Q_training_ops = tuple(
                    tf.contrib.layers.optimize_loss(
                        Q_loss,
                        self.global_step,
                        learning_rate=self._Q_lr,
                        optimizer=Q_optimizer,
                        variables=Q.trainable_variables,
                        increment_global_step=False,
                        summaries=((
                            "loss", "gradients", "gradient_norm", "global_gradient_norm"
                        ) if self._tf_summaries else ()))
                    for i, (Q, Q_loss, Q_optimizer)
                    in enumerate(zip(self._Qs, Q_losses, self._Q_optimizers)))
                if self.with_min_q and self.relabel_weight_temp_adaptive:
                    Q_training_ops_ema = []
                    for i in range(len(Q_training_ops)):
                        with tf.control_dependencies([Q_training_ops[i]]):
                            Q_training_ops_ema.append(self.relabel_weight_temp_emas[i].apply([self.relabel_Q_diff_abs_mean[i]]))
                    Q_training_ops = tuple(Q_training_ops_ema)

            self._training_ops.update({'Q': tf.group(Q_training_ops)})

    def _init_actor_update(self):
        """Create minimization operations for policy and entropy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """
        if self.multitask_type != 'hipi':
            observation_ph = self._observations_ph
        else:
            observation_ph = self.relabeled_observations

        actions = self._policy.actions([observation_ph])
        log_pis = self._policy.log_pis([observation_ph], actions)

        if self.multitask_type != 'hipi':
            assert log_pis.shape.as_list() == [None, 1]

        log_alpha = self._log_alpha = tf.get_variable(
            'log_alpha',
            dtype=tf.float32,
            initializer=0.0)
        alpha = tf.exp(log_alpha)

        if isinstance(self._target_entropy, Number):
            if self.with_min_q and self.backup_for_one_step:
                alpha_loss = -tf.reduce_sum(
                    log_alpha * tf.stop_gradient(log_pis * (1.0 - self._fake_indicator_ph) + self._target_entropy)) / tf.reduce_sum(1.0 - self._fake_indicator_ph)
            elif self.with_min_q and self.policy_for_fake_only:
                alpha_loss = -tf.reduce_sum(
                    log_alpha * tf.stop_gradient(log_pis * self._real_indicator_ph + self._target_entropy)) / tf.reduce_sum(self._real_indicator_ph)
            elif self.with_min_q and self.use_relabel_weights:
                alpha_loss = -tf.reduce_mean(
                    log_alpha * tf.stop_gradient(log_pis*tf.reduce_min(self.relabel_weights, axis=0) + self._target_entropy))
            else:
                alpha_loss = -tf.reduce_mean(
                    log_alpha * tf.stop_gradient(log_pis + self._target_entropy))

            self._alpha_optimizer = tf.train.AdamOptimizer(
                self._policy_lr, name='alpha_optimizer')
            self._alpha_train_op = self._alpha_optimizer.minimize(
                loss=alpha_loss, var_list=[log_alpha])

            self._training_ops.update({
                'temperature_alpha': self._alpha_train_op
            })

        self._alpha = alpha

        if self._action_prior == 'normal':
            policy_prior = tf.contrib.distributions.MultivariateNormalDiag(
                loc=tf.zeros(self._action_shape),
                scale_diag=tf.ones(self._action_shape))
            policy_prior_log_probs = policy_prior.log_prob(actions)
        elif self._action_prior == 'uniform':
            policy_prior_log_probs = 0.0

        Q_log_targets = tuple(
            Q([observation_ph, actions])
            for Q in self._Qs)
        min_Q_log_target = tf.reduce_min(Q_log_targets, axis=0)

        if self._reparameterize:
            policy_kl_losses = (
                alpha * log_pis
                - min_Q_log_target
                - policy_prior_log_probs)
            policy_log_prob = self._policy.log_pis([observation_ph], self._actions_ph)
            policy_bc_loss = (
                alpha * log_pis
                - policy_log_prob
                - policy_prior_log_probs)
        else:
            raise NotImplementedError

        assert policy_kl_losses.shape.as_list() == [None, 1]


        if self.with_min_q and self.backup_for_one_step:
            policy_loss = tf.reduce_sum(policy_kl_losses * (1.0 - self._fake_indicator_ph)) / tf.reduce_sum(1.0 - self._fake_indicator_ph)
            policy_bc_loss = tf.reduce_sum(policy_bc_loss * (1.0 - self._fake_indicator_ph)) / tf.reduce_sum(1.0 - self._fake_indicator_ph)
        # elif self.with_min_q and self.backup_with_uniform:
        #     policy_prob = self._policy_prob = self._policy.pis([self._observations_ph], self._actions_ph)
        #     policy_prob = tf.maximum(policy_prob / tf.reduce_min(policy_prob[tf.math.equal(self._fake_indicator_ph, 1.0)]), 1e-4)
        #     policy_loss_real = tf.reduce_sum(policy_kl_losses * (1.0 - self._fake_indicator_ph)) / tf.reduce_sum(1.0 - self._fake_indicator_ph)
        #     policy_bc_loss_real = tf.reduce_sum(policy_bc_loss * (1.0 - self._fake_indicator_ph)) / tf.reduce_sum(1.0 - self._fake_indicator_ph)
        #     policy_loss_fake = tf.reduce_sum(policy_kl_losses / policy_prob * self._fake_indicator_ph) / tf.reduce_sum(self._fake_indicator_ph)
        #     policy_bc_loss_fake = tf.reduce_sum(policy_bc_loss / policy_prob * self._fake_indicator_ph) / tf.reduce_sum(self._fake_indicator_ph)
        #     policy_loss_fake = policy_loss_fake / tf.reduce_sum(1.0 / policy_prob * self._fake_indicator_ph)
        #     policy_bc_loss_fake = policy_bc_loss_fake / tf.reduce_sum(1.0 / policy_prob * self._fake_indicator_ph)
        #     policy_loss = policy_loss_real * self._real_ratio + policy_loss_fake * (1.0 - self._real_ratio)
        #     policy_bc_loss = policy_bc_loss_real * self._real_ratio + policy_bc_loss_fake * (1.0 - self._real_ratio)
        elif self.with_min_q and self.policy_for_fake_only:
            policy_loss = tf.reduce_sum(policy_kl_losses * self._real_indicator_ph) / tf.reduce_sum(self._real_indicator_ph)
            policy_bc_loss = tf.reduce_sum(policy_bc_loss * self._real_indicator_ph) / tf.reduce_sum(self._real_indicator_ph)
        elif self.with_min_q and self.policy_prob_weighting:
            assert self.use_old_policy
            policy_prob = self._policy_prob =  tf.stop_gradient(self._old_policy.pis([observation_ph], self._actions_ph))
            # policy_prob = self._policy_prob = tf.maximum(policy_prob, 1e-5) / tf.maximum(tf.reduce_max(policy_prob), 1e-5)
            if self.num_tasks > 1:
                policy_prob = tf.reshape(policy_prob, [self.num_tasks, -1, 1])
                policy_prob = self._policy_prob = tf.reshape(policy_prob / tf.expand_dims(tf.reduce_sum(policy_prob, axis=1), axis=1) + 0.1, [-1, 1])
            else:
                policy_prob = self._policy_prob = policy_prob / tf.reduce_sum(policy_prob) + 0.1
            policy_loss = tf.reduce_mean(policy_kl_losses * policy_prob)
            policy_bc_loss = tf.reduce_mean(policy_bc_loss * policy_prob)
        elif self.with_min_q and self.use_relabel_weights:
            policy_loss = tf.reduce_mean(policy_kl_losses * tf.reduce_min(self.relabel_weights, axis=0))
            policy_bc_loss = tf.reduce_mean(policy_bc_loss * tf.reduce_min(self.relabel_weights, axis=0))
        else:
            policy_loss = tf.reduce_mean(policy_kl_losses)
            policy_bc_loss = tf.reduce_mean(policy_bc_loss)

        self._policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name="policy_optimizer")
        if self.PCGrad:
            PCGrad_policy_optimizer = PCGrad(self._policy_optimizer, name='PCGrad_policy')
            policy_losses_task = [policy_kl_losses[j] for j in range(self.num_tasks)]
            self.policy_train_op = PCGrad_policy_optimizer.minimize(policy_losses_task, self.global_step, var_list=self._policy.trainable_variables)
        else:
            self.policy_train_op = tf.contrib.layers.optimize_loss(
                policy_loss,
                self.global_step,
                learning_rate=self._policy_lr,
                optimizer=self._policy_optimizer,
                variables=self._policy.trainable_variables,
                increment_global_step=False,
                summaries=(
                    "loss", "gradients", "gradient_norm", "global_gradient_norm"
                ) if self._tf_summaries else ())
        self.policy_bc_train_op = tf.contrib.layers.optimize_loss(
            policy_bc_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._policy_optimizer,
            variables=self._policy.trainable_variables,
            increment_global_step=False,
            summaries=(
                "loss", "gradients", "gradient_norm", "global_gradient_norm"
            ) if self._tf_summaries else ())

        self._training_ops.update({'policy_train_op': self.policy_train_op})

    def _init_bc_policy_update(self):
        """Create minimization operations for bc policy.

        Creates a `tf.optimizer.minimize` operations for updating
        policy and entropy with gradient descent, and adds them to
        `self._training_ops` attribute.
        """
        if self.multitask_type != 'hipi':
            observation_ph = self._observations_ph
        else:
            observation_ph = self.relabeled_observations
        policy_log_prob = self._bc_policy.log_pis([observation_ph], self._actions_ph)
        policy_bc_losses = -policy_log_prob
        # actions = self._bc_policy.actions([self._observations_ph])
        # policy_bc_losses = tf.reduce_mean(tf.square(actions - self._actions_ph), axis=-1, keepdims=True)

        assert policy_bc_losses.shape.as_list() == [None, 1]

        self.policy_bc_loss = policy_bc_loss = tf.reduce_mean(policy_bc_losses)

        self._bc_policy_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._policy_lr,
            name="bc_policy_optimizer")
        self.bc_policy_train_op = tf.contrib.layers.optimize_loss(
            policy_bc_loss,
            self.global_step,
            learning_rate=self._policy_lr,
            optimizer=self._bc_policy_optimizer,
            variables=self._bc_policy.trainable_variables,
            increment_global_step=False,
            # clip_gradients=0.1,#1.0,
            summaries=(
                "loss", "gradients", "gradient_norm", "global_gradient_norm"
            ) if self._tf_summaries else ())

    def _init_relabeling(self):
        if self.with_min_q and self.use_relabel_weights:
            if self.relabel_weight_temp_adaptive:
                if self.relabel_weight_temp_adaptive_per_task:
                    self.relabel_weight_temp_ph = tf.placeholder(
                                                tf.float32,
                                                shape=(len(self._Qs), self.num_tasks),
                                                name='relabel_weight_temp')
                else:
                    self.relabel_weight_temp_ph = tf.placeholder(
                                                    tf.float32,
                                                    shape=(len(self._Qs)),
                                                    name='relabel_weight_temp')
                self.relabel_weight_temp_emas = tuple(tf.train.ExponentialMovingAverage(decay=self.relabel_weight_temp_tau)
                        for _ in range(len(self._Qs)))
            else:
                self.relabel_weight_temp_ph = [self.relabel_weight_temp for i in range(len(self._Qs))]
            mask_orig = tf.math.equal(self._relabel_masks, tf.zeros_like(self._relabel_masks))
            mask_relabel = tf.math.equal(self._relabel_masks, tf.ones_like(self._relabel_masks))
            Q_values = tuple(
                Q([self._observations_ph, self._actions_ph])
                for Q in self._Qs)
            self.relabel_weights = []
            self.relabel_Q_diff = []
            self.relabel_Q_diff_abs_mean = []
            for i, Q_value in enumerate(Q_values):
                Q_value_orig_task = tf.reshape(tf.boolean_mask(Q_value, mask_orig), [self.num_tasks, -1, 1])
                Q_value_task = tf.reshape(Q_value, [self.num_tasks, -1, 1])
                Q_value_orig_task_avg = tf.expand_dims(tf.reduce_mean(Q_value_orig_task, axis=1), axis=1)
                if self.use_relabel_weights_diff:
                    action_tile = tf.tile(tf.expand_dims(self._actions_ph, axis=1), tf.constant([1, self.num_random, 1], tf.int32))
                    random_actions_tensor = tf.reshape(tf.random.uniform(tf.shape(action_tile), minval=-1, maxval=1), [-1, *self._action_shape])
                    obs_tile = tf.reshape(tf.tile(tf.expand_dims(self._observations_ph, axis=1),
                                        tf.constant([1, self.num_random, 1], tf.int32)),
                                        [-1, *self._observation_shape])
                    next_obs_tile = tf.reshape(tf.tile(tf.expand_dims(self._next_observations_ph, axis=1),
                                        tf.constant([1, self.num_random, 1], tf.int32)),
                                        [-1, *self._observation_shape])

                    curr_actions_tensor = self._policy.actions([obs_tile])
                    curr_log_pis = tf.reshape(self._policy.log_pis([obs_tile], curr_actions_tensor), [-1, self.num_random, 1])
                    new_actions_tensor = self._policy.actions([next_obs_tile])
                    new_log_pis = tf.reshape(self._policy.log_pis([next_obs_tile], new_actions_tensor), [-1, self.num_random, 1])
                    Q_values_rand = tf.reshape(self._Qs[i]([obs_tile, random_actions_tensor]), [-1, self.num_random, 1])
                    Q_values_curr_actions = tf.reshape(self._Qs[i]([obs_tile, curr_actions_tensor]), [-1, self.num_random, 1])
                    Q_values_next_actions = tf.reshape(self._Qs[i]([obs_tile, new_actions_tensor]), [-1, self.num_random, 1])
                    random_density = np.log(0.5 ** self._action_shape[-1])
                    cat_Q_value = tf.concat([Q_values_rand - random_density, Q_values_next_actions - tf.stop_gradient(new_log_pis), Q_values_curr_actions - tf.stop_gradient(curr_log_pis)], axis=1)
                    cat_Q_value = tf.reshape(cat_Q_value, [self.num_tasks, -1, self.num_random*3, 1])
                    min_Q_losses_orig = tf.reduce_sum(tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=2) * (1.0 - tf.reshape(self._relabel_masks, [self.num_tasks, -1, 1])), axis=[1,2]) * self.temp / tf.reduce_sum((1.0 - tf.reshape(self._relabel_masks, [self.num_tasks, -1, 1])), axis=[1,2])
                    min_Q_losses_orig = tf.expand_dims(min_Q_losses_orig, axis=1) - tf.squeeze(Q_value_orig_task_avg, axis=-1)
                    min_Q_losses_relabel = tf.math.reduce_logsumexp(cat_Q_value / self.temp, axis=2) * self.temp - Q_value_task
                    Q_diff = tf.reshape(tf.expand_dims(min_Q_losses_orig, axis=1) - min_Q_losses_relabel, [-1, 1])
                else:
                    Q_diff = tf.reshape(Q_value_task - Q_value_orig_task_avg, [-1, 1])
                if self.relabel_weight_temp_adaptive:
                    relabel_weight_temp_avg = tf.clip_by_value(self.relabel_weight_temp_ph[i], self.relabel_weight_temp_min, self.relabel_weight_temp_max)
                    # relabel_weight_temp_avg = tf.maximum(self.relabel_weight_temp_ph[i], self.relabel_weight_temp_min)
                    if self.relabel_weight_temp_adaptive_per_task:
                        Q_diff_normalized = tf.reshape(tf.reshape(Q_diff, [-1, self.num_tasks, 1]) / tf.expand_dims(tf.expand_dims(relabel_weight_temp_avg, axis=0), axis=2), [-1, 1])
                    else:
                        Q_diff_normalized = Q_diff / relabel_weight_temp_avg
                    if self.relabel_weight_orig_task:
                        if 0 < self.relabel_prob < 1:
                            logits = tf.log([1.0 - self.relabel_prob, self.relabel_prob])
                            mask = tf.squeeze(
                                tf.random.categorical(
                                logits[None], num_samples=self.sampler._batch_size))
                            mask = tf.cast(mask, tf.float32)[:, None] * (1.0 - self._relabel_masks)
                            relabel_weight = tf.stop_gradient(tf.math.sigmoid(Q_diff_normalized) * (self._relabel_masks + mask) + (1.0 - self._relabel_masks - mask))
                        else:
                            relabel_weight = tf.stop_gradient(tf.math.sigmoid(Q_diff_normalized))
                    else:
                        relabel_weight = tf.stop_gradient(tf.math.sigmoid(Q_diff_normalized) * self._relabel_masks + (1.0 - self._relabel_masks))
                else:
                    relabel_weight = tf.stop_gradient(tf.math.sigmoid(Q_diff / self.relabel_weight_temp_ph[i]) * self._relabel_masks + (1.0 - self._relabel_masks))
                self.relabel_weights.append(relabel_weight)
                self.relabel_Q_diff.append(Q_diff)
                if self.relabel_weight_temp_adaptive_per_task:
                    self.relabel_Q_diff_abs_mean.append(tf.reduce_mean(tf.reshape(tf.abs(Q_diff), [-1, self.num_tasks, 1]), axis=[0,2]))
                else:                
                    self.relabel_Q_diff_abs_mean.append(tf.reduce_mean(tf.abs(Q_diff)))
            self.relabel_weights = tuple(self.relabel_weights)
            self.relabel_Q_diff = tuple(self.relabel_Q_diff)
            self.relabel_Q_diff_abs_mean = tuple(self.relabel_Q_diff_abs_mean)
        elif self.multitask_type == 'hipi' or self.multitask_type == "random":
            if self.goal_conditioned:
                observations = self._observations_ph[:, :-self.goal_dim]
                next_observations = self._next_observations_ph[:, :-self.goal_dim]
                obs_tile = tf.reshape(tf.tile(tf.expand_dims(observations, axis=1), [1, self.num_tasks, 1]), [-1, (self._observation_shape[0]-self.goal_dim)])
                next_obs_tile = tf.reshape(tf.tile(tf.expand_dims(observations, axis=1), [1, self.num_tasks, 1]), [-1, (self._observation_shape[0]-self.goal_dim)])
                action_tile = tf.reshape(tf.tile(tf.expand_dims(self._actions_ph, axis=1), [1, self.num_tasks, 1]), [-1, *self._action_shape])
                orig_tasks = self._observations_ph[:, -self.goal_dim:]
                orig_tasks_tile = tf.tile(tf.expand_dims(orig_tasks, axis=1), [1, self.num_tasks, 1])
                tasks = tf.reshape(tf.zeros_like(orig_tasks_tile) + tf.expand_dims(self._tasks_ph, axis=0), [-1, self.goal_dim])
                candidate_tasks = self._tasks_ph
            else:
                observations = self._observations_ph[:, :-self.num_tasks]
                next_observations = self._next_observations_ph[:, :-self.num_tasks]
                obs_tile = tf.reshape(tf.tile(tf.expand_dims(observations, axis=1), [1, self.num_tasks, 1]), [-1, (self._observation_shape[0]-self.num_tasks)])
                next_obs_tile = tf.reshape(tf.tile(tf.expand_dims(observations, axis=1), [1, self.num_tasks, 1]), [-1, (self._observation_shape[0]-self.num_tasks)])
                action_tile = tf.reshape(tf.tile(tf.expand_dims(self._actions_ph, axis=1), [1, self.num_tasks, 1]), [-1, *self._action_shape])
                orig_tasks = self._observations_ph[:, -self.num_tasks:]
                orig_tasks_tile = tf.tile(tf.expand_dims(orig_tasks, axis=1), [1, self.num_tasks, 1])
                tasks = tf.reshape(tf.zeros_like(orig_tasks_tile) + tf.expand_dims(tf.eye(self.num_tasks), axis=0), [-1, self.num_tasks])
                candidate_tasks = tf.eye(self.num_tasks)
            next_relabelled_obs = tf.concat([next_obs_tile, tasks], axis=-1)

            sampled_actions_tiled = self._policy.actions([next_relabelled_obs])
            q_vals = tuple(
                Q([next_relabelled_obs, sampled_actions_tiled])
                for Q in self._Qs)
            q_vals = tf.reduce_min(q_vals, axis=0)
            q_vals_vec = tf.reshape(q_vals, [-1, self.num_tasks])
            rewards_vec = tf.reshape(self._rewards_ph, [-1, self.num_tasks])
            dones_vec = tf.reshape(self._terminals_ph, [-1, self.num_tasks])

            relabelled_obs = tf.concat([obs_tile, tasks], axis=-1)
            log_pi = self._policy.log_pis([relabelled_obs], action_tile)
            log_pi_vec = tf.reshape(log_pi, [-1, self.num_tasks])

            logits_vec = (
                rewards_vec - log_pi_vec + self._discount * (1.0 - dones_vec) * q_vals_vec)
            if self.multitask_type == "random":
              logits_vec = tf.ones_like(logits_vec)

            logits_vec = logits_vec - tf.math.reduce_logsumexp(
                  logits_vec, axis=0)[None]
            relabel_indices = tf.random.categorical(logits=logits_vec, num_samples=1)
            relabelled_tasks = tf.stop_gradient(tf.gather(candidate_tasks, tf.squeeze(relabel_indices)))
            if 0 < self.relabel_prob < 1:
                logits = tf.log([1.0 - self.relabel_prob, self.relabel_prob])
                mask = tf.squeeze(
                    tf.random.categorical(
                    logits[None], num_samples=self.sampler._batch_size))
                mask = tf.cast(mask, tf.float32)[:, None]
                relabelled_tasks = tf.stop_gradient(mask * orig_tasks + (1 - mask) * relabelled_tasks)
            self.relabelled_tasks = relabelled_tasks
            self.relabeled_observations = tf.concat([observations, relabelled_tasks], axis=-1)
            self.relabeled_next_observations = tf.concat([next_observations, relabelled_tasks], axis=-1)
            # self.relabeled_rewards = tf.gather(tf.reshape(self._rewards_ph, [-1, self.num_tasks, 1]), tf.squeeze(relabel_indices), axis=1)
            # self.relabeled_terminals = tf.gather(tf.reshape(self._terminals_ph, [-1, self.num_tasks, 1]), tf.squeeze(relabel_indices), axis=1)
            if not self.goal_conditioned:
                self.relabeled_rewards = tf.expand_dims(tf.reduce_sum(tf.reshape(self._rewards_ph, [-1, self.num_tasks]) * relabelled_tasks, axis=-1), axis=1)
                self.relabeled_terminals = tf.expand_dims(tf.reduce_sum(tf.reshape(self._terminals_ph, [-1, self.num_tasks]) * relabelled_tasks, axis=-1), axis=1)
            else:
                relabel_indices = tf.one_hot(tf.squeeze(relabel_indices), self.num_tasks)
                self.relabeled_rewards = tf.expand_dims(tf.reduce_sum(tf.reshape(self._rewards_ph, [-1, self.num_tasks]) * relabel_indices, axis=-1), axis=1)
                self.relabeled_terminals = tf.expand_dims(tf.reduce_sum(tf.reshape(self._terminals_ph, [-1, self.num_tasks]) * relabel_indices, axis=-1), axis=1)
            self.relabel_weights = tuple(tf.ones_like(self.relabeled_rewards) for _ in range(len(self._Qs)))
        else:
            self.relabel_weights = tuple(tf.ones_like(self._rewards_ph) for _ in range(len(self._Qs)))

    def _init_training(self):
        self._update_target(tau=1.0)

    def _update_target(self, tau=None):
        tau = tau or self._tau

        for Q, Q_target in zip(self._Qs, self._Q_targets):
            source_params = Q.get_weights()
            target_params = Q_target.get_weights()
            Q_target.set_weights([
                tau * source + (1.0 - tau) * target
                for source, target in zip(source_params, target_params)
            ])
        if self.use_fqe:
            for fqe_Qs, fqe_Q_targets in zip(self._fqe_Qs, self._fqe_Q_targets):
                for Q, Q_target in zip(fqe_Qs, fqe_Q_targets):
                    source_params = Q.get_weights()
                    target_params = Q_target.get_weights()
                    Q_target.set_weights([
                        tau * source + (1.0 - tau) * target
                        for source, target in zip(source_params, target_params)
                    ])

    def _update_old_policy(self, tau=1.0):
        source_params = self._policy.get_weights()
        target_params = self._old_policy.get_weights()
        self._old_policy.set_weights([
            tau * source + (1.0 - tau) * target
            for source, target in zip(source_params, target_params)
        ])


    def _do_training(self, iteration, batch):
        """Runs the operations for updating training and target ops."""

        # self._training_progress.update()
        # self._training_progress.set_description()
        if self.use_old_policy and self._epoch % self.old_policy_update_interval == 0:
            self._update_old_policy()

        feed_dict = self._get_feed_dict(iteration, batch)

        if self.with_min_q and self._epoch < self.policy_eval_start:
            self._training_ops['policy_train_op'] = self.policy_bc_train_op
            self._session.run(self.policy_bc_train_op, feed_dict)
        else:
            self._training_ops['policy_train_op'] = self.policy_train_op
            self._session.run(self._training_ops, feed_dict)

        if iteration % self._target_update_interval == 0:
            # Run target ops here.
            self._update_target()
        if iteration % self._epoch_length == 0:
            print('Epoch: %d' % self._epoch)

    def _get_feed_dict(self, iteration, batch):
        """Construct TensorFlow feed_dict from sample batch."""

        feed_dict = {
            self._observations_ph: batch['observations'],
            self._actions_ph: batch['actions'],
            self._next_observations_ph: batch['next_observations'],
            self._rewards_ph: batch['rewards'],
            self._terminals_ph: batch['terminals'],
            self._real_indicator_ph: batch.get('real_indicator', np.ones_like(batch['rewards'])),
            self._fake_indicator_ph: batch.get('fake_indicator', np.zeros_like(batch['rewards'])),
        }
        if self.use_relabel_weights:
            feed_dict[self._relabel_masks] = batch['relabel_masks']

        if self.with_min_q and self.relabel_weight_temp_adaptive:
            relabel_Q_diff_avg = np.array(self._session.run([self.relabel_weight_temp_emas[i].average(self.relabel_Q_diff_abs_mean[i]) for i in range(len(self._Qs))], {}))
            feed_dict[self.relabel_weight_temp_ph] = relabel_Q_diff_avg

        if self.multitask_type == 'hipi' and self.goal_conditioned:
            feed_dict[self._tasks_ph] = batch['tasks']

        if self._store_extra_policy_info:
            feed_dict[self._log_pis_ph] = batch['log_pis']
            feed_dict[self._raw_actions_ph] = batch['raw_actions']

        if iteration is not None:
            feed_dict[self._iteration_ph] = iteration

        return feed_dict

    def _evaluation_batch(self, batch_size=None, task_id=None):
        batch_size = batch_size or self.sampler._batch_size
        env_batch_size = int(batch_size*self._real_ratio)
        model_batch_size = batch_size - env_batch_size

        ## can sample from the env pool even if env_batch_size == 0
        if self.goal_conditioned and self.multitask_type == 'relabel-all':
            env_batch = self._pool.random_batch(env_batch_size, task_id=task_id, relabel_all=True)
        else:
            env_batch = self._pool.random_batch(env_batch_size, task_id=task_id, use_hipi=(self.multitask_type=='hipi'))
        return env_batch

    def get_diagnostics(self,
                        iteration,
                        batch,
                        training_paths,
                        evaluation_paths):
        """Return diagnostic information as ordered dictionary.

        Records mean and standard deviation of Q-function and state
        value function, and TD-loss (mean squared Bellman error)
        for the sample batch.

        Also calls the `draw` method of the plotter, if plotter defined.
        """

        feed_dict = self._get_feed_dict(iteration, batch)
        if not self.use_fqe:
            if not self.with_lagrange:
                if not self.with_min_q:
                    (Q_values, Q_losses, alpha, global_step) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self._alpha,
                         self.global_step),
                        feed_dict)

                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'alpha': alpha,
                    })
                elif self.backup_with_uniform or self.policy_prob_weighting:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step, policy_prob) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.min_Q_losses,
                         self._Q_values_curr_actions,
                         self._alpha,
                         self.global_step,
                         self._policy_prob),
                        feed_dict)
                    if np.any(np.isnan(policy_prob)):
                        import pdb; pdb.set_trace()
                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                        'min_Q_loss': np.mean(min_Q_losses),
                        'alpha': alpha,
                        'policy_prob_min': policy_prob.min(),
                        'policy_prob_max': policy_prob.max(),
                        'policy_prob_mean': policy_prob.mean()
                    })
                elif self.multitask_type == 'hipi':
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step, relabelled_tasks) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.min_Q_losses,
                         self._Q_values_curr_actions,
                         self._alpha,
                         self.global_step,
                         self.relabelled_tasks),
                        feed_dict)
                    try:
                        diagnostics = OrderedDict({
                            'Q-avg': np.mean(Q_values),
                            'Q-std': np.std(Q_values),
                            'Q_loss': np.mean(Q_losses),
                            'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                            'min_Q_loss': np.mean(min_Q_losses),
                            'alpha': alpha,
                        })
                    except:
                        import pdb; pdb.set_trace()
                elif self.use_relabel_weights:
                    if not self.relabel_weight_temp_adaptive:
                        (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step, relabel_weights, relabel_Q_diff) = self._session.run(
                            (self._Q_values,
                             self._Q_losses,
                             self.min_Q_losses,
                             self._Q_values_curr_actions,
                             self._alpha,
                             self.global_step,
                             self.relabel_weights,
                             self.relabel_Q_diff),
                            feed_dict)
                        relabel_weights = np.array(relabel_weights)
                        relabel_Q_diff = np.array(relabel_Q_diff)
                        if np.any(np.isnan(relabel_weights)):
                            import pdb; pdb.set_trace()
                        try:
                            diagnostics = OrderedDict({
                                'Q-avg': np.mean(Q_values),
                                'Q-std': np.std(Q_values),
                                'Q_loss': np.mean(Q_losses),
                                'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                                'min_Q_loss': np.mean(min_Q_losses),
                                'alpha': alpha,
                                'relabel_weights_min': relabel_weights.min(),
                                'relabel_weights_max': relabel_weights.max(),
                                'relabel_weights_mean': relabel_weights.mean(),
                                'relabel_weights_max_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_weights_mean_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_max': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_Q_diff_mean': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_min': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].min(),
                            })
                        except:
                            import pdb; pdb.set_trace()
                    else:
                        (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step, relabel_weights, relabel_Q_diff) = self._session.run(
                            (self._Q_values,
                             self._Q_losses,
                             self.min_Q_losses,
                             self._Q_values_curr_actions,
                             self._alpha,
                             self.global_step,
                             self.relabel_weights,
                             self.relabel_Q_diff),
                            feed_dict)
                        relabel_weights = np.array(relabel_weights)
                        relabel_Q_diff = np.array(relabel_Q_diff)
                        relabel_weight_temp_avgs = feed_dict[self.relabel_weight_temp_ph]
                        if np.any(np.isnan(relabel_weights)):
                            import pdb; pdb.set_trace()
                        try:
                            diagnostics = OrderedDict({
                                'Q-avg': np.mean(Q_values),
                                'Q-std': np.std(Q_values),
                                'Q_loss': np.mean(Q_losses),
                                'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                                'min_Q_loss': np.mean(min_Q_losses),
                                'alpha': alpha,
                                'relabel_weights_min': relabel_weights.min(),
                                'relabel_weights_max': relabel_weights.max(),
                                'relabel_weights_mean': relabel_weights.mean(),
                                'relabel_weights_max_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_weights_mean_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_max': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_Q_diff_mean': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_min': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].min(),
                                'relabel_weight_temp_avgs': np.mean(relabel_weight_temp_avgs),
                            })
                        except:
                            import pdb; pdb.set_trace()
                        if self.relabel_weight_temp_adaptive_per_task:
                            for i in range(self.num_tasks):
                                diagnostics.update({
                                    'relabel_weight_temp_avgs_task%s' % i: np.mean(relabel_weight_temp_avgs, axis=0)[i],
                                    })
                else:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.min_Q_losses,
                         self._Q_values_curr_actions,
                         self._alpha,
                         self.global_step),
                        feed_dict)
                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                        'min_Q_loss': np.mean(min_Q_losses),
                        'alpha': alpha,
                    })
            else:
                if self.policy_prob_weighting:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha_prime, alpha_prime_loss, alpha, policy_prob, global_step) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.orig_min_Q_losses,
                         self._Q_values_curr_actions,
                         self.alpha_prime,
                         self.alpha_prime_loss,
                         self._alpha,
                         self._policy_prob,
                         self.global_step),
                        feed_dict)

                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                        'min_Q_loss': np.mean(min_Q_losses),
                        'alpha_prime': alpha_prime,
                        'alpha': alpha,
                        'max_policy_prob': np.amax(policy_prob),
                        'mean_policy_prob': np.mean(policy_prob),
                    })
                elif self.use_relabel_weights:
                    if not self.relabel_weight_temp_adaptive:
                        (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, alpha_prime, alpha_prime_loss, global_step, relabel_weights, relabel_Q_diff) = self._session.run(
                            (self._Q_values,
                             self._Q_losses,
                             self.min_Q_losses,
                             self._Q_values_curr_actions,
                             self._alpha,
                             self.alpha_prime,
                             self.alpha_prime_loss,
                             self.global_step,
                             self.relabel_weights,
                             self.relabel_Q_diff),
                            feed_dict)
                        relabel_weights = np.array(relabel_weights)
                        relabel_Q_diff = np.array(relabel_Q_diff)
                        if np.any(np.isnan(relabel_weights)):
                            import pdb; pdb.set_trace()
                        try:
                            diagnostics = OrderedDict({
                                'Q-avg': np.mean(Q_values),
                                'Q-std': np.std(Q_values),
                                'Q_loss': np.mean(Q_losses),
                                'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                                'min_Q_loss': np.mean(min_Q_losses),
                                'alpha': alpha,
                                'alpha_prime': alpha_prime,
                                'relabel_weights_min': relabel_weights.min(),
                                'relabel_weights_max': relabel_weights.max(),
                                'relabel_weights_mean': relabel_weights.mean(),
                                'relabel_weights_max_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_weights_mean_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_max': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_Q_diff_mean': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_min': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].min(),
                            })
                        except:
                            import pdb; pdb.set_trace()
                    else:
                        (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, alpha_prime, alpha_prime_loss, global_step, relabel_weights, relabel_Q_diff) = self._session.run(
                            (self._Q_values,
                             self._Q_losses,
                             self.min_Q_losses,
                             self._Q_values_curr_actions,
                             self._alpha,
                             self.alpha_prime,
                             self.alpha_prime_loss,
                             self.global_step,
                             self.relabel_weights,
                             self.relabel_Q_diff),
                            feed_dict)
                        relabel_weights = np.array(relabel_weights)
                        relabel_Q_diff = np.array(relabel_Q_diff)
                        relabel_weight_temp_avgs = feed_dict[self.relabel_weight_temp_ph]
                        if np.any(np.isnan(relabel_weights)):
                            import pdb; pdb.set_trace()
                        try:
                            diagnostics = OrderedDict({
                                'Q-avg': np.mean(Q_values),
                                'Q-std': np.std(Q_values),
                                'Q_loss': np.mean(Q_losses),
                                'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                                'min_Q_loss': np.mean(min_Q_losses),
                                'alpha': alpha,
                                'alpha_prime': alpha_prime,
                                'relabel_weights_min': relabel_weights.min(),
                                'relabel_weights_max': relabel_weights.max(),
                                'relabel_weights_mean': relabel_weights.mean(),
                                'relabel_weights_max_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_weights_mean_no_orig': relabel_weights.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_weights.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_max': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].max(),
                                'relabel_Q_diff_mean': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].mean(),
                                'relabel_Q_diff_min': relabel_Q_diff.reshape(2, self.num_tasks, -1, 1)[:, :, relabel_Q_diff.shape[1]//2//self.num_tasks:, :].min(),
                                'relabel_weight_temp_avgs': np.mean(relabel_weight_temp_avgs),
                            })
                        except:
                            import pdb; pdb.set_trace()
                        if self.relabel_weight_temp_adaptive_per_task:
                            for i in range(self.num_tasks):
                                diagnostics.update({
                                    'relabel_weight_temp_avgs_task%s' % i: np.mean(relabel_weight_temp_avgs, axis=0)[i],
                                    })
                else:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha_prime, alpha_prime_loss, alpha, global_step) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.orig_min_Q_losses,
                         self._Q_values_curr_actions,
                         self.alpha_prime,
                         self.alpha_prime_loss,
                         self._alpha,
                         self.global_step),
                        feed_dict)

                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                        'min_Q_loss': np.mean(min_Q_losses),
                        'alpha_prime': alpha_prime,
                        'alpha': alpha,
                    })
        else:
            if not self.with_lagrange:
                if not self.with_min_q:
                    (Q_values, Q_losses, alpha, global_step, fqe_Q_values) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self._alpha,
                         self.global_step,
                         self._fqe_Q_values),
                        feed_dict)
                    fqe_Q_values = np.array(fqe_Q_values)
                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'FQE-Q-avg': np.mean(fqe_Q_values),
                        'FQE-Q-std': np.std(np.mean(fqe_Q_values, axis=-1)),
                        'alpha': alpha,
                    })
                elif self.backup_with_uniform:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step, policy_prob, fqe_Q_values) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.min_Q_losses,
                         self._Q_values_curr_actions,
                         self._alpha,
                         self.global_step,
                         self._policy_prob,
                         self._fqe_Q_values),
                        feed_dict)
                    if np.any(np.isnan(policy_prob)):
                        import pdb; pdb.set_trace()
                    fqe_Q_values = np.array(fqe_Q_values)
                    min_fqe_Q_losses = np.array(min_fqe_Q_losses)
                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                        'FQE-Q-avg': np.mean(fqe_Q_values),
                        'FQE-Q-std': np.std(np.mean(fqe_Q_values, axis=-1)),
                        'min_Q_loss': np.mean(min_Q_losses),
                        'alpha': alpha,
                        'policy_prob_min': policy_prob.min()
                    })
                else:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha, global_step, fqe_Q_values) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.min_Q_losses,
                         self._Q_values_curr_actions,
                         self._alpha,
                         self.global_step,
                         self._fqe_Q_values),
                        feed_dict)
                    fqe_Q_values = np.array(fqe_Q_values)
                    diagnostics = OrderedDict({
                        'Q-avg': np.mean(Q_values),
                        'Q-std': np.std(Q_values),
                        'Q_loss': np.mean(Q_losses),
                        'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                        'FQE-Q-avg': np.mean(fqe_Q_values),
                        'FQE-Q-std': np.std(np.mean(fqe_Q_values, axis=-1)),
                        'min_Q_loss': np.mean(min_Q_losses),
                        'alpha': alpha,
                    })
            else:
                (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, alpha_prime, alpha_prime_loss, alpha, global_step, fqe_Q_values) = self._session.run(
                    (self._Q_values,
                     self._Q_losses,
                     self.orig_min_Q_losses,
                     self._Q_values_curr_actions,
                     self.alpha_prime,
                     self.alpha_prime_loss,
                     self._alpha,
                     self.global_step,
                     self._fqe_Q_values),
                    feed_dict)
                fqe_Q_values = np.array(fqe_Q_values)
                diagnostics = OrderedDict({
                    'Q-avg': np.mean(Q_values),
                    'Q-std': np.std(Q_values),
                    'Q_loss': np.mean(Q_losses),
                    'Q-avg-curr_actions': np.mean(Q_values_curr_actions),
                    'FQE-Q-avg': np.mean(fqe_Q_values),
                    'FQE-Q-std': np.std(np.mean(fqe_Q_values, axis=-1)),
                    'min_Q_loss': np.mean(min_Q_losses),
                    'alpha_prime': alpha_prime,
                    'alpha': alpha,
                })

        if self.num_tasks > 1 and not self.goal_conditioned:
            for task_idx in range(self.num_tasks):
                batch = self._evaluation_batch(batch_size=self.sampler._batch_size, task_id=task_idx)
                per_task_feed_dict = self._get_feed_dict(iteration, batch)
                if not self.with_lagrange:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, Q_gradient_norm) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.min_Q_losses,
                         self._Q_values_curr_actions,
                         self._Q_gradient_norm),
                        per_task_feed_dict)
                else:
                    (Q_values, Q_losses, min_Q_losses, Q_values_curr_actions, Q_gradient_norm) = self._session.run(
                        (self._Q_values,
                         self._Q_losses,
                         self.orig_min_Q_losses,
                         self._Q_values_curr_actions,
                         self._Q_gradient_norm),
                        per_task_feed_dict)

                diagnostics.update({
                    'Q-avg-task%d' % task_idx: np.mean(Q_values),
                    'Q-std-task%d' % task_idx: np.std(Q_values),
                    'Q_loss-task%d' % task_idx: np.mean(Q_losses),
                    'Q-avg-curr_actions-task%d' % task_idx: np.mean(Q_values_curr_actions),
                    'min_Q_loss-task%d' % task_idx: np.mean(min_Q_losses),
                    'Q-grad-norm-task%d' % task_idx: np.mean(Q_gradient_norm),
                })

        if self.cross_validate and self.cross_validate_model_eval:
            model_eval_metrics = self._cross_validate(rollout_batch_size=self._rollout_batch_size)
            diagnostics.update({
                'model_eval_return_average': np.mean([evaluation_metrics['return-average'] for evaluation_metrics in model_eval_metrics]),
                'model_eval_return_std': np.mean([evaluation_metrics['return-std'] for evaluation_metrics in model_eval_metrics]),
                })
        policy_diagnostics = self._policy.get_diagnostics(
            batch['observations'])
        diagnostics.update({
            f'policy/{key}': value
            for key, value in policy_diagnostics.items()
        })

        if self._plotter:
            self._plotter.draw()

        return diagnostics

    @property
    def tf_saveables(self):
        saveables = {
            '_policy_optimizer': self._policy_optimizer,
            **{
                f'Q_optimizer_{i}': optimizer
                for i, optimizer in enumerate(self._Q_optimizers)
            },
            '_log_alpha': self._log_alpha,
        }

        if hasattr(self, '_alpha_optimizer'):
            saveables['_alpha_optimizer'] = self._alpha_optimizer

        return saveables
