"""Vanilla Policy Gradient (REINFORCE)."""
import collections
import copy
import warnings

from dowel import tabular
import numpy as np
import torch
import torch.nn.functional as F

from garage import log_performance, EpisodeBatch
from garage.np import discount_cumsum
from garage.np.algos import RLAlgorithm
from garage.torch import compute_advantages, filter_valids
from garage.torch.optimizers import OptimizerWrapper

from learning.utils import DnCOptimizerWrapper, extract_policy_samples


class DnCVPG(RLAlgorithm):
    """Vanilla Policy Gradient (REINFORCE).

    VPG, also known as Reinforce, trains stochastic policy in an on-policy way.

    Args:
        env_spec (EnvSpec): Environment specification.
        policy (garage.torch.policies.Policy): Policy.
        value_function (garage.torch.value_functions.ValueFunction): The value
            function.
        sampler (garage.sampler.Sampler): Sampler.
        policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
            for policy.
        vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
            value function.
        num_train_per_epoch (int): Number of train_once calls per epoch.
        discount (float): Discount.
        gae_lambda (float): Lambda used for generalized advantage
            estimation.
        center_adv (bool): Whether to rescale the advantages
            so that they have mean 0 and standard deviation 1.
        positive_adv (bool): Whether to shift the advantages
            so that they are always positive. When used in
            conjunction with center_adv the advantages will be
            standardized before shifting.
        policy_ent_coeff (float): The coefficient of the policy entropy.
            Setting it to zero would mean no entropy regularization.
        use_softplus_entropy (bool): Whether to estimate the softmax
            distribution of the entropy to prevent the entropy from being
            negative.
        stop_entropy_gradient (bool): Whether to stop the entropy gradient.
        entropy_method (str): A string from: 'max', 'regularized',
            'no_entropy'. The type of entropy method to use. 'max' adds the
            dense entropy to the reward for each time step. 'regularized' adds
            the mean entropy to the surrogate objective. See
            https://arxiv.org/abs/1805.00909 for more details.

    """

    def __init__(
        self,
        env_spec,
        policy,
        policies,
        centroid,
        value_functions,
        sampler,
        policy_optimizers=None,
        centroid_optimizer=None,
        vf_optimizers=None,
        kl_coeff=1.0,
        num_train_per_epoch=1,
        discount=0.99,
        gae_lambda=1,
        center_adv=True,
        positive_adv=False,
        policy_ent_coeff=0.0,
        use_softplus_entropy=False,
        stop_entropy_gradient=False,
        entropy_method="no_entropy",
        track_centroid=False,
    ):
        self._discount = discount
        self.policy = policy
        self.policies = policies
        self.centroid = centroid
        self.max_episode_length = env_spec.max_episode_length
        self.n_policies = len(policies)

        self._kl_coeff = kl_coeff
        self._value_functions = value_functions
        self._gae_lambda = gae_lambda
        self._center_adv = center_adv
        self._positive_adv = positive_adv
        self._policy_ent_coeff = policy_ent_coeff
        self._use_softplus_entropy = use_softplus_entropy
        self._stop_entropy_gradient = stop_entropy_gradient
        self._entropy_method = entropy_method
        self._n_samples = num_train_per_epoch
        self._env_spec = env_spec
        self._track_centroid = track_centroid

        self._maximum_entropy = entropy_method == "max"
        self._entropy_regularzied = entropy_method == "regularized"
        self._check_entropy_configuration(
            entropy_method, center_adv, stop_entropy_gradient, policy_ent_coeff
        )
        self._episode_reward_mean = collections.deque(maxlen=100)
        self._sampler = sampler

        if policy_optimizers:
            self._policy_optimizers = policy_optimizers
        else:
            self._policy_optimizers = [
                DnCOptimizerWrapper(torch.optim.Adam, policy) for policy in policies
            ]
        if self._track_centroid:
            if centroid_optimizer:
                self._centroid_optimizer = centroid_optimizer
            else:
                self._centroid_optimizer = OptimizerWrapper(
                    torch.optim.Adam, self.centroid
                )
        if vf_optimizers:
            self._vf_optimizers = vf_optimizers
        else:
            self._vf_optimizers = [
                OptimizerWrapper(torch.optim.Adam, value_function)
                for value_function in value_functions
            ]

        self._mseloss = torch.nn.MSELoss()

        self._old_policies = [copy.deepcopy(policy) for policy in self.policies]

        self._total_policy_samples = [0] * self.n_policies

    @staticmethod
    def _check_entropy_configuration(
        entropy_method, center_adv, stop_entropy_gradient, policy_ent_coeff
    ):
        if entropy_method not in ("max", "regularized", "no_entropy"):
            raise ValueError("Invalid entropy_method")

        if entropy_method == "max":
            if center_adv:
                raise ValueError(
                    "center_adv should be False when " "entropy_method is max"
                )
            if not stop_entropy_gradient:
                raise ValueError(
                    "stop_gradient should be True when " "entropy_method is max"
                )
        if entropy_method == "no_entropy":
            if policy_ent_coeff != 0.0:
                raise ValueError(
                    "policy_ent_coeff should be zero " "when there is no entropy method"
                )

    @property
    def discount(self):
        """Discount factor used by the algorithm.

        Returns:
            float: discount factor.
        """
        return self._discount

    def _train_once(self, itr, all_eps):
        """Train the algorithm once.

        Args:
            itr (int): Iteration number.
            eps (EpisodeBatch): A batch of collected paths.

        Returns:
            numpy.float64: Calculated mean value of undiscounted returns.

        """
        eps_by_policy = extract_policy_samples(all_eps, self.n_policies)
        all_obs_flat = torch.Tensor(all_eps.observations)
        all_acs_flat = torch.Tensor(all_eps.actions)

        for i in range(self.n_policies):
            eps = eps_by_policy[i]
            if eps == []:
                ### ASDF didn't collect any samples with this policy this batch
                with tabular.prefix(self.policies[i].name):
                    tabular.record("/AverageReturn", 0)
                    tabular.record("/TotalNumSamples", 0)
                    tabular.record("/LossAfter", 0)
                    tabular.record("/KL", 0)
                    tabular.record("/Entropy", 0)
                    tabular.record("/MutualKL", 0)
                with tabular.prefix(self._value_functions[i].name):
                    tabular.record("/LossAfter", 0)
                continue
            obs = torch.Tensor(eps.padded_observations)
            rewards = torch.Tensor(eps.padded_rewards)
            returns = torch.Tensor(
                np.stack(
                    [
                        discount_cumsum(reward, self.discount)
                        for reward in eps.padded_rewards
                    ]
                )
            )
            valids = eps.lengths
            with torch.no_grad():
                baselines = self._value_functions[i](obs)

            if self._maximum_entropy:
                policy_entropies = self._compute_policy_entropy(obs)
                rewards += self._policy_ent_coeff * policy_entropies

            obs_flat = torch.Tensor(eps.observations)
            actions_flat = torch.Tensor(eps.actions)
            rewards_flat = torch.Tensor(eps.rewards)
            returns_flat = torch.cat(filter_valids(returns, valids))
            advs_flat = self._compute_advantage(rewards, valids, baselines)

            self._train(
                all_obs_flat,
                all_acs_flat,
                obs_flat,
                actions_flat,
                rewards_flat,
                returns_flat,
                advs_flat,
                i,
            )

            with torch.no_grad():
                policy_loss_after = self._compute_loss_with_adv(
                    all_obs_flat, obs_flat, actions_flat, rewards_flat, advs_flat, i
                )
                vf_loss_after = self._value_functions[i].compute_loss(
                    obs_flat, returns_flat
                )
                kl_after = self._compute_kl_constraint(obs, i)
                policy_entropy = self._compute_policy_entropy(obs, i)
                mutual_kl = self._compute_mutual_kl(all_obs_flat, i)
                undiscounted_returns = log_performance(
                    itr, eps, discount=self._discount
                )
                self._total_policy_samples[i] += np.sum(eps.lengths)

            if self._track_centroid:
                with torch.no_grad():
                    centroid_returns = 0  # ASDF how to do
                    centroid_loss = self._compute_centroid_loss(
                        all_obs_flat, all_acs_flat
                    )
                with tabular.prefix(self.centroid.name):
                    tabular.record("/AverageReturn", np.mean(centroid_returns))
                    tabular.record("/Loss", np.mean(centroid_loss.item()))

            with tabular.prefix(self.policies[i].name):
                tabular.record("/AverageReturn", np.mean(undiscounted_returns))
                tabular.record("/TotalNumSamples", self._total_policy_samples[i])
                tabular.record("/LossAfter", policy_loss_after.item())
                tabular.record("/KL", kl_after.item())
                tabular.record("/Entropy", policy_entropy.mean().item())
                tabular.record("/MutualKL", mutual_kl.item())

            with tabular.prefix(self._value_functions[i].name):
                tabular.record("/LossAfter", vf_loss_after.item())

            self._old_policies[i].load_state_dict(self.policies[i].state_dict())

        undiscounted_returns = log_performance(itr, all_eps, discount=self._discount)
        return np.mean(undiscounted_returns)

    def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.

        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.

        Returns:
            float: The average return in last epoch cycle.

        """
        last_return = None

        for _ in trainer.step_epochs():
            for _ in range(self._n_samples):
                eps = trainer.obtain_episodes(trainer.step_itr)
                last_return = self._train_once(trainer.step_itr, eps)
                trainer.step_itr += 1

        return last_return

    def _train(self, all_obs, all_acs, obs, actions, rewards, returns, advs, policy_id):
        r"""Train the policy and value function with minibatch.

        Args:
            obs (torch.Tensor): Observation from the environment with shape
                :math:`(N, O*)`.
            actions (torch.Tensor): Actions fed to the environment with shape
                :math:`(N, A*)`.
            rewards (torch.Tensor): Acquired rewards with shape :math:`(N, )`.
            returns (torch.Tensor): Acquired returns with shape :math:`(N, )`.
            advs (torch.Tensor): Advantage value at each step with shape
                :math:`(N, )`.

        """

        ### Train Policies
        for all_obs_dataset, dataset in self._policy_optimizers[
            policy_id
        ].get_minibatch(all_obs, obs, actions, rewards, advs):
            self._train_policy(*all_obs_dataset, *dataset, policy_id)
        for dataset in self._vf_optimizers[policy_id].get_minibatch(obs, returns):
            self._train_value_function(*dataset, policy_id)

        if self._track_centroid:
            ### Train Centroid
            for dataset in self._centroid_optimizer.get_minibatch(all_obs, all_acs):
                self._train_centroid(*dataset)

    def _train_policy(self, all_obs, obs, actions, rewards, advantages, policy_id):
        r"""Train the policy.

        Args:
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N, O*)`.
            actions (torch.Tensor): Actions fed to the environment
                with shape :math:`(N, A*)`.
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N, )`.
            advantages (torch.Tensor): Advantage value at each step
                with shape :math:`(N, )`.

        Returns:
            torch.Tensor: Calculated mean scalar value of policy loss (float).

        """
        self._policy_optimizers[policy_id].zero_grad()
        loss = self._compute_loss_with_adv(
            all_obs, obs, actions, rewards, advantages, policy_id
        )
        loss.backward()
        self._policy_optimizers[policy_id].step()

        return loss

    def _train_centroid(self, all_obs, all_actions):
        """Trains centroid with behavioral cloning loss on all samples collected."""
        self._centroid_optimizer.zero_grad()
        loss = self._compute_centroid_loss(all_obs, all_actions)
        loss.backward()
        self._centroid_optimizer.step()

    def _compute_centroid_loss(self, all_obs, all_actions):
        ### ASDF the gets actions only for gaussian distributions
        centroid_actions = self.centroid(all_obs)[1]["mean"]
        loss = self._mseloss(all_actions, centroid_actions)
        return loss

    def _train_value_function(self, obs, returns, policy_id):
        r"""Train the value function.

        Args:
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N, O*)`.
            returns (torch.Tensor): Acquired returns
                with shape :math:`(N, )`.

        Returns:
            torch.Tensor: Calculated mean scalar value of value function loss
                (float).

        """
        self._vf_optimizers[policy_id].zero_grad()
        loss = self._value_functions[policy_id].compute_loss(obs, returns)
        loss.backward()
        self._vf_optimizers[policy_id].step()

        return loss

    def _compute_loss(self, obs, actions, rewards, valids, baselines, policy_id):
        r"""Compute mean value of loss.

        Notes: P is the maximum episode length (self.max_episode_length)

        Args:
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N, P, O*)`.
            actions (torch.Tensor): Actions fed to the environment
                with shape :math:`(N, P, A*)`.
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N, P)`.
            valids (list[int]): Numbers of valid steps in each episode
            baselines (torch.Tensor): Value function estimation at each step
                with shape :math:`(N, P)`.

        Returns:
            torch.Tensor: Calculated negative mean scalar value of
                objective (float).

        """
        obs_flat = torch.cat(filter_valids(obs, valids))
        actions_flat = torch.cat(filter_valids(actions, valids))
        rewards_flat = torch.cat(filter_valids(rewards, valids))
        advantages_flat = self._compute_advantage(rewards, valids, baselines)

        return self._compute_loss_with_adv(
            obs_flat, actions_flat, rewards_flat, advantages_flat, policy_id
        )

    def _compute_loss_with_adv(
        self, all_obs, obs, actions, rewards, advantages, policy_id
    ):
        r"""Compute mean value of loss.

        Args:
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N \dot [T], O*)`.
            actions (torch.Tensor): Actions fed to the environment
                with shape :math:`(N \dot [T], A*)`.
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N \dot [T], )`.
            advantages (torch.Tensor): Advantage value at each step
                with shape :math:`(N \dot [T], )`.

        Returns:
            torch.Tensor: Calculated negative mean scalar value of objective.

        """
        objectives = self._compute_objective(
            advantages, obs, actions, rewards, policy_id
        )

        if self._entropy_regularzied:
            policy_entropies = self._compute_policy_entropy(obs, policy_id)
            objectives += self._policy_ent_coeff * policy_entropies

        kl = self._compute_mutual_kl(all_obs, policy_id)
        objectives -= self._kl_coeff * kl

        return -objectives.mean()

    def _compute_advantage(self, rewards, valids, baselines):
        r"""Compute mean value of loss.

        Notes: P is the maximum episode length (self.max_episode_length)

        Args:
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N, P)`.
            valids (list[int]): Numbers of valid steps in each episode
            baselines (torch.Tensor): Value function estimation at each step
                with shape :math:`(N, P)`.

        Returns:
            torch.Tensor: Calculated advantage values given rewards and
                baselines with shape :math:`(N \dot [T], )`.

        """
        advantages = compute_advantages(
            self._discount,
            self._gae_lambda,
            self.max_episode_length,
            baselines,
            rewards,
        )
        advantage_flat = torch.cat(filter_valids(advantages, valids))

        if self._center_adv:
            means = advantage_flat.mean()
            variance = advantage_flat.var()
            advantage_flat = (advantage_flat - means) / (variance + 1e-8)

        if self._positive_adv:
            advantage_flat -= advantage_flat.min()

        return advantage_flat

    def _compute_kl_constraint(self, obs, policy_id):
        r"""Compute KL divergence.

        Compute the KL divergence between the old policy distribution and
        current policy distribution.

        Notes: P is the maximum episode length (self.max_episode_length)

        Args:
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N, P, O*)`.

        Returns:
            torch.Tensor: Calculated mean scalar value of KL divergence
                (float).

        """
        with torch.no_grad():
            old_dist = self._old_policies[policy_id](obs)[0]

        new_dist = self.policies[policy_id](obs)[0]

        kl_constraint = torch.distributions.kl.kl_divergence(old_dist, new_dist)

        return kl_constraint.mean()

    def _compute_policy_entropy(self, obs, policy_id):
        r"""Compute entropy value of probability distribution.

        Notes: P is the maximum episode length (self.max_episode_length)

        Args:
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N, P, O*)`.

        Returns:
            torch.Tensor: Calculated entropy values given observation
                with shape :math:`(N, P)`.

        """
        if self._stop_entropy_gradient:
            with torch.no_grad():
                policy_entropy = self.policies[policy_id](obs)[0].entropy()
        else:
            policy_entropy = self.policies[policy_id](obs)[0].entropy()

        # This prevents entropy from becoming negative for small policy std
        if self._use_softplus_entropy:
            policy_entropy = F.softplus(policy_entropy)

        return policy_entropy

    def _compute_objective(self, advantages, obs, actions, rewards, policy_id):
        r"""Compute objective value.

        Args:
            advantages (torch.Tensor): Advantage value at each step
                with shape :math:`(N \dot [T], )`.
            obs (torch.Tensor): Observation from the environment
                with shape :math:`(N \dot [T], O*)`.
            actions (torch.Tensor): Actions fed to the environment
                with shape :math:`(N \dot [T], A*)`.
            rewards (torch.Tensor): Acquired rewards
                with shape :math:`(N \dot [T], )`.

        Returns:
            torch.Tensor: Calculated objective values
                with shape :math:`(N \dot [T], )`.

        """
        del rewards
        log_likelihoods = self.policies[policy_id](obs)[0].log_prob(actions)

        return log_likelihoods * advantages

    def _compute_mutual_kl(self, all_obs, policy_id):

        dists = [policy(all_obs)[0] for policy in self.policies]
        policy_dist = dists[policy_id]

        kl = torch.Tensor([0])
        for (i, dist) in enumerate(dists):
            if i != policy_id:
                # compute KL between dist and dists[policy_id]
                kl += torch.distributions.kl.kl_divergence(dist, policy_dist).mean()

        return kl
