from collections import OrderedDict

import os
import time
import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from rlkit.samplers.data_collector.path_collector import MdpPathCollector

class CQLTrainer(TorchTrainer):
    def __init__(
            self,
            env,
            exp_name,
            policy,
            qf1,
            qf2,
            target_qf1,
            target_qf2,
            behavior_poliy,
            vae,

            discount=0.99,
            reward_scale=1.0,

            policy_lr=1e-3,
            qf_lr=1e-3,
            optimizer_class=optim.Adam,

            soft_target_tau=1e-2,
            target_update_period=2,
            plotter=None,
            render_eval_paths=False,

            use_automatic_entropy_tuning=True,
            target_entropy=None,
            policy_eval_start=0,
            num_qs=2,

            # CQL
            min_q_version=3,
            temp=1.0,
            min_q_weight=1.0,

            ## sort of backup
            max_q_backup=False,
            deterministic_backup=True,
            num_total=100,
            with_lagrange=False,
            ratio_temp=1.0,
            ver=1,
    ):
        super().__init__()
        self.env = env
        self.exp_name = exp_name
        self.policy = policy
        self.behavior_poliy = behavior_poliy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.vae = vae

        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf1_optimizer = optimizer_class(
            self.qf1.parameters(),
            lr=qf_lr,
        )
        self.qf2_optimizer = optimizer_class(
            self.qf2.parameters(),
            lr=qf_lr,
        )
        self.vae_optimizer = optimizer_class(
            self.vae.parameters(),
            lr=policy_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self.eval_wandb = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.policy_eval_start = policy_eval_start

        self._current_epoch = 0
        self._policy_update_ctr = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0
        self._num_policy_steps = 1

        self.num_qs = num_qs

        ## min Q
        self.temp = temp
        self.min_q_version = min_q_version
        self.min_q_weight = min_q_weight
        self.ratio_temp = ratio_temp

        print("self.min_q_version: \t", min_q_version)
        print("self.min_q_weight: \t", min_q_weight)
        print('self.ratio_temp: \t', ratio_temp)

        self.softmax = torch.nn.Softmax(dim=1)

        self.max_q_backup = max_q_backup
        self.deterministic_backup = deterministic_backup

        self.num_total = num_total

        print('self.num_total: \t', self.num_total)

        # For implementation on the
        self.discrete = False
        self.ver = ver

        self.path = f'./{self.exp_name}'
        if not os.path.exists(self.path):
            os.makedirs(self.path)

    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, _, _, new_obs_log_pi, *_ = network(
            obs_temp, reparameterize=False, return_log_prob=True,
        )
        if not self.discrete:
            return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    def _calculate_log_prob(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        act_temp, _, _, new_obs_log_pi, *_ = self.policy(
            obs_temp, reparameterize=False, return_log_prob=True,
        )
        new_obs_log_pi = network.log_prob(obs_temp, act_temp)

        return new_obs_log_pi

    def train_from_torch(self, batch):
        self._current_epoch += 1

        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        VAE Loss
        """

        recon, mean, std = self.vae(obs, actions)
        recon_loss = self.qf_criterion(recon, actions)
        kl_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * kl_loss

        self.vae_optimizer.zero_grad()
        vae_loss.backward()
        self.vae_optimizer.step()

        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = (self.log_alpha * ((-log_pi) - self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha * log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            policy_log_prob = self.policy.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()

        """
        QF Loss
        """

        q1_pred = self.qf1(obs, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, actions)

        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs, reparameterize=False, return_log_prob=True,
        )
        new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
            obs, reparameterize=False, return_log_prob=True,
        )

        if not self.max_q_backup:
            if self.num_qs == 1:
                target_q_values = self.target_qf1(next_obs, new_next_actions)
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),
                    self.target_qf2(next_obs, new_next_actions),
                )

            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi

        if self.max_q_backup:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy)
            target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[
                0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[
                0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        qf1_bellman = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_bellman = self.qf_criterion(q2_pred, q_target)

        ### add Ours
        random_density = np.log(0.5 ** actions.shape[-1])

        with torch.no_grad():
            curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_total,
                                                                         network=self.policy)
            new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_total,
                                                                            network=self.policy)
            random_actions_tensor = torch.FloatTensor(q1_pred.shape[0] * self.num_total, actions.shape[-1]).uniform_(-1,
                                                                                                                     1).to(
                ptu.device)
            # curr_beta_tensor, _ = self._get_policy_actions(obs, num_actions=self.num_total, network=self.behavior_poliy)
            curr_beta_tensor, *_ = self.vae.decode_multiple(obs, num_decode=self.num_total)
            curr_beta_tensor = curr_beta_tensor.view(obs.shape[0] * self.num_total, actions.shape[-1])

            obs_temp = obs.unsqueeze(1).repeat(1, self.num_total, 1).view(obs.shape[0] * self.num_total, obs.shape[1])
            beta_log_betas = self.behavior_poliy.log_prob(obs_temp, curr_beta_tensor)
            pi_log_betas = self.behavior_poliy.log_prob(obs_temp, curr_actions_tensor)
            npi_log_betas = self.behavior_poliy.log_prob(obs_temp, new_curr_actions_tensor)
            # rand_log_betas = self.behavior_poliy.log_prob(obs_temp, random_actions_tensor)

            # torch.Size([512, 50])
            pi_log_betas = pi_log_betas.view(obs.shape[0], -1)
            npi_log_betas = npi_log_betas.view(obs.shape[0], -1)
            # rand_log_betas = rand_log_betas.view(obs.shape[0], -1)
            beta_log_betas = beta_log_betas.view(obs.shape[0], -1)

            log_rho = torch.ones_like(beta_log_betas) * random_density

            # torch.Size([512, 50])
            log_beta_ratio = (self.ratio_temp * (beta_log_betas - log_rho)).clamp(min=-20., max=0.0).detach()
            log_pi_ratio = (self.ratio_temp * (pi_log_betas - log_rho)).clamp(min=-20., max=0.0).detach()
            log_npi_ratio = (self.ratio_temp * (npi_log_betas - log_rho)).clamp(min=-20., max=0.0).detach()
            # log_rand_ratio = (self.ratio_temp * (rand_log_betas - log_rho)).clamp(min=-20., max=0.0).detach()

            # (1-beta/rho)+
            # torch.Size([512, 1])
            beta_ratio = (1 - torch.exp(log_beta_ratio)).detach()
            pi_ratio = (1 - torch.exp(log_pi_ratio)).detach()
            npi_ratio = (1 - torch.exp(log_npi_ratio)).detach()
            # rand_ratio = (1 - torch.exp(log_rand_ratio)).detach()

            if self.min_q_version==2:
                beta_log_batch = self.behavior_poliy.log_prob(obs, actions).view(obs.shape[0], -1)
                beta_log_batch = beta_log_batch.view(obs.shape[0], -1)
                log_batch_ratio = (self.ratio_temp * (beta_log_batch - log_rho[:, 0].unsqueeze(1))).clamp(min=-20., max=0.0).detach()
                batch_ratio = (1 - torch.exp(log_batch_ratio)).detach()

        # log E_rho [(1-beta/rho)+ exp Q] * (E_rho[(1-beta/rho)+ exp Q] / E_rho [exp Q])
        # torch.Size([512, 50, 1])
        q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
        q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
        q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)

        # torch.Size([512, 150, 1])
        if self.min_q_version==3:
            # cat_q1 = torch.cat(
            #     [torch.log(rand_ratio+1e-30).view(-1, self.num_total, 1) + q1_rand - random_density,
            #      torch.log(npi_ratio+1e-30).view(-1, self.num_total, 1) + q1_next_actions - new_log_pis.detach(),
            #      torch.log(pi_ratio+1e-30).view(-1, self.num_total, 1) + q1_curr_actions - curr_log_pis.detach(),], dim=1)

            # nw_cat_q1 = torch.cat(
            #     [q1_rand - random_density,
            #      q1_next_actions - new_log_pis.detach(),
            #      q1_curr_actions - curr_log_pis.detach(), ], dim=1)

            if self.ver==2:
                cat_q1 = torch.log(pi_ratio + 1e-20).view(-1, self.num_total,1) + q1_curr_actions - curr_log_pis.detach()
                nw_cat_q1 = q1_curr_actions - curr_log_pis.detach()
            else:
                cat_q1 = torch.cat(
                    [torch.log(npi_ratio + 1e-20).view(-1, self.num_total, 1) + q1_next_actions - new_log_pis.detach(),
                     torch.log(pi_ratio + 1e-20).view(-1, self.num_total,
                                                      1) + q1_curr_actions - curr_log_pis.detach(), ], dim=1)

                nw_cat_q1 = torch.cat(
                    [q1_next_actions - new_log_pis.detach(),
                     q1_curr_actions - curr_log_pis.detach(), ], dim=1)



        elif self.min_q_version==2:
            # B N D
            cat_q1 = torch.cat(
                [torch.log(rand_ratio+1e-20).view(-1, self.num_total, 1) + q1_rand,
                 torch.log(npi_ratio+1e-20).view(-1, self.num_total, 1) + q1_next_actions,
                 torch.log(pi_ratio+1e-20).view(-1, self.num_total, 1) + q1_curr_actions,
                 torch.log(batch_ratio+1e-20).view(-1, 1, 1) + q1_pred.unsqueeze(1), ], dim=1)

            nw_cat_q1 = torch.cat(
                [q1_rand,
                 q1_next_actions,
                 q1_curr_actions,
                 q1_pred.unsqueeze(1)], dim=1)
        else:
            raise NotImplementedError

        # torch.Size([512, 1])
        logsum = torch.logsumexp(cat_q1, dim=1,)
        with torch.no_grad():
            nw_logsum = torch.logsumexp(nw_cat_q1, dim=1,)
            logsum_ratio = torch.exp(logsum - nw_logsum) # <= 1

        s1_min_qf1_loss = (logsum * logsum_ratio.detach()).mean()

        # - E_beta [(1-beta/rho)+ Q]
        q1_curr_betas = self._get_tensor_values(obs, curr_beta_tensor, network=self.qf1) # torch.Size([512, 50, 1])
        s2_min_qf1_loss = - ((beta_ratio * q1_curr_betas.squeeze(-1)).mean(dim=-1, keepdim=True)).mean()

        # - (E_pi [(1-beta/rho)+] - E_beta [(1-beta/rho)+]) E_beta [Q]
        s3_min_qf1_loss = -((pi_ratio.mean(dim=1, keepdim=True) - beta_ratio.mean(dim=1, keepdim=True)) * q1_curr_betas.squeeze(-1).mean(dim=1, keepdim=True)).mean()

        min_qf1_loss = s1_min_qf1_loss + s2_min_qf1_loss + s3_min_qf1_loss
        qf1_loss = qf1_bellman + min_qf1_loss * self.min_q_weight

        if self.num_qs > 1:
            # log E_rho [(1-beta/rho)+ exp Q] * (E_rho[(1-beta/rho)+ exp Q] / E_rho [exp Q])
            # torch.Size([512, 50, 1])
            q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
            q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)
            q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)

            # torch.Size([512, 150, 1])
            if self.min_q_version == 3:
                # B N D
                # cat_q2 = torch.cat(
                #     [torch.log(rand_ratio + 1e-30).view(-1, self.num_total, 1) + q2_rand - random_density,
                #      torch.log(npi_ratio + 1e-30).view(-1, self.num_total, 1) + q2_next_actions - new_log_pis.detach(),
                #      torch.log(pi_ratio + 1e-30).view(-1, self.num_total,1) + q2_curr_actions - curr_log_pis.detach(), ], dim=1)
                #
                # nw_cat_q2 = torch.cat(
                #     [q2_rand - random_density,
                #      q2_next_actions - new_log_pis.detach(),
                #      q2_curr_actions - curr_log_pis.detach(), ], dim=1)
                if self.ver == 2:
                    cat_q2 = torch.log(pi_ratio + 1e-20).view(-1, self.num_total,
                                                              1) + q2_curr_actions - curr_log_pis.detach()
                    nw_cat_q2 = q2_curr_actions - curr_log_pis.detach()
                else:
                    cat_q2 = torch.cat(
                        [torch.log(npi_ratio + 1e-20).view(-1, self.num_total, 1) + q2_next_actions - new_log_pis.detach(),
                         torch.log(pi_ratio + 1e-20).view(-1, self.num_total,
                                                          1) + q2_curr_actions - curr_log_pis.detach(), ], dim=1)

                    nw_cat_q2 = torch.cat(
                        [q2_next_actions - new_log_pis.detach(),
                         q2_curr_actions - curr_log_pis.detach(), ], dim=1)

            elif self.min_q_version == 2:
                # B N D
                cat_q2 = torch.cat(
                    [torch.log(rand_ratio + 1e-20).view(-1, self.num_total, 1) + q2_rand,
                     torch.log(npi_ratio + 1e-20).view(-1, self.num_total, 1) + q2_next_actions,
                     torch.log(pi_ratio + 1e-20).view(-1, self.num_total, 1) + q2_curr_actions,
                     torch.log(batch_ratio + 1e-20).view(-1, 1, 1) + q2_pred.unsqueeze(1), ], dim=1)

                nw_cat_q2 = torch.cat(
                    [q2_rand,
                     q2_next_actions,
                     q2_curr_actions,
                     q2_pred.unsqueeze(1)], dim=1)
            else:
                raise NotImplementedError

            # torch.Size([512, 1])
            logsum = torch.logsumexp(cat_q2, dim=1, )
            with torch.no_grad():
                nw_logsum = torch.logsumexp(nw_cat_q2, dim=1, )
                logsum_ratio = torch.exp(logsum - nw_logsum)  # <= 1

            s1_min_qf2_loss = (logsum * logsum_ratio.detach()).mean()

            # - E_beta [(1-beta/rho)+ Q]
            q2_curr_betas = self._get_tensor_values(obs, curr_beta_tensor, network=self.qf2)  # torch.Size([512, 50, 1])
            s2_min_qf2_loss = - ((beta_ratio * q2_curr_betas.squeeze(-1)).mean(dim=-1, keepdim=True)).mean()

            # - (E_pi [(1-beta/rho)+] - E_beta [(1-beta/rho)+]) E_beta [Q]
            s3_min_qf2_loss = -((pi_ratio.mean(dim=1, keepdim=True)
                                 - beta_ratio.mean(dim=1, keepdim=True)
                                 ) * q2_curr_betas.squeeze(-1).mean(dim=1, keepdim=True)).mean()

            min_qf2_loss = s1_min_qf2_loss + s2_min_qf2_loss + s3_min_qf2_loss
            qf2_loss = qf2_bellman + min_qf2_loss * self.min_q_weight

        """
        Update networks
        """
        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()

        # Update the Q-functions iff
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()

        """
        Soft Updates
        """
        if self._num_train_steps % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf1, self.target_qf1, self.soft_target_tau
            )
            if self.num_qs > 1:
                ptu.soft_update_from_to(
                    self.qf2, self.target_qf2, self.soft_target_tau
                )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            self.eval_statistics.update(create_stats_ordered_dict(
                'pi_ratio',
                ptu.get_numpy(pi_ratio),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'npi_ratio',
                ptu.get_numpy(npi_ratio),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'beta_ratio',
                ptu.get_numpy(beta_ratio),
            ))
            self.eval_statistics['QF1 Bellman Loss'] = np.mean(ptu.get_numpy(qf1_bellman))
            self.eval_statistics.update(create_stats_ordered_dict(
                'log_beta_ratio',
                ptu.get_numpy(log_beta_ratio),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'log_pi_ratio',
                ptu.get_numpy(log_pi_ratio),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'log_npi_ratio',
                ptu.get_numpy(log_npi_ratio),
            ))
            # self.eval_statistics.update(create_stats_ordered_dict(
            #     'log_rand_ratio',
            #     ptu.get_numpy(log_rand_ratio),
            # ))


            self.eval_statistics.update(create_stats_ordered_dict(
                'logsum',
                ptu.get_numpy(logsum),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'nw_logsum',
                ptu.get_numpy(nw_logsum),
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'logsum_ratio',
                ptu.get_numpy(logsum_ratio),
            ))
            self.eval_statistics['1st min QF1 Loss'] = np.mean(ptu.get_numpy(s1_min_qf1_loss))
            self.eval_statistics['2nd min QF1 Loss'] = np.mean(ptu.get_numpy(s2_min_qf1_loss))
            self.eval_statistics['3rd min QF1 Loss'] = np.mean(ptu.get_numpy(s3_min_qf1_loss))
            self.eval_statistics['Total min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss))
            self.eval_statistics['Total QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['Bellman QF2 Loss'] = np.mean(ptu.get_numpy(qf2_bellman))
                self.eval_statistics['Min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss))
                self.eval_statistics['Total QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))

            if not self.discrete:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 batch values',
                    ptu.get_numpy(q1_pred),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 random values',
                    ptu.get_numpy(q1_rand),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 pi values',
                    ptu.get_numpy(q1_curr_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 next_actions values',
                    ptu.get_numpy(q1_next_actions),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'actions',
                    ptu.get_numpy(actions)
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'rewards',
                    ptu.get_numpy(rewards)
                ))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
                policy_loss
            ))
            self.eval_statistics['Entropy'] = np.mean(ptu.get_numpy(
                -log_pi
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q1 Predictions',
                ptu.get_numpy(q1_pred),
            ))
            if self.num_qs > 1:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions',
                    ptu.get_numpy(q2_pred),
                ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Targets',
                ptu.get_numpy(q_target),
            ))

            self.eval_statistics.update(create_stats_ordered_dict(
                'Log Pis',
                ptu.get_numpy(log_pi),
            ))
            if not self.discrete:
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            def get_all_state_overestim(paths):
                overestim = []
                path_over_info = {'path_overestim': [], 'gamm_return': [], 'q_val': []}

                for path in paths:
                    gamma_return = 0
                    for i in reversed(range(path["rewards"].size)):
                        gamma_return = path["rewards"][i] + 0.99 * gamma_return * (1 - path["terminals"][i])
                        with torch.no_grad():
                            q_val = torch.stack([self.qf1(ptu.from_numpy(path["observations"][i]).unsqueeze(0),
                                                          ptu.from_numpy(path["actions"][i]).unsqueeze(0)),
                                                 self.qf2(ptu.from_numpy(path["observations"][i]).unsqueeze(0),
                                                          ptu.from_numpy(path["actions"][i]).unsqueeze(0))], 0)
                            q_val = torch.min(q_val, dim=0)[0].squeeze(0).cpu().numpy()

                        path_over_info["gamm_return"].append(gamma_return.item())
                        path_over_info['q_val'].append((q_val).item())
                        path_over_info["path_overestim"].append((q_val - gamma_return).item())

                    path_over_info["gamm_return"] = path_over_info["gamm_return"][::-1]
                    path_over_info["q_val"] = path_over_info["q_val"][::-1]
                    path_over_info["path_overestim"] = path_over_info["path_overestim"][::-1]

                    overestim.append(path_over_info)

                return np.array(overestim)

            self._reserve_path_collector = MdpPathCollector(
                env=self.env, policy=self.policy,
            )
            self._reserve_path_collector.update_policy(self.policy)

            # Sampling
            eval_paths = self._reserve_path_collector.collect_new_paths(
                max_path_length=1000,
                num_steps=1000,
                discard_incomplete_paths=True,
            )

            overestim_info = get_all_state_overestim(eval_paths)

            self.eval_statistics.update(
                create_stats_ordered_dict('Overestimation', overestim_info[0]['path_overestim']))

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        base_list = [
            self.policy,
            self.behavior_poliy,
            self.qf1,
            self.qf2,
            self.target_qf1,
            self.target_qf2,
            self.vae,
        ]
        return base_list

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            behavior_poliy=self.behavior_poliy,
            qf1=self.qf1,
            qf2=self.qf2,
            target_qf1=self.target_qf1,
            target_qf2=self.target_qf2,
            vae=self.vae,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.behavior_poliy = snapshot['behavior_poliy']
        self.qf1 = snapshot['qf1']
        self.qf2 = snapshot['qf2']
        self.target_qf1 = snapshot['target_qf1']
        self.target_qf2 = snapshot['target_qf2']
        self.vae = snapshot['vae']

