"""This modules creates a Distral SAC model based on garage SAC."""

# yapf: disable
import copy

import numpy as np
import torch
import torch.nn.functional as F

from garage import obtain_evaluation_episodes, StepType, EpisodeBatch
from garage.np.algos import RLAlgorithm
from garage.sampler import RaySampler
from garage.torch import as_torch_dict, as_torch, global_device

from learning.utils import (get_path_policy_id, log_performance,
                            log_multitask_performance, log_wandb)

# yapf: enable


class DistralSAC(RLAlgorithm):
    def __init__(
        self,
        env_spec,
        central_policy,
        policy,
        policies,
        central_qf1,
        central_qf2,
        qf1s,
        qf2s,
        replay_buffers,
        sampler,
        visualizer,
        get_stage_id,
        preproc_obs,
        *,  # Everything after this is numbers.
        two_column=False,
        initial_kl_coeff=[0.0],
        entropy_beta=False,
        max_episode_length_eval=None,
        gradient_steps_per_itr,
        fixed_alpha=None,
        target_entropy=None,
        initial_log_entropy=0.0,
        discount=0.99,
        buffer_batch_size=64,
        min_buffer_size=int(1e4),
        target_update_tau=5e-3,
        policy_lr=3e-4,
        qf_lr=3e-4,
        reward_scale=1.0,
        optimizer=torch.optim.Adam,
        steps_per_epoch=1,
        num_evaluation_episodes=10,
        eval_env=None,
        use_deterministic_evaluation=True,
    ):
        self.get_stage_id = get_stage_id
        self.preproc_obs = preproc_obs or (lambda x: (x, x))

        self._two_column = two_column
        self._central_qf1 = central_qf1
        self._central_qf2 = central_qf2
        self._qf1s = qf1s
        self._qf2s = qf2s
        self.replay_buffers = replay_buffers
        self._tau = target_update_tau
        self._policy_lr = policy_lr
        self._qf_lr = qf_lr
        self._initial_log_entropy = initial_log_entropy
        self._gradient_steps = gradient_steps_per_itr
        self._optimizer = optimizer
        self._num_evaluation_episodes = num_evaluation_episodes
        self._eval_env = eval_env

        self._min_buffer_size = min_buffer_size
        self._steps_per_epoch = steps_per_epoch
        self._buffer_batch_size = buffer_batch_size
        self._discount = discount
        self._reward_scale = reward_scale
        self.max_episode_length = env_spec.max_episode_length
        self._max_episode_length_eval = env_spec.max_episode_length

        if max_episode_length_eval is not None:
            self._max_episode_length_eval = max_episode_length_eval
        self._use_deterministic_evaluation = use_deterministic_evaluation

        self.central_policy = central_policy
        self.policies = policies
        self.policy = policy
        self.env_spec = env_spec
        self.n_policies = len(policies)

        self._sampler = sampler
        self._visualizer = visualizer

        # use 2 target q networks
        self._target_central_qf1 = copy.deepcopy(self._central_qf1)
        self._target_central_qf2 = copy.deepcopy(self._central_qf2)
        self._target_qf1s = [copy.deepcopy(qf) for qf in self._qf1s]
        self._target_qf2s = [copy.deepcopy(qf) for qf in self._qf2s]
        self._central_policy_optimizer = self._optimizer(
            central_policy.parameters(), lr=self._policy_lr
        )
        self._policy_optimizers = [
            self._optimizer(policy.parameters(), lr=self._policy_lr)
            for policy in self.policies
        ]
        self._qf1_optimizers = [
            self._optimizer(qf.parameters(), lr=self._qf_lr) for qf in self._qf1s
        ]
        if self._two_column:
            self._qf1_optimizers += [
                self._optimizer(self._central_qf1.parameters(), lr=self._qf_lr)
            ]
        self._qf2_optimizers = [
            self._optimizer(qf.parameters(), lr=self._qf_lr) for qf in self._qf2s
        ]
        if self._two_column:
            self._qf2_optimizers += [
                self._optimizer(self._central_qf2.parameters(), lr=self._qf_lr)
            ]
        # automatic entropy coefficient tuning
        self._use_automatic_entropy_tuning = fixed_alpha is None
        self._fixed_alpha = fixed_alpha
        if self._use_automatic_entropy_tuning:
            if target_entropy:
                self._target_entropy = target_entropy
            else:
                self._target_entropy = -np.prod(self.env_spec.action_space.shape).item()
            self._log_alphas = [
                torch.Tensor([self._initial_log_entropy]).requires_grad_()
                for _ in range(self.n_policies)
            ]
            self._alpha_optimizers = [
                self._optimizer([a], lr=self._policy_lr) for a in self._log_alphas
            ]
        else:
            self._log_alphas = [
                torch.Tensor([self._fixed_alpha]).log() for _ in range(self.n_policies)
            ]

        self._entropy_beta = entropy_beta

        self._initial_kl_coeffs = initial_kl_coeff
        self._log_kl_coeffs = torch.ones(self.n_policies)

        ### Uniform Beta
        if len(initial_kl_coeff) == 1:
            self._log_kl_coeffs = (self._log_kl_coeffs * initial_kl_coeff[0]).log()

        ### Task-wise Betas
        else:
            assert len(initial_kl_coeff) == self.n_policies
            self._log_kl_coeffs = torch.Tensor(initial_kl_coeff).log()

        self.episode_rewards = np.zeros(self.n_policies)
        self.success_rates = np.zeros(self.n_policies)
        self.stages_completed = np.zeros(self.n_policies)

    def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.

        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._eval_env:
            self._eval_env = trainer.get_env_copy()
        last_return = None
        for _ in trainer.step_epochs():
            for _ in range(self._steps_per_epoch):
                self.train_once(trainer)
            last_return = self._evaluate_policy(trainer.step_itr)
            videos = self._visualize_policy(trainer.step_itr)
            infos = {}
            infos["AverageReturn"] = np.mean(
                [np.mean(self.episode_rewards[i]) for i in range(self.n_policies)]
            )
            infos["SuccessRate"] = np.mean(
                [np.mean(self.success_rates[i]) for i in range(self.n_policies)]
            )
            infos["StagesCompleted"] = np.mean(
                [np.mean(self.stages_completed[i]) for i in range(self.n_policies)]
            )
            infos["TotalEnvSteps"] = trainer.total_env_steps
            log_wandb(trainer.step_itr, infos, medias=videos, prefix="Train/")
            trainer.step_itr += 1

        return np.mean(last_return)

    def train_once(self, trainer):
        if not (
            np.any(
                [
                    self.replay_buffers[i].n_transitions_stored >= self._min_buffer_size
                    for i in range(self.n_policies)
                ]
            )
        ):
            batch_size = int(self._min_buffer_size) * self.n_policies
        else:
            batch_size = None

        if isinstance(self._sampler, RaySampler):
            # ray only supports CPU sampling
            with torch.no_grad():
                agent_update = copy.copy(self.policy)
                agent_update.policies = [
                    copy.deepcopy(policy).cpu() for policy in self.policies
                ]
        else:
            agent_update = None
        trainer.step_episode = trainer.obtain_samples(
            trainer.step_itr, batch_size, agent_update=agent_update
        )
        (
            path_returns,
            num_samples,
            num_path_starts,
            num_path_ends,
            num_successes,
            num_stages_completed,
        ) = (
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
        )

        step_types = []

        for path in trainer.step_episode:
            policy_id = get_path_policy_id(path)
            step_types.extend(path["step_types"])
            terminals = np.array(
                [step_type == StepType.TERMINAL for step_type in path["step_types"]]
            ).reshape(-1, 1)
            self.replay_buffers[policy_id].add_path(
                dict(
                    observation=self.preproc_obs(path["observations"])[0],
                    action=path["actions"],
                    reward=path["rewards"].reshape(-1, 1),
                    next_observation=self.preproc_obs(path["next_observations"])[0],
                    terminal=terminals,
                )
            )
            path_returns[policy_id] += sum(path["rewards"])
            num_samples[policy_id] += len(path["actions"])
            num_path_starts[policy_id] += np.sum(
                [step_type == StepType.FIRST for step_type in path["step_types"]]
            )
            num_path_ends[policy_id] += np.sum(terminals)
            if "success" in path["env_infos"]:
                num_successes[policy_id] += path["env_infos"]["success"].any()
            if "stages_completed" in path["env_infos"]:
                num_stages_completed[policy_id] += path["env_infos"][
                    "stages_completed"
                ][-1]

        for i in range(self.n_policies):
            num_paths = max(num_path_starts[i], num_path_ends[i], 1)  # AD-HOC
            self.episode_rewards[i] = path_returns[i] / num_paths
            self.success_rates[i] = (
                num_path_ends[i] and num_successes[i] / num_path_ends[i]
            )
            self.stages_completed[i] = (
                num_path_ends[i] and num_stages_completed[i] / num_path_ends[i]
            )
        ### ASDF Which way should the for loop go?  Policy id then gradient steps?
        for policy_id in range(self.n_policies):
            num_grad_steps = int(
                self._gradient_steps / np.sum(num_samples) * num_samples[policy_id]
            )
            policy_losses, distillation_losses, qf1_losses, qf2_losses = (
                [],
                [],
                [],
                [],
            )
            for _ in range(num_grad_steps):
                (
                    policy_loss,
                    distillation_loss,
                    qf1_loss,
                    qf2_loss,
                ) = self.train_policy_once(policy_id)
                policy_losses.append(policy_loss)
                distillation_losses.append(distillation_loss)
                qf1_losses.append(qf1_loss)
                qf2_losses.append(qf2_loss)

            self._log_statistics(
                trainer.step_itr,
                policy_id,
                policy_losses,
                distillation_losses,
                qf1_losses,
                qf2_losses,
            )

    def train_policy_once(self, policy_id, itr=None, paths=None):
        """Complete 1 training iteration of SAC.

        Args:
            itr (int): Iteration number. This argument is deprecated.
            paths (list[dict]): A list of collected paths.
                This argument is deprecated.

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        del itr
        del paths
        if self.replay_buffers[policy_id].n_transitions_stored >= (
            self._min_buffer_size
        ):
            all_obs, policy_samples = self.replay_buffers.sample_transitions(
                self._buffer_batch_size, policy_id
            )
            all_obs = [as_torch(obs) for obs in all_obs]
            policy_samples = as_torch_dict(policy_samples)
            policy_loss, distillation_loss, qf1_loss, qf2_loss = self.optimize_policy(
                all_obs, policy_samples, policy_id=policy_id
            )

        else:
            policy_loss, distillation_loss, qf1_loss, qf2_loss = (0.0, 0.0, 0.0, 0.0)

        self._update_targets(policy_id)

        return policy_loss, distillation_loss, qf1_loss, qf2_loss

    def _get_log_alpha(self, samples_data, policy_id):
        """Return the value of log_alpha.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        This function exists in case there are versions of sac that need
        access to a modified log_alpha, such as multi_task sac.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: log_alpha

        """
        del samples_data
        log_alpha = self._log_alphas[policy_id]
        return log_alpha

    def _temperature_objective(self, log_pi, samples_data, policy_id):
        """Compute the temperature/alpha coefficient loss.

        Args:
            log_pi(torch.Tensor): log probability of actions that are sampled
                from the replay buffer. Shape is (1, buffer_batch_size).
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: the temperature/alpha coefficient loss.

        """
        alpha_loss = 0
        if self._use_automatic_entropy_tuning:
            alpha_loss = (
                -(self._get_log_alpha(samples_data, policy_id))
                * (log_pi.detach() + self._target_entropy)
            ).mean()
        return alpha_loss

    def _actor_objective(
        self, samples_data, new_actions, log_pi_new_actions, policy_id
    ):
        obs = samples_data["observation"]
        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data, policy_id).exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        # Note: alpha and beta correspond to c_Ent and c_KL, respectively.

        # TODO: two-column for Q value (alpha * Q0 + beta * Qi)
        q1 = self._qf1s[policy_id](obs, new_actions)
        q2 = self._qf2s[policy_id](obs, new_actions)
        if self._two_column:
            central_q1 = self._central_qf1(obs, new_actions)
            central_q2 = self._central_qf2(obs, new_actions)
            q1 = (alpha * central_q1 + beta * q1) / (alpha + beta)
            q2 = (alpha * central_q2 + beta * q2) / (alpha + beta)
        min_q_new_actions = torch.min(q1, q2)
        policy_objective = (
            ((alpha + beta) * log_pi_new_actions) - min_q_new_actions.flatten()
        ).mean()
        return policy_objective

    def _distillation_objective(self, policy_dist, central_dist, policy_id):
        policy_dist = policy_dist.__class__(
            policy_dist.mean.detach(),
            policy_dist.stddev.detach(),
        )
        kl = torch.distributions.kl.kl_divergence(policy_dist, central_dist)
        kl = kl.sum(-1).mean() * self._log_kl_coeffs[policy_id].exp()
        return kl

    def _critic_objective(self, samples_data, policy_id):
        """Compute the Q-function/critic loss.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]
        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data, policy_id).exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        # Note: alpha and beta correspond to c_Ent and c_KL, respectively.

        q1_pred = self._qf1s[policy_id](obs, actions)
        q2_pred = self._qf2s[policy_id](obs, actions)
        if self._two_column:
            central_q1 = self._central_qf1(obs, actions)
            central_q2 = self._central_qf2(obs, actions)
            q1_pred = (alpha * central_q1 + beta * q1_pred) / (alpha + beta)
            q2_pred = (alpha * central_q2 + beta * q2_pred) / (alpha + beta)

        new_next_actions_dist = self.policies[policy_id](next_obs)[0]
        (
            new_next_actions_pre_tanh,
            new_next_actions,
        ) = new_next_actions_dist.rsample_with_pre_tanh_value()
        new_log_pi = new_next_actions_dist.log_prob(
            value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh
        )

        # Also get log_pi from central policy
        central_dist = self.central_policy(next_obs)[0]
        (central_pre_tanh, central_actions) = central_dist.rsample_with_pre_tanh_value()
        central_log_pi = central_dist.log_prob(
            value=central_actions, pre_tanh_value=central_pre_tanh
        )

        target_q1 = self._target_qf1s[policy_id](next_obs, new_next_actions)
        target_q2 = self._target_qf2s[policy_id](next_obs, new_next_actions)
        if self._two_column:
            target_central_q1 = self._target_central_qf1(next_obs, new_next_actions)
            target_central_q2 = self._target_central_qf2(next_obs, new_next_actions)
            target_q1 = (alpha * target_central_q1 + beta * target_q1) / (alpha + beta)
            target_q2 = (alpha * target_central_q2 + beta * target_q2) / (alpha + beta)

        target_q_values = (
            torch.min(target_q1, target_q2).flatten()
            - (alpha * new_log_pi)
            + (beta * (central_log_pi - new_log_pi))
        )
        # Original code from fb-mtrl
        # TODO: task-dependent beta
        # --------------------------
        # agent_alpha = self.agent.get_alpha(batch.task_obs).detach()
        # alpha_from_paper = self.distral_alpha / (self.distral_alpha + agent_alpha)
        # beta_from_paper = 1.0 / (self.distral_alpha + agent_alpha)
        # return (
        #     torch.min(target_Q1, target_Q2)
        #     + (alpha_from_paper * distral_log_pi - log_pi) / beta_from_paper
        # )

        with torch.no_grad():
            q_target = (
                rewards * self._reward_scale
                + (1.0 - terminals) * self._discount * target_q_values
            )
        qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
        qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)

        return qf1_loss, qf2_loss

    def _update_targets(self, policy_id):
        """Update parameters in the target q-functions."""

        target_qfs = [self._target_qf1s[policy_id], self._target_qf2s[policy_id]]
        qfs = [self._qf1s[policy_id], self._qf2s[policy_id]]
        for target_qf, qf in zip(target_qfs, qfs):
            for t_param, param in zip(target_qf.parameters(), qf.parameters()):
                t_param.data.copy_(
                    t_param.data * (1.0 - self._tau) + param.data * self._tau
                )

    def optimize_policy(self, all_obs, samples_data, policy_id, log_betas=[]):
        """Optimize the policy q_functions, and temperature coefficient.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]
        qf1_loss, qf2_loss = self._critic_objective(samples_data, policy_id)

        self._qf1_optimizers[policy_id].zero_grad()
        qf1_loss.backward()
        self._qf1_optimizers[policy_id].step()

        self._qf2_optimizers[policy_id].zero_grad()
        qf2_loss.backward()
        self._qf2_optimizers[policy_id].step()

        action_dists = self.policies[policy_id](obs)[0]
        new_actions_pre_tanh, new_actions = action_dists.rsample_with_pre_tanh_value()
        log_pi_new_actions = action_dists.log_prob(
            value=new_actions, pre_tanh_value=new_actions_pre_tanh
        )

        policy_loss = self._actor_objective(
            samples_data, new_actions, log_pi_new_actions, policy_id
        )

        self._policy_optimizers[policy_id].zero_grad()
        policy_loss.backward()
        self._policy_optimizers[policy_id].step()

        # Update central policy
        # TODO: when to compute this loss? (before or after policy update?)
        central_dist = self.central_policy(obs)[0]
        # with torch.no_grad():
        #     policy_dist = self.policies[policy_id](obs)[0]
        # distillation_loss = self._distillation_objective(
        #     policy_dist, central_dist, policy_id
        # )
        distillation_loss = self._distillation_objective(
            action_dists, central_dist, policy_id
        )

        self._central_policy_optimizer.zero_grad()
        distillation_loss.backward()
        self._central_policy_optimizer.step()

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(
                log_pi_new_actions, samples_data, policy_id
            )
            self._alpha_optimizers[policy_id].zero_grad()
            alpha_loss.backward()
            self._alpha_optimizers[policy_id].step()

        return (
            policy_loss.item(),
            distillation_loss.item(),
            qf1_loss.item(),
            qf2_loss.item(),
        )

    def _evaluate_policy(self, epoch):
        """Evaluate the performance of the policy via deterministic sampling.

            Statistics such as (average) discounted return and success rate are
            recorded.

        Args:
            epoch (int): The current training epoch.

        Returns:
            float: The average return across self._num_evaluation_episodes
                episodes

        """
        if isinstance(self._eval_env, list):
            num_eval = self._num_evaluation_episodes // len(self._eval_env)
            eval_eps = []
            for eval_env in self._eval_env:
                eval_eps.append(
                    obtain_evaluation_episodes(
                        self.central_policy,  # self.policy,
                        eval_env,
                        self._max_episode_length_eval,
                        num_eps=num_eval,
                        deterministic=self._use_deterministic_evaluation,
                    )
                )
            eval_eps = EpisodeBatch.concatenate(*eval_eps)
            last_return = log_multitask_performance(
                epoch, eval_eps, discount=self._discount
            )
        else:
            eval_episodes = obtain_evaluation_episodes(
                self.central_policy,  # self.policy,
                self._eval_env,
                self._max_episode_length_eval,
                num_eps=self._num_evaluation_episodes,
                deterministic=self._use_deterministic_evaluation,
            )
            last_return = log_performance(epoch, eval_episodes, discount=self._discount)
        return last_return

    def _log_statistics(
        self,
        step,
        policy_id,
        policy_losses,
        distillation_losses,
        qf1_losses,
        qf2_losses,
    ):
        """Record training statistics to dowel such as losses and returns.

        Args:
            policy_loss (torch.Tensor): loss from actor/policy network.
            qf1_loss (torch.Tensor): loss from 1st qf/critic network.
            qf2_loss (torch.Tensor): loss from 2nd qf/critic network.

        """

        infos = {}

        with torch.no_grad():
            infos["AlphaTemperature"] = self._log_alphas[policy_id].exp().mean().item()
            log_betas = self._log_kl_coeffs[policy_id].cpu().detach().numpy()
            infos["BetaKL"] = np.exp(log_betas)

        infos["PolicyLoss"] = np.mean(policy_losses)
        infos["DistillationLoss"] = np.mean(distillation_losses)
        infos["Qf1Loss"] = np.mean(qf1_losses)
        infos["Qf2Loss"] = np.mean(qf2_losses)
        infos["ReplayBufferSize"] = self.replay_buffers[policy_id].n_transitions_stored
        infos["AverageReturn"] = np.mean(self.episode_rewards[policy_id])
        infos["SuccessRate"] = np.mean(self.success_rates[policy_id])
        infos["StagesCompleted"] = np.mean(self.stages_completed[policy_id])

        log_wandb(step, infos, prefix="Train/" + self.policies[policy_id].name + "/")

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        nets = [
            self.central_policy,
            *self.policies,
            *self._qf1s,
            *self._qf2s,
            *self._target_qf1s,
            *self._target_qf2s,
        ]
        if self._two_column:
            nets += [
                self._central_qf1,
                self._central_qf2,
                self._target_central_qf1,
                self._target_central_qf2,
            ]
        return nets

    def to(self, device=None):
        """Put all the networks within the model on device.

        Args:
            device (str): ID of GPU or CPU.

        """
        if device is None:
            device = global_device()
        for net in self.networks:
            net.to(device)
        if not self._use_automatic_entropy_tuning:
            self._log_alphas = [
                torch.Tensor([self._fixed_alpha]).log().to(device)
                for _ in range(self.n_policies)
            ]
        else:
            self._log_alphas = [
                (torch.Tensor([self._initial_log_entropy]).to(device).requires_grad_())
                for _ in range(self.n_policies)
            ]
            self._alpha_optimizers = [
                self._optimizer([a], lr=self._policy_lr) for a in self._log_alphas
            ]

        self._log_kl_coeffs = self._log_kl_coeffs.to(device)

    def _visualize_policy(self, epoch):
        if not self._visualizer.do_visualization(epoch):
            return None

        self._visualizer.reset()
        if isinstance(self._eval_env, list):
            num_vis = self._visualizer.num_videos // len(self._eval_env)
            for env in self._eval_env:
                self._render_env(env, num_vis)
        else:
            self._render_env(self._eval_env, self._visualizer.num_videos)

        return self._visualizer.get_video()

    def _render_env(self, env, num_vis):
        for _ in range(num_vis):
            last_obs, _ = env.reset()
            self.central_policy.reset()  # self.policy.reset()
            episode_length = 0
            while episode_length < self._max_episode_length_eval:
                # a, agent_info = self.policy.get_action(last_obs)
                a, agent_info = self.central_policy.get_action(last_obs)
                if self._use_deterministic_evaluation and "mean" in agent_info:
                    a = agent_info["mean"]
                es = env.step(a)
                info = {"reward": es.reward, "step": episode_length}
                info.update(es.env_info)
                self._visualizer.add(env.render(mode="rgb_array"), info)
                episode_length += 1
                if es.last:
                    break
                last_obs = es.observation
