"""This modules creates a sac model in PyTorch."""
# yapf: disable
from collections import deque
import copy

from dowel import tabular
import numpy as np
import torch
import torch.nn.functional as F

from garage import log_performance, obtain_evaluation_episodes, StepType
from garage.torch import as_torch_dict, as_torch, global_device

from learning.utils import get_path_policy_id
from .dnc_sac import DnCSAC

# yapf: enable

### DnC with high level policy
class HDnCSAC(DnCSAC):
    def __init__(
        self,
        env_spec,
        policy,
        hl_policy,
        ll_policies,
        hl_qf,
        qf1s,
        qf2s,
        replay_buffers,
        sampler,
        *,  # Everything after this is numbers.
        kl_coeff=0.01,
        max_episode_length_eval=None,
        gradient_steps_per_itr,
        fixed_alpha=None,
        target_entropy=None,
        initial_log_entropy=0.0,
        discount=0.99,
        buffer_batch_size=64,
        min_buffer_size=int(1e4),
        target_update_tau=5e-3,
        policy_lr=3e-4,
        qf_lr=3e-4,
        reward_scale=1.0,
        optimizer=torch.optim.Adam,
        steps_per_epoch=1,
        num_evaluation_episodes=10,
        eval_env=None,
        use_deterministic_evaluation=True
    ):
        self._hl_qf = hl_qf
        self._qf1s = qf1s
        self._qf2s = qf2s
        self.replay_buffers = replay_buffers
        self._tau = target_update_tau
        self._policy_lr = policy_lr
        self._qf_lr = qf_lr
        self._initial_log_entropy = initial_log_entropy
        self._gradient_steps = gradient_steps_per_itr
        self._optimizer = optimizer
        self._num_evaluation_episodes = num_evaluation_episodes
        self._eval_env = eval_env

        self._kl_coeff = kl_coeff
        self._min_buffer_size = min_buffer_size
        self._steps_per_epoch = steps_per_epoch
        self._buffer_batch_size = buffer_batch_size
        self._discount = discount
        self._reward_scale = reward_scale
        self.max_episode_length = env_spec.max_episode_length
        self._max_episode_length_eval = env_spec.max_episode_length

        if max_episode_length_eval is not None:
            self._max_episode_length_eval = max_episode_length_eval
        self._use_deterministic_evaluation = use_deterministic_evaluation

        self.policies = ll_policies
        self.hl_policy = hl_policy
        self.policy = policy
        self.env_spec = env_spec
        self.n_policies = len(self.policies)

        self._sampler = sampler

        self._reward_scale = reward_scale
        # use 2 target q networks
        self._target_qf1s = [copy.deepcopy(qf) for qf in self._qf1s]
        self._target_qf2s = [copy.deepcopy(qf) for qf in self._qf2s]
        self._hl_qf_optimizer = self._optimizer(
            self._hl_qf.parameters(), lr=self._qf_lr
        )
        self._policy_optimizers = [
            self._optimizer(policy.parameters(), lr=self._policy_lr)
            for policy in self.policies
        ]
        self._qf1_optimizers = [
            self._optimizer(qf.parameters(), lr=self._qf_lr) for qf in self._qf1s
        ]
        self._qf2_optimizers = [
            self._optimizer(qf.parameters(), lr=self._qf_lr) for qf in self._qf2s
        ]
        # automatic entropy coefficient tuning
        self._use_automatic_entropy_tuning = fixed_alpha is None
        self._fixed_alpha = fixed_alpha
        if self._use_automatic_entropy_tuning:
            if target_entropy:
                self._target_entropy = target_entropy
            else:
                self._target_entropy = -np.prod(self.env_spec.action_space.shape).item()
            self._log_alphas = [
                torch.Tensor([self._initial_log_entropy]).requires_grad_()
                for _ in range(self.n_policies)
            ]
            self._alpha_optimizers = [
                self._optimizer([a], lr=self._policy_lr) for a in self._log_alphas
            ]
        else:
            self._log_alphas = [
                torch.Tensor([self._fixed_alpha]).log() for _ in range(self.n_policies)
            ]

        self.episode_rewards = [deque(maxlen=30) for _ in range(self.n_policies)]
        self.num_trajectories = [0] * self.n_policies

    def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.

        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._eval_env:
            self._eval_env = trainer.get_env_copy()
        last_return = None
        for _ in trainer.step_epochs():
            for _ in range(self._steps_per_epoch):
                if not (
                    np.any(
                        [
                            self.replay_buffers[i].n_transitions_stored
                            >= self._min_buffer_size
                            for i in range(self.n_policies)
                        ]
                    )
                ):
                    batch_size = int(self._min_buffer_size)
                else:
                    batch_size = None

                trainer.step_episode = trainer.obtain_samples(
                    trainer.step_itr, batch_size
                )
                path_returns = [[] for _ in range(self.n_policies)]
                num_samples = [0] * self.n_policies

                for path in trainer.step_episode:
                    policy_id = get_path_policy_id(path)
                    self.replay_buffers[policy_id].add_path(
                        dict(
                            observation=path["observations"],
                            action=path["actions"],
                            reward=path["rewards"].reshape(-1, 1),
                            next_observation=path["next_observations"],
                            terminal=np.array(
                                [
                                    step_type == StepType.TERMINAL
                                    for step_type in path["step_types"]
                                ]
                            ).reshape(-1, 1),
                            policy_id=path["agent_infos"]["policy_id"],
                        )
                    )
                    path_returns[policy_id].append(sum(path["rewards"]))
                    num_samples[policy_id] += len(path["actions"])
                    self.num_trajectories[policy_id] += 1

                for i in range(self.n_policies):
                    self.episode_rewards[i].append(np.mean(path_returns[i]))

                for _ in range(self._gradient_steps):
                    self.train_once_hl()

                ### Train LL policies
                ### ASDF Which way should the for loop go?  Policy id then gradient steps?
                policy_losses, kl_penalties, qf1_losses, qf2_losses = [], [], [], []
                for policy_id in range(self.n_policies):
                    num_grad_steps = int(
                        self._gradient_steps
                        / np.sum(num_samples)
                        * num_samples[policy_id]
                    )
                    # num_grad_steps = self._gradient_steps
                    for _ in range(num_grad_steps):
                        policy_loss, kl_penalty, qf1_loss, qf2_loss = self.train_once(
                            policy_id
                        )
                        policy_losses.append(policy_loss)
                        kl_penalties.append(kl_penalty)
                        qf1_losses.append(qf1_loss)
                        qf2_losses.append(qf2_loss)
            last_return = self._evaluate_policy(trainer.step_itr)
            self._log_statistics(policy_losses, kl_penalties, qf1_losses, qf2_losses)
            self.num_trajectories = [0] * self.n_policies

            all_episode_rewards = []
            for i in range(self.n_policies):
                all_episode_rewards.extend(self.episode_rewards[i])
            tabular.record(
                "Average/TrainAverageReturn",
                np.nanmean(all_episode_rewards),
            )
            tabular.record("TotalEnvSteps", trainer.total_env_steps)
            trainer.step_itr += 1

        return np.mean(last_return)

    def train_once_hl(self):
        """
        Completes 1 training iteration for high level policy

        Returns:

        """
        samples = self.replay_buffers.sample_contexts(self._buffer_batch_size)
        samples = as_torch_dict(samples)

        contexts = samples["contexts"]
        returns = samples["returns"]
        policy_ids = samples["policy_ids"]

        qf_pred = torch.gather(
            self._hl_qf(contexts), dim=1, index=policy_ids.long().unsqueeze(1)
        )
        qf_loss = F.mse_loss(qf_pred.flatten(), returns)

        self._hl_qf_optimizer.zero_grad()
        qf_loss.backward()
        self._hl_qf_optimizer.step()

        return qf_loss

    def train_once(self, policy_id, itr=None, paths=None):
        """Complete 1 training iteration of SAC.

        Args:
            itr (int): Iteration number. This argument is deprecated.
            paths (list[dict]): A list of collected paths.
                This argument is deprecated.

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        del itr
        del paths
        ### ASDF self._min_buffer_size / self.n_policies
        if self.replay_buffers[policy_id].n_transitions_stored >= (
            self._min_buffer_size / self.n_policies
        ):
            all_obs, policy_samples = self.replay_buffers.sample_transitions(
                self._buffer_batch_size, policy_id
            )
            all_obs = as_torch(all_obs)
            policy_samples = as_torch_dict(policy_samples)
            policy_loss, kl_penalty, qf1_loss, qf2_loss = self.optimize_policy(
                all_obs, policy_samples, policy_id=policy_id
            )

            # samples = self.replay_buffers[policy_id].sample_transitions(
            #     self._buffer_batch_size
            # )
            # samples = as_torch_dict(samples)
            # policy_loss, kl_penalty, qf1_loss, qf2_loss = self.optimize_policy(
            #     samples["observation"], samples, policy_id=policy_id
            # )

        else:
            policy_loss, kl_penalty, qf1_loss, qf2_loss = (
                torch.Tensor([0]),
                torch.Tensor([0]),
                torch.Tensor([0]),
                torch.Tensor([0]),
            )

        self._update_targets(policy_id)

        return policy_loss, kl_penalty, qf1_loss, qf2_loss

    def _get_log_alpha(self, samples_data, policy_id):
        """Return the value of log_alpha.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        This function exists in case there are versions of sac that need
        access to a modified log_alpha, such as multi_task sac.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: log_alpha

        """
        del samples_data
        log_alpha = self._log_alphas[policy_id]
        return log_alpha

    def _temperature_objective(self, log_pi, samples_data, policy_id):
        """Compute the temperature/alpha coefficient loss.

        Args:
            log_pi(torch.Tensor): log probability of actions that are sampled
                from the replay buffer. Shape is (1, buffer_batch_size).
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: the temperature/alpha coefficient loss.

        """
        alpha_loss = 0
        if self._use_automatic_entropy_tuning:
            alpha_loss = (
                -(self._get_log_alpha(samples_data, policy_id))
                * (log_pi.detach() + self._target_entropy)
            ).mean()
        return alpha_loss

    def _actor_objective(
        self, samples_data, new_actions, log_pi_new_actions, policy_id
    ):
        """Compute the Policy/Actor loss.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.
            new_actions (torch.Tensor): Actions resampled from the policy based
                based on the Observations, obs, which were sampled from the
                replay buffer. Shape is (action_dim, buffer_batch_size).
            log_pi_new_actions (torch.Tensor): Log probability of the new
                actions on the TanhNormal distributions that they were sampled
                from. Shape is (1, buffer_batch_size).

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from the Policy/Actor.

        """
        obs = samples_data["observation"]
        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data, policy_id).exp()
        min_q_new_actions = torch.min(
            self._qf1s[policy_id](obs, new_actions),
            self._qf2s[policy_id](obs, new_actions),
        )
        policy_objective = (
            (alpha * log_pi_new_actions) - min_q_new_actions.flatten()
        ).mean()
        return policy_objective

    def _compute_mutual_kl(self, all_obs, policy_id):

        dists = [policy(all_obs)[0] for policy in self.policies]
        policy_dist = dists[policy_id]

        kl = torch.zeros(1).to(global_device())
        for (i, dist) in enumerate(dists):
            if i != policy_id:
                # compute KL between dist and dists[policy_id]
                kl += torch.distributions.kl.kl_divergence(dist, policy_dist).mean()

        return kl

    def _critic_objective(self, samples_data, policy_id):
        """Compute the Q-function/critic loss.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]
        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data, policy_id).exp()

        q1_pred = self._qf1s[policy_id](obs, actions)
        q2_pred = self._qf2s[policy_id](obs, actions)

        new_next_actions_dist = self.policies[policy_id](next_obs)[0]
        (
            new_next_actions_pre_tanh,
            new_next_actions,
        ) = new_next_actions_dist.rsample_with_pre_tanh_value()
        new_log_pi = new_next_actions_dist.log_prob(
            value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh
        )

        target_q_values = (
            torch.min(
                self._target_qf1s[policy_id](next_obs, new_next_actions),
                self._target_qf2s[policy_id](next_obs, new_next_actions),
            ).flatten()
            - (alpha * new_log_pi)
        )
        with torch.no_grad():
            q_target = (
                rewards * self._reward_scale
                + (1.0 - terminals) * self._discount * target_q_values
            )
        qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
        qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)

        return qf1_loss, qf2_loss

    def _update_targets(self, policy_id):
        """Update parameters in the target q-functions."""

        target_qfs = [self._target_qf1s[policy_id], self._target_qf2s[policy_id]]
        qfs = [self._qf1s[policy_id], self._qf2s[policy_id]]
        for target_qf, qf in zip(target_qfs, qfs):
            for t_param, param in zip(target_qf.parameters(), qf.parameters()):
                t_param.data.copy_(
                    t_param.data * (1.0 - self._tau) + param.data * self._tau
                )

    def optimize_policy(self, all_obs, samples_data, policy_id):
        """Optimize the policy q_functions, and temperature coefficient.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]
        qf1_loss, qf2_loss = self._critic_objective(samples_data, policy_id)

        self._qf1_optimizers[policy_id].zero_grad()
        qf1_loss.backward()
        self._qf1_optimizers[policy_id].step()

        self._qf2_optimizers[policy_id].zero_grad()
        qf2_loss.backward()
        self._qf2_optimizers[policy_id].step()

        action_dists = self.policies[policy_id](obs)[0]
        new_actions_pre_tanh, new_actions = action_dists.rsample_with_pre_tanh_value()
        log_pi_new_actions = action_dists.log_prob(
            value=new_actions, pre_tanh_value=new_actions_pre_tanh
        )

        policy_loss = self._actor_objective(
            samples_data, new_actions, log_pi_new_actions, policy_id
        )

        ### DnC KL Penalty
        kl = self._compute_mutual_kl(all_obs, policy_id)
        kl_penalty = self._kl_coeff * kl.mean()
        # kl_penalty = torch.Tensor([0]).mean()
        policy_loss += kl_penalty

        self._policy_optimizers[policy_id].zero_grad()
        policy_loss.backward()

        self._policy_optimizers[policy_id].step()

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(
                log_pi_new_actions, samples_data, policy_id
            )
            self._alpha_optimizers[policy_id].zero_grad()
            alpha_loss.backward()
            self._alpha_optimizers[policy_id].step()

        return policy_loss, kl_penalty, qf1_loss, qf2_loss

    def _evaluate_policy(self, epoch):
        """Evaluate the performance of the policy via deterministic sampling.

            Statistics such as (average) discounted return and success rate are
            recorded.

        Args:
            epoch (int): The current training epoch.

        Returns:
            float: The average return across self._num_evaluation_episodes
                episodes

        """
        eval_episodes = obtain_evaluation_episodes(
            self.policy,
            self._eval_env,
            self._max_episode_length_eval,
            num_eps=self._num_evaluation_episodes,
            deterministic=self._use_deterministic_evaluation,
        )
        last_return = log_performance(epoch, eval_episodes, discount=self._discount)
        return last_return

    def _log_statistics(self, policy_losses, kl_penalties, qf1_losses, qf2_losses):
        """Record training statistics to dowel such as losses and returns.

        Args:
            policy_loss (torch.Tensor): loss from actor/policy network.
            qf1_loss (torch.Tensor): loss from 1st qf/critic network.
            qf2_loss (torch.Tensor): loss from 2nd qf/critic network.

        """
        for i in range(self.n_policies):
            with tabular.prefix(self.policies[i].name):
                with torch.no_grad():
                    tabular.record(
                        "AlphaTemperature/mean", self._log_alphas[i].exp().mean().item()
                    )
                tabular.record("Policy/Loss", policy_losses[i].item())
                tabular.record("Policy/KLPenalty", kl_penalties[i].item())
                tabular.record("QF/{}".format("Qf1Loss"), float(qf1_losses[i]))
                tabular.record("QF/{}".format("Qf2Loss"), float(qf2_losses[i]))
                tabular.record(
                    "ReplayBuffer/buffer_size",
                    self.replay_buffers[i].n_transitions_stored,
                )

                ### ASDF Need a way to keep track of trajectories
                tabular.record("Average/Trajectories", self.num_trajectories[i])
                tabular.record(
                    "Average/TrainAverageReturn", np.nanmean(self.episode_rewards[i])
                )

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        ### ASDF do we need self.policy as well?
        return [
            self.hl_policy,
            *self.policies,
            *self._qf1s,
            *self._qf2s,
            *self._target_qf1s,
            *self._target_qf2s,
        ]

    def to(self, device=None):
        """Put all the networks within the model on device.

        Args:
            device (str): ID of GPU or CPU.

        """
        if device is None:
            device = global_device()
        for net in self.networks:
            net.to(device)
        if not self._use_automatic_entropy_tuning:
            self._log_alphas = [
                torch.Tensor([self._fixed_alpha]).log().to(device)
                for _ in range(self.n_policies)
            ]
        else:
            self._log_alphas = [
                (torch.Tensor([self._initial_log_entropy]).to(device).requires_grad_())
                for _ in range(self.n_policies)
            ]
            self._alpha_optimizers = [
                self._optimizer([a], lr=self._policy_lr) for a in self._log_alphas
            ]
