"""
This module creates a MoP DnC model that trains both a Q_i and Q_mixture by training on
task-specific policy-specific data and task-specific mixture data, respectively.
"""

# yapf: disable
from collections import deque
import copy
import time

from dowel import tabular
import numpy as np
import torch
import torch.nn.functional as F

from garage import StepType, EpisodeBatch
from garage.np.algos import RLAlgorithm
from garage.sampler import RaySampler
from garage.torch import as_torch_dict, as_torch, global_device

from learning.utils import (get_path_policy_id, get_policy_ids, get_path_task_id,
                            log_multitask_performance, log_wandb, obtain_multitask_multimode_evaluation_episodes)

from learning.algorithms import MoPDnC
# yapf: enable


class MoPDnCv2(MoPDnC):
    def __init__(self, *, Qi1s, Qi2s, policy_replay_buffers, **kwargs):
        super().__init__(**kwargs)

        self._Qi1s = Qi1s
        self._Qi2s = Qi2s

        # use 2 target q networks
        self._target_Qi1s = [copy.deepcopy(qf) for qf in self._Qi1s]
        self._target_Qi2s = [copy.deepcopy(qf) for qf in self._Qi2s]

        self._Qi1_optimizers = [
            self._optimizer(qf.parameters(), lr=self._qf_lr) for qf in self._Qi1s
        ]
        self._Qi2_optimizers = [
            self._optimizer(qf.parameters(), lr=self._qf_lr) for qf in self._Qi2s
        ]

        self.policy_replay_buffers = policy_replay_buffers

    def train_once(self, trainer):
        if not (
            np.all(
                [
                    self.replay_buffers[i].n_transitions_stored >= self._min_buffer_size
                    for i in range(self.n_policies)
                ]
            )
        ):
            batch_size = int(self._min_buffer_size) * self.n_policies
        else:
            batch_size = None

        if isinstance(self._sampler, RaySampler):
            # ray only supports CPU sampling
            with torch.no_grad():
                agent_update = copy.copy(self.policy)
                agent_update.policies = [
                    copy.deepcopy(policy).cpu() for policy in self.policies
                ]
                agent_update.score_functions = [
                    copy.deepcopy(Qi).cpu() for Qi in self._Qi1s
                ]
                agent_update.score_function2s = [
                    copy.deepcopy(Qi).cpu() for Qi in self._Qi2s
                ]
        else:
            agent_update = None

        trainer.step_episode = trainer.obtain_samples(
            trainer.step_itr, batch_size, agent_update=agent_update
        )
        (
            path_returns,
            num_samples,
            num_path_starts,
            num_path_ends,
            num_successes,
            num_stages_completed,
        ) = (
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
            [0] * self.n_policies,
        )
        policy_counts = np.zeros((self.n_policies, self.n_policies))

        step_types = []
        for path in trainer.step_episode:
            task_id = get_path_task_id(path)
            step_types.extend(path["step_types"])
            terminals = np.array(
                [step_type == StepType.TERMINAL for step_type in path["step_types"]]
            ).reshape(-1, 1)

            ### Add path by task to replay_buffers
            policy_id = get_path_policy_id(path)
            self.replay_buffers[policy_id].add_path(
                dict(
                    observation=self.preproc_obs(path["observations"])[0],
                    action=path["actions"],
                    reward=path["rewards"].reshape(-1, 1),
                    next_observation=self.preproc_obs(path["next_observations"])[0],
                    terminal=terminals,
                )
            )

            ### Add policy specific samples to policy_replay_buffers
            policy_ids = get_policy_ids(path)
            idxs = np.nonzero(policy_ids == policy_id)
            if len(idxs) > 0:
                policy_path = dict(
                    observation=self.preproc_obs(path["observations"][idxs])[0],
                    action=path["actions"][idxs],
                    reward=path["rewards"][idxs].reshape(-1, 1),
                    next_observation=self.preproc_obs(path["next_observations"][idxs])[
                        0
                    ],
                    terminal=terminals[idxs],
                )
                try:
                    self.policy_replay_buffers[policy_id].add_path(policy_path)
                except:
                    print(f"Idxs: {idxs}")

                    import ipdb

                    ipdb.set_trace()

            path_returns[task_id] += sum(path["rewards"])
            num_samples[task_id] += len(path["actions"])
            num_path_starts[task_id] += np.sum(
                [step_type == StepType.FIRST for step_type in path["step_types"]]
            )
            num_path_ends[task_id] += np.sum(terminals)
            if "success" in path["env_infos"]:
                num_successes[task_id] += path["env_infos"]["success"].any()
            if "stages_completed" in path["env_infos"]:
                num_stages_completed[task_id] += path["env_infos"]["stages_completed"][
                    -1
                ]

            for i in range(self.n_policies):
                policy_counts[task_id][i] += (
                    path["agent_infos"]["real_policy_id"] == i
                ).sum()

        for i in range(self.n_policies):
            num_paths = max(num_path_starts[i], num_path_ends[i], 1)  # AD-HOC
            self.episode_rewards[i] = path_returns[i] / num_paths
            self.success_rates[i] = (
                num_path_ends[i] and num_successes[i] / num_path_ends[i]
            )
            self.stages_completed[i] = (
                num_path_ends[i] and num_stages_completed[i] / num_path_ends[i]
            )
            self.mixture_probs[i] = policy_counts[i] / np.sum(policy_counts[i])

        for policy_id in range(self.n_policies):
            num_grad_steps = int(
                self._gradient_steps / np.sum(num_samples) * num_samples[policy_id]
            )

            (
                policy_losses,
                kl_penalties,
                kls,
                qf1_losses,
                qf2_losses,
                Qi1_losses,
                Qi2_losses,
            ) = (
                [],
                [],
                [],
                [],
                [],
                [],
                [],
            )
            for _ in range(num_grad_steps):
                (
                    policy_loss,
                    kl_penalty,
                    kl,
                    qf1_loss,
                    qf2_loss,
                    Qi1_loss,
                    Qi2_loss,
                ) = self.train_policy_once(policy_id)
                policy_losses.append(policy_loss)
                kl_penalties.append(kl_penalty)
                kls.append(kl)
                qf1_losses.append(qf1_loss)
                qf2_losses.append(qf2_loss)
                Qi1_losses.append(Qi1_loss)
                Qi2_losses.append(Qi2_loss)

            self._log_statistics(
                trainer.step_itr,
                policy_id,
                policy_losses,
                kl_penalties,
                kls,
                qf1_losses,
                qf2_losses,
                Qi1_losses,
                Qi2_losses,
            )

    def train_policy_once(self, policy_id, itr=None, paths=None):
        """Complete 1 training iteration of SAC.

        Args:
            itr (int): Iteration number. This argument is deprecated.
            paths (list[dict]): A list of collected paths.
                This argument is deprecated.

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        del itr
        del paths
        if self.replay_buffers[policy_id].n_transitions_stored >= (
            self._min_buffer_size
        ):
            all_obs, task_samples = self.replay_buffers.sample_transitions(
                self._buffer_batch_size, policy_id
            )
            policy_samples = self.policy_replay_buffers.sample_transitions(
                self._buffer_batch_size, policy_id
            )
            all_obs = [as_torch(obs) for obs in all_obs]
            task_samples = as_torch_dict(task_samples)
            policy_samples = as_torch_dict(policy_samples)

            (
                policy_loss,
                kl_penalty,
                kl,
                qf1_loss,
                qf2_loss,
                Qi1_loss,
                Qi2_loss,
            ) = self.optimize_policy(
                all_obs, task_samples, policy_samples, policy_id=policy_id
            )

        else:
            policy_loss, kl_penalty, kl, qf1_loss, qf2_loss, Qi1_loss, Qi2_loss = (
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
            )

        self._update_targets(policy_id)

        return policy_loss, kl_penalty, kl, qf1_loss, qf2_loss, Qi1_loss, Qi2_loss

    def _critic_objective(self, samples_data, policy_id, qf_obj=True):
        """Compute the Q-function/critic loss.

        Depending on qf_obj, compute objective for qf's or Qi's

        """

        obs = samples_data["observation"]
        actions = samples_data["action"]
        rewards = samples_data["reward"].flatten()
        terminals = samples_data["terminal"].flatten()
        next_obs = samples_data["next_observation"]
        with torch.no_grad():
            alpha = self._get_log_alpha(samples_data, policy_id).exp()

        if qf_obj:
            q1_pred = self._qf1s[policy_id](obs, actions)
            q2_pred = self._qf2s[policy_id](obs, actions)
        else:
            q1_pred = self._Qi1s[policy_id](obs, actions)
            q2_pred = self._Qi2s[policy_id](obs, actions)

        new_next_actions_dist = self.policies[policy_id](next_obs)[0]
        (
            new_next_actions_pre_tanh,
            new_next_actions,
        ) = new_next_actions_dist.rsample_with_pre_tanh_value()
        new_log_pi = new_next_actions_dist.log_prob(
            value=new_next_actions, pre_tanh_value=new_next_actions_pre_tanh
        )

        if qf_obj:
            target_q_values = (
                torch.min(
                    self._target_qf1s[policy_id](next_obs, new_next_actions),
                    self._target_qf2s[policy_id](next_obs, new_next_actions),
                ).flatten()
                - (alpha * new_log_pi)
            )
        else:
            target_q_values = (
                torch.min(
                    self._target_Qi1s[policy_id](next_obs, new_next_actions),
                    self._target_Qi2s[policy_id](next_obs, new_next_actions),
                ).flatten()
                - (alpha * new_log_pi)
            )

        with torch.no_grad():
            q_target = (
                rewards * self._reward_scale
                + (1.0 - terminals) * self._discount * target_q_values
            )
        qf1_loss = F.mse_loss(q1_pred.flatten(), q_target)
        qf2_loss = F.mse_loss(q2_pred.flatten(), q_target)

        return qf1_loss, qf2_loss

    def _update_targets(self, policy_id):
        """Update parameters in the target q-functions and target Qi-functions."""

        target_qfs = [self._target_qf1s[policy_id], self._target_qf2s[policy_id]]
        qfs = [self._qf1s[policy_id], self._qf2s[policy_id]]
        for target_qf, qf in zip(target_qfs, qfs):
            for t_param, param in zip(target_qf.parameters(), qf.parameters()):
                t_param.data.copy_(
                    t_param.data * (1.0 - self._tau) + param.data * self._tau
                )

        target_Qis = [self._target_Qi1s[policy_id], self._target_Qi2s[policy_id]]
        Qis = [self._Qi1s[policy_id], self._Qi2s[policy_id]]
        for target_Qi, Qi in zip(target_Qis, Qis):
            for t_param, param in zip(target_Qi.parameters(), Qi.parameters()):
                t_param.data.copy_(
                    t_param.data * (1.0 - self._tau) + param.data * self._tau
                )

    def optimize_policy(
        self, all_obs, samples_data, policy_data, policy_id, log_betas=[]
    ):
        """Optimize the policy q_functions, and temperature coefficient.

        Args:
            samples_data (dict): Transitions(S,A,R,S') that are sampled from
                the replay buffer. It should have the keys 'observation',
                'action', 'reward', 'terminal', and 'next_observations'.

        Note:
            samples_data's entries should be torch.Tensor's with the following
            shapes:
                observation: :math:`(N, O^*)`
                action: :math:`(N, A^*)`
                reward: :math:`(N, 1)`
                terminal: :math:`(N, 1)`
                next_observation: :math:`(N, O^*)`

        Returns:
            torch.Tensor: loss from actor/policy network after optimization.
            torch.Tensor: loss from 1st q-function after optimization.
            torch.Tensor: loss from 2nd q-function after optimization.

        """
        obs = samples_data["observation"]
        qf1_loss, qf2_loss = self._critic_objective(
            samples_data, policy_id, qf_obj=True
        )

        self._qf1_optimizers[policy_id].zero_grad()
        qf1_loss.backward()
        self._qf1_optimizers[policy_id].step()

        self._qf2_optimizers[policy_id].zero_grad()
        qf2_loss.backward()
        self._qf2_optimizers[policy_id].step()

        ### Optimize Qi1/2 with critic loss

        Qi1_loss, Qi2_loss = self._critic_objective(
            policy_data, policy_id, qf_obj=False
        )

        self._Qi1_optimizers[policy_id].zero_grad()
        Qi1_loss.backward()
        self._Qi1_optimizers[policy_id].step()

        self._Qi2_optimizers[policy_id].zero_grad()
        Qi2_loss.backward()
        self._Qi2_optimizers[policy_id].zero_grad()

        action_dists = self.policies[policy_id](obs)[0]
        new_actions_pre_tanh, new_actions = action_dists.rsample_with_pre_tanh_value()
        log_pi_new_actions = action_dists.log_prob(
            value=new_actions, pre_tanh_value=new_actions_pre_tanh
        )

        policy_loss = self._actor_objective(
            samples_data, new_actions, log_pi_new_actions, policy_id
        )

        ### DnC KL Penalty
        kl, kl_penalty = self._compute_kl_penalty(
            obs, all_obs, policy_id, log_betas=log_betas
        )
        policy_loss = policy_loss + kl_penalty

        self._policy_optimizers[policy_id].zero_grad()
        policy_loss.backward()

        self._policy_optimizers[policy_id].step()

        if self._use_automatic_entropy_tuning:
            alpha_loss = self._temperature_objective(
                log_pi_new_actions, samples_data, policy_id
            )
            self._alpha_optimizers[policy_id].zero_grad()
            alpha_loss.backward()
            self._alpha_optimizers[policy_id].step()

        return (
            policy_loss.item(),
            kl_penalty.item(),
            kl.mean().item(),
            qf1_loss.item(),
            qf2_loss.item(),
            Qi1_loss.item(),
            Qi2_loss.item(),
        )

    def _log_statistics(
        self,
        step,
        policy_id,
        policy_losses,
        kl_penalties,
        kls,
        qf1_losses,
        qf2_losses,
        Qi1_losses,
        Qi2_losses,
    ):
        """Record training statistics to dowel such as losses and returns.

        Args:
            policy_loss (torch.Tensor): loss from actor/policy network.
            qf1_loss (torch.Tensor): loss from 1st qf/critic network.
            qf2_loss (torch.Tensor): loss from 2nd qf/critic network.
            Qi1_loss (torch.Tensor): loss from 1st Qi/critic network.
            Qi2_loss (torch.Tensor): loss from 2nd Qi/critic network.

        """

        infos = {}

        with torch.no_grad():
            infos["AlphaTemperature"] = self._log_alphas[policy_id].exp().mean().item()
            log_betas = self._log_kl_coeffs[policy_id].cpu().detach().numpy()
            log_betas = np.concatenate(
                [log_betas[:policy_id], log_betas[policy_id + 1 :]]
            )
            betamean = np.mean(np.exp(log_betas))
            infos["BetaKL"] = betamean

        infos["PolicyLoss"] = np.mean(policy_losses)
        infos["PolicyKLPenalty"] = np.mean(kl_penalties)
        infos["PolicyKL"] = np.mean(kls)
        infos["Qf1Loss"] = np.mean(qf1_losses)
        infos["Qf2Loss"] = np.mean(qf2_losses)
        infos["Qi1Loss"] = np.mean(Qi1_losses)
        infos["Qi2Loss"] = np.mean(Qi2_losses)
        infos["ReplayBufferSize"] = self.replay_buffers[policy_id].n_transitions_stored
        infos["AverageReturn"] = np.mean(self.episode_rewards[policy_id])
        infos["SuccessRate"] = np.mean(self.success_rates[policy_id])
        infos["StagesCompleted"] = np.mean(self.stages_completed[policy_id])
        infos["MixtureProbs"] = self.mixture_probs[policy_id][policy_id]

        for i in range(self.n_policies):
            infos[f"Policy{i}Prob"] = self.mixture_probs[policy_id][i]

        log_wandb(step, infos, prefix="Train/" + self.policies[policy_id].name + "/")

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        return [
            *self.policies,
            *self._qf1s,
            *self._qf2s,
            *self._target_qf1s,
            *self._target_qf2s,
            *self._Qi1s,
            *self._Qi2s,
            *self._target_Qi1s,
            *self._target_Qi2s,
        ]
