"""This modules creates a Distral SAC model based on garage SAC."""

# yapf: disable
import copy

import numpy as np
import torch
import torch.nn.functional as F

from garage import obtain_evaluation_episodes, StepType, EpisodeBatch
from garage.np.algos import RLAlgorithm
from garage.sampler import RaySampler
from garage.torch import as_torch_dict, as_torch, global_device

from learning.utils import (get_path_policy_id, log_performance,
                            log_multitask_performance, log_wandb)

# yapf: enable


class DistralSQLDiscrete(RLAlgorithm):
    def __init__(
        self,
        env_spec,
        central_policy,
        policy,
        policies,
        replay_buffers,
        sampler,
        visualizer,
        get_stage_id,
        preproc_obs,
        *,  # Everything after this is numbers.
        two_column=False,
        initial_kl_coeff=[0.0],
        max_episode_length_eval=None,
        gradient_steps_per_itr,
        initial_log_entropy=0.0,
        discount=0.99,
        buffer_batch_size=64,
        min_buffer_size=int(1e4),
        target_update_tau=5e-3,
        central_lr=3e-4,
        model_lr=3e-4,
        reward_scale=1.0,
        optimizer=torch.optim.Adam,
        steps_per_epoch=1,
        num_evaluation_episodes=10,
        eval_env=None,
        use_deterministic_evaluation=True,
    ):
        self.get_stage_id = get_stage_id
        self.preproc_obs = preproc_obs or (lambda x: (x, x))

        self.central_policy = central_policy
        self.policy = policy
        self.policies = policies
        self.target_policies = [copy.deepcopy(policy) for policy in self.policies]
        self.replay_buffers = replay_buffers
        self._tau = target_update_tau

        self._two_column = two_column
        self._central_lr = central_lr
        self._model_lr = model_lr
        self._gradient_steps = gradient_steps_per_itr
        self._optimizer = optimizer
        self._num_evaluation_episodes = num_evaluation_episodes
        self._eval_env = eval_env

        self._min_buffer_size = min_buffer_size
        self._steps_per_epoch = steps_per_epoch
        self._buffer_batch_size = buffer_batch_size
        self._discount = discount
        self._reward_scale = reward_scale
        self.max_episode_length = env_spec.max_episode_length
        self._max_episode_length_eval = env_spec.max_episode_length

        if max_episode_length_eval is not None:
            self._max_episode_length_eval = max_episode_length_eval
        self._use_deterministic_evaluation = use_deterministic_evaluation

        self.env_spec = env_spec
        self.n_policies = len(policies)

        self._sampler = sampler
        self._visualizer = visualizer

        self._central_optimizer = self._optimizer(
            central_policy.parameters(), lr=self._central_lr
        )
        self._policy_optimizers = [
            self._optimizer(policy.parameters(), lr=self._model_lr)
            for policy in self.policies
        ]

        self._log_alpha = torch.ones(self.n_policies) * initial_log_entropy

        self._initial_kl_coeffs = initial_kl_coeff

        ### Uniform Beta
        if len(initial_kl_coeff) == 1:
            kl_coeffs = torch.full(self.n_policies, initial_kl_coeff[0])

        ### Task-wise Betas
        else:
            assert len(initial_kl_coeff) == self.n_policies
            kl_coeffs = torch.Tensor(initial_kl_coeff)

        self._log_kl_coeffs = kl_coeffs.log()
        self.episode_rewards = np.zeros(self.n_policies)
        self.success_rates = np.zeros(self.n_policies)
        self.stages_completed = np.zeros(self.n_policies)

    def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.

        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._eval_env:
            self._eval_env = trainer.get_env_copy()
        last_return = None
        for _ in trainer.step_epochs():
            for _ in range(self._steps_per_epoch):
                self.train_once(trainer)
            last_return = self._evaluate_policy(trainer.step_itr)
            videos = self._visualize_policy(trainer.step_itr)
            infos = {}
            infos["AverageReturn"] = np.mean(
                [np.mean(self.episode_rewards[i]) for i in range(self.n_policies)]
            )
            infos["SuccessRate"] = np.mean(
                [np.mean(self.success_rates[i]) for i in range(self.n_policies)]
            )
            infos["StagesCompleted"] = np.mean(
                [np.mean(self.stages_completed[i]) for i in range(self.n_policies)]
            )
            infos["TotalEnvSteps"] = trainer.total_env_steps
            log_wandb(trainer.step_itr, infos, medias=videos, prefix="Train/")
            trainer.step_itr += 1

        return np.mean(last_return)

    def train_once(self, trainer):
        if not (
            np.any(
                [
                    self.replay_buffers[i].n_transitions_stored >= self._min_buffer_size
                    for i in range(self.n_policies)
                ]
            )
        ):
            batch_size = int(self._min_buffer_size) * self.n_policies
        else:
            batch_size = None

        if isinstance(self._sampler, RaySampler):
            # ray only supports CPU sampling
            with torch.no_grad():
                agent_update = copy.copy(self.policy)
                agent_update.policies = [
                    copy.deepcopy(policy).cpu() for policy in self.policies
                ]
        else:
            agent_update = None
        trainer.step_episode = trainer.obtain_samples(
            trainer.step_itr, batch_size, agent_update=agent_update
        )
        (
            path_returns,
            num_samples,
            num_path_starts,
            num_path_ends,
            num_successes,
            num_stages_completed,
        ) = (
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
        )

        step_types = []

        for path in trainer.step_episode:
            policy_id = get_path_policy_id(path)
            step_types.extend(path["step_types"])
            terminals = np.array(
                [step_type == StepType.TERMINAL for step_type in path["step_types"]]
            ).reshape(-1, 1)
            self.replay_buffers[policy_id].add_path(
                dict(
                    observation=self.preproc_obs(path["observations"])[0],
                    action=path["actions"],
                    reward=path["rewards"].reshape(-1, 1),
                    next_observation=self.preproc_obs(path["next_observations"])[0],
                    terminal=terminals,
                )
            )
            path_returns[policy_id] += sum(path["rewards"])
            num_samples[policy_id] += len(path["actions"])
            num_path_starts[policy_id] += np.sum(
                [step_type == StepType.FIRST for step_type in path["step_types"]]
            )
            num_path_ends[policy_id] += np.sum(terminals)
            if "success" in path["env_infos"]:
                num_successes[policy_id] += path["env_infos"]["success"].any()
            if "stages_completed" in path["env_infos"]:
                num_stages_completed[policy_id] += path["env_infos"][
                    "stages_completed"
                ][-1]

        for i in range(self.n_policies):
            num_paths = max(num_path_starts[i], num_path_ends[i], 1)  # AD-HOC
            self.episode_rewards[i] = path_returns[i] / num_paths
            self.success_rates[i] = (
                num_path_ends[i] and num_successes[i] / num_path_ends[i]
            )
            self.stages_completed[i] = (
                num_path_ends[i] and num_stages_completed[i] / num_path_ends[i]
            )
        ### ASDF Which way should the for loop go?  Policy id then gradient steps?
        for policy_id in range(self.n_policies):
            num_grad_steps = int(
                self._gradient_steps / np.sum(num_samples) * num_samples[policy_id]
            )
            central_losses, model_losses = (
                [],
                [],
            )
            for _ in range(num_grad_steps):
                (
                    central_loss,
                    model_loss,
                ) = self.train_policy_once(policy_id)
                central_losses.append(central_loss)
                model_losses.append(model_loss)

            self._log_statistics(
                trainer.step_itr,
                policy_id,
                central_losses,
                model_losses,
            )

    def train_policy_once(self, policy_id, itr=None, paths=None):
        """Complete 1 training iteration of SAC.

        Args:
            itr (int): Iteration number. This argument is deprecated.
            paths (list[dict]): A list of collected paths.
                This argument is deprecated.

        Returns:
            torch.Tensor: loss from central policy after optimization.
            torch.Tensor: loss from task-specific policy after optimization.

        """
        del itr
        del paths
        if self.replay_buffers[policy_id].n_transitions_stored >= (
            self._min_buffer_size
        ):
            all_obs, policy_samples = self.replay_buffers.sample_transitions(
                self._buffer_batch_size, policy_id
            )
            all_obs = [as_torch(obs) for obs in all_obs]
            policy_samples = as_torch_dict(policy_samples)
            central_loss, model_loss = self.optimize_policy(
                all_obs, policy_samples, policy_id=policy_id
            )

        else:
            central_loss, model_loss = (0.0, 0.0)

        self._update_targets(policy_id)

        return central_loss, model_loss

    def _learning_objective_1col(
        self, samples_data, new_actions, log_pi_new_actions, policy_id
    ):
        obs = samples_data["observation"]

        alpha = self._log_alpha[policy_id].exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        paper_alpha = beta / (alpha + beta)
        paper_beta = 1.0 / (alpha + beta)

        return objective

    def _actor_objective_2col(
        self, samples_data, new_actions, log_pi_new_actions, policy_id
    ):
        obs = samples_data["observation"]

        alpha = self._log_alpha[policy_id].exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        paper_alpha = beta / (alpha + beta)
        paper_beta = 1.0 / (alpha + beta)

        # Use advantage instead of regularized reward-to-go
        # Note: footnote #2
        central_dist = self.central_policy(obs)[0]
        q_dist = self._qfs[policy_id](obs)
        q = q_dist.gather(1, new_actions).flatten()
        v = (
            torch.logsumexp(
                paper_alpha * central_dist.probs + paper_beta * q_dist,
                dim=-1,
            )
            / paper_beta
        )
        advantage = q - v

        # M-2: return-to-go and V

        policy_objective = (log_pi_new_actions * advantage).mean()
        return policy_objective

    def _central_objective_1col(
        self, samples_data, new_actions, log_pi_new_actions, policy_id
    ):
        obs = samples_data["observation"]

        alpha = self._log_alpha[policy_id].exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        paper_alpha = beta / (alpha + beta)
        paper_beta = 1.0 / (alpha + beta)

        # Use advantage instead of regularized reward-to-go
        central_dist = self.central_policy(obs)[0]
        q_dist = self._qfs[policy_id](obs)
        q = q_dist.gather(1, new_actions).flatten()
        v = (
            torch.logsumexp(
                paper_alpha * central_dist.probs + paper_beta * q_dist,
                dim=-1,
            )
            / paper_beta
        )
        advantage = q - v

        policy_objective = (log_pi_new_actions * advantage).mean()
        return policy_objective

    def _distillation_objective(self, policy_dist, central_dist, policy_id):
        policy_dist = policy_dist.__class__(
            policy_dist.mean.detach(),
            policy_dist.stddev.detach(),
        )
        kl = torch.distributions.kl.kl_divergence(policy_dist, central_dist)
        kl = kl.sum(-1).mean() * self._log_kl_coeffs[policy_id].exp()
        return kl

    def _q_objective(self, samples_data, policy_id):
        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]

        alpha = self._log_alpha[policy_id].exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        paper_alpha = beta / (alpha + beta)
        paper_beta = 1.0 / (alpha + beta)

        q_pred = torch.gather(
            self.policies[policy_id](obs).logits,
            dim=-1,
            index=actions.unsqueeze(-1),
        )

        if self._two_column:
            # TODO
            pass

        with torch.no_grad():
            next_central_logp = self.central_policy(next_obs)[0].probs().log()
            next_q = self.target_policies[policy_id](next_obs).logits
            target_v = (
                torch.logsumexp(
                    paper_alpha * next_central_logp + paper_beta * next_q, dim=-1
                )
                / paper_beta
            )

            if self._two_column:
                # TODO
                pass

            q_target = (
                rewards * self._reward_scale
                + (1.0 - terminals) * self._discount * target_v
            )

        qf_loss = F.mse_loss(q_pred.flatten(), q_target)

        return qf_loss

    def _learning_objective(self, samples_data, policy_id):
        """Compute the Q-function/critic loss.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]
        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data, policy_id).exp()
        beta = self._log_kl_coeffs[policy_id].exp()
        # Note: alpha and beta correspond to c_Ent and c_KL, respectively.

        q1_pred = self._qf1s[policy_id](obs, actions)
        q2_pred = self._qf2s[policy_id](obs, actions)
        if self._two_column:
            central_q1 = self._central_qf1(obs, actions)
            central_q2 = self._central_qf2(obs, actions)
            q1_pred = (alpha * central_q1 + beta * q1_pred) / (alpha + beta)
            q2_pred = (alpha * central_q2 + beta * q2_pred) / (alpha + beta)

        new_next_actions_dist = self.policies[policy_id](next_obs)[0]
        (
            new_next_actions_pre_tanh,
            new_next_actions,
        ) = new_next_actions_dist.rsample_with_pre_tanh_value()
        new_log_pi = new_next_actions_dist.log_prob(
            value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh
        )

        # Also get log_pi from central policy
        central_dist = self.central_policy(next_obs)[0]
        (central_pre_tanh, central_actions) = central_dist.rsample_with_pre_tanh_value()
        central_log_pi = central_dist.log_prob(
            value=central_actions, pre_tanh_value=central_pre_tanh
        )

        target_q1 = self._target_qf1s[policy_id](next_obs, new_next_actions)
        target_q2 = self._target_qf2s[policy_id](next_obs, new_next_actions)
        if self._two_column:
            target_central_q1 = self._target_central_qf1(next_obs, new_next_actions)
            target_central_q2 = self._target_central_qf2(next_obs, new_next_actions)
            target_q1 = (alpha * target_central_q1 + beta * target_q1) / (alpha + beta)
            target_q2 = (alpha * target_central_q2 + beta * target_q2) / (alpha + beta)

        target_q_values = (
            torch.min(target_q1, target_q2).flatten()
            - (alpha * new_log_pi)
            + (beta * (central_log_pi - new_log_pi))
        )

        with torch.no_grad():
            q_target = (
                rewards * self._reward_scale
                + (1.0 - terminals) * self._discount * target_q_values
            )
        qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
        qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)

        return qf1_loss, qf2_loss

    def _update_targets(self, policy_id):
        """Update parameters in the target q-functions."""

        for t_param, param in zip(
            self.target_policies[policy_id].parameters(),
            self.policies[policy_id].parameters(),
        ):
            t_param.data.copy_(
                t_param.data * (1.0 - self._tau) + param.data * self._tau
            )

    def optimize_policy(self, all_obs, samples_data, policy_id):
        """Optimize the policy q_functions, and temperature coefficient.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        if self._two_column:
            loss = self._learning_objective(samples_data, policy_id)

            self._central_optimizer.zero_grad()
            self._policy_optimizers[policy_id].zero_grad()
            loss.backward()
            self._central_optimizer.step()
            self._policy_optimizers[policy_id].step()
        else:
            policy_loss = self._q_objective(samples_data, policy_id)
            self._policy_optimizers[policy_id].zero_grad()
            policy_loss.backward()
            self._policy_optimizers[policy_id].step()

        return (loss.item(),)

    def _evaluate_policy(self, epoch):
        """Evaluate the performance of the policy via deterministic sampling.

            Statistics such as (average) discounted return and success rate are
            recorded.

        Args:
            epoch (int): The current training epoch.

        Returns:
            float: The average return across self._num_evaluation_episodes
                episodes

        """
        if isinstance(self._eval_env, list):
            num_eval = self._num_evaluation_episodes // len(self._eval_env)
            eval_eps = []
            for eval_env in self._eval_env:
                eval_eps.append(
                    obtain_evaluation_episodes(
                        self.central_policy,  # self.policy,
                        eval_env,
                        self._max_episode_length_eval,
                        num_eps=num_eval,
                        deterministic=self._use_deterministic_evaluation,
                    )
                )
            eval_eps = EpisodeBatch.concatenate(*eval_eps)
            last_return = log_multitask_performance(
                epoch, eval_eps, discount=self._discount
            )
        else:
            eval_episodes = obtain_evaluation_episodes(
                self.central_policy,  # self.policy,
                self._eval_env,
                self._max_episode_length_eval,
                num_eps=self._num_evaluation_episodes,
                deterministic=self._use_deterministic_evaluation,
            )
            last_return = log_performance(epoch, eval_episodes, discount=self._discount)
        return last_return

    def _log_statistics(self, step, policy_id, central_losses, model_losses):
        """Record training statistics to dowel such as losses and returns.

        Args:
            central_losses (torch.Tensor): loss from central policy.
            model_losses (torch.Tensor): loss from task-specific policy.

        """

        infos = {}

        with torch.no_grad():
            infos["AlphaTemperature"] = self._log_alphas[policy_id].exp().mean().item()
            log_betas = self._log_kl_coeffs[policy_id].cpu().detach().numpy()
            infos["BetaKL"] = np.exp(log_betas)

        infos["CentralLoss"] = np.mean(central_losses)
        infos["ModelLoss"] = np.mean(model_losses)
        infos["ReplayBufferSize"] = self.replay_buffers[policy_id].n_transitions_stored
        infos["AverageReturn"] = np.mean(self.episode_rewards[policy_id])
        infos["SuccessRate"] = np.mean(self.success_rates[policy_id])
        infos["StagesCompleted"] = np.mean(self.stages_completed[policy_id])

        log_wandb(step, infos, prefix="Train/" + self.policies[policy_id].name + "/")

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        nets = [
            self.central_policy,
            *self.policies,
            *self.target_policies,
        ]
        return nets

    def to(self, device=None):
        """Put all the networks within the model on device.

        Args:
            device (str): ID of GPU or CPU.

        """
        if device is None:
            device = global_device()
        for net in self.networks:
            net.to(device)
        self._log_kl_coeffs = self._log_kl_coeffs.to(device)

    def _visualize_policy(self, epoch):
        if not self._visualizer.do_visualization(epoch):
            return None

        self._visualizer.reset()
        if isinstance(self._eval_env, list):
            num_vis = self._visualizer.num_videos // len(self._eval_env)
            for env in self._eval_env:
                self._render_env(env, num_vis)
        else:
            self._render_env(self._eval_env, self._visualizer.num_videos)

        return self._visualizer.get_video()

    def _render_env(self, env, num_vis):
        for _ in range(num_vis):
            last_obs, _ = env.reset()
            self.central_policy.reset()  # self.policy.reset()
            episode_length = 0
            while episode_length < self._max_episode_length_eval:
                # a, agent_info = self.policy.get_action(last_obs)
                a, agent_info = self.central_policy.get_action(last_obs)
                if self._use_deterministic_evaluation and "mean" in agent_info:
                    a = agent_info["mean"]
                es = env.step(a)
                info = {"reward": es.reward, "step": episode_length}
                info.update(es.env_info)
                self._visualizer.add(env.render(mode="rgb_array"), info)
                episode_length += 1
                if es.last:
                    break
                last_obs = es.observation
