import numpy as np
import torch
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py
from train.reinforcment_learning.utils.utils import scale_grad_norm_


class PPO:

    @staticmethod
    def compute_gradient_norm(params):
        total_norm = 0
        for p in params:
            if np.prod(p.shape) > 0 and p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        return total_norm

    @staticmethod
    def unique_parameters(params):
        params = list(set(list(params)))
        return params

    def __init__(self, model, dataset, gamma=1., gae_lambda=1.,
                 eps_clip=.1, pg_coef=1., vf_coef=0.25, ent_coef_bin=0., ent_coef_cam=0.,
                 lrs=1e-3, betas=(0.9, 0.999), clamping=False, clip_grad_norm=None,
                 bound_value_loss=False, value_clip=None, lrs_critic_head=1e-4, scale_critic_norm=False, envname=None,
                 device=None):
        """
        :param model:
        :param dataset:
        :param gamma:
        :param gae_lambda:
        :param eps_clip:
        :param pg_coef:
        :param vf_coef:
        :param ent_coef_bin:
        :param ent_coef_cam:
        :param lrs:
        :param betas:
        :param clamping: ("sum", "separate", "default")
        :param clip_grad_norm:
        :param bound_value_loss:
        :param value_clip:
        :param lrs_critic_head: for pre-training the value function only
        :param device:
        """

        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.eps_clip = eps_clip
        self.device = device

        self.envname = envname

        self.eps_clip_initial = eps_clip
        self.ent_coef_bin_initial = ent_coef_bin
        self.ent_coef_cam_initial = ent_coef_cam
        self.lrs = lrs
        self.lrs_initial = lrs
        self.lrs_critic_head = lrs_critic_head
        self.clip_grad_norm = clip_grad_norm
        self.scale_critic_norm = scale_critic_norm

        self.pg_coef = pg_coef
        self.ent_coef_bin = ent_coef_bin
        self.ent_coef_cam = ent_coef_cam
        self.vf_coef = vf_coef

        self.model = model
        self.mse = torch.nn.MSELoss()

        self._freeze_batch_norm()

        self.actor_critic_params = self.unique_parameters(self.model.parameters())
        self.actor_params = self.unique_parameters(self.model.policy_parameters())
        self.critic_head_params = list(set(self.actor_critic_params).difference(set(self.actor_params)))

        self.optimizer = torch.optim.Adam(self.actor_critic_params, lr=self.lrs, betas=betas)
        self.optimizer_critic_head = torch.optim.Adam(self.critic_head_params, lr=self.lrs_critic_head, betas=betas)
        self.device = device

        self.action_space = dataset.ACTION_SPACE
        self.input_space = dataset.INPUT_SPACE
        self.value_clip = value_clip
        self.clamping = clamping
        self.bound_value_loss = bound_value_loss

    def update_value_head(self, mb_returns, mb_pov, mb_bin_actions, mb_camera_actions, mb_actions):
        """ Update critic part of the model """

        # create a input dictionary from pov, bin_actions and camera_actions
        input_dict = {"pov": torch.from_numpy(mb_pov).to(self.device).detach(),
                      "binary_actions": torch.from_numpy(mb_bin_actions).to(self.device).detach(),
                      "camera_actions": torch.from_numpy(mb_camera_actions).to(self.device).detach()}

        # compute model output and evaluate actions taken
        out_dict = self.model(input_dict)
        if self.envname == 'MineRLTreechop-v0':
            value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = self.action_space.evaluate_actions(out_dict, mb_actions)
        else:
            value, action_log_probs, camera_log_probs, equip_log_probs, place_log_probs, craft_log_probs, \
            nearbyCraft_log_probs, nearbySmelt_log_probs, action_entropy, camera_entropy, equip_entropy, place_entropy, \
            craft_entropy, nearbyCraft_entropy, nearbySmelt_entropy = \
                self.action_space.evaluate_actions(out_dict, mb_actions)

        # value loss
        mb_returns = torch.from_numpy(mb_returns).float().to(self.device).detach()
        value_loss = (value.squeeze() - mb_returns)**2

        # take gradient step for critic
        self.optimizer_critic_head.zero_grad()
        value_loss.mean().backward(retain_graph=False)
        self.optimizer_critic_head.step()

        # get stuff back to cpu
        value_loss = value_loss.cpu().detach().numpy()
        camera_entropy = camera_entropy.cpu().detach().numpy()
        action_entropy = action_entropy.cpu().detach().numpy()

        return value_loss, camera_entropy, action_entropy

    def update(self, mb_returns, mb_advs, mb_logprobs, mb_pov, mb_bin_actions, mb_camera_actions, mb_actions, mb_values):
        """ Update both actor and critic part of the model """

        # create a input dictionary from pov, bin_actions and camera_actions
        input_dict = {"pov": torch.from_numpy(mb_pov).to(self.device).detach(),
                      "binary_actions": torch.from_numpy(mb_bin_actions).to(self.device).detach(),
                      "camera_actions": torch.from_numpy(mb_camera_actions).to(self.device).detach()}

        # get the loss for a given dataset of obs, returns, actions, values, logprobs, states
        mb_logprobs = torch.from_numpy(mb_logprobs).to(self.device).detach()
        mb_values = torch.from_numpy(mb_values).float().to(self.device).detach()
        mb_returns = torch.from_numpy(mb_returns).float().to(self.device).detach()

        # get the new log probabilities for s-a pair for the policy
        # get the new entropy of the dist for the state also
        out_dict = self.model(input_dict)

        if self.envname == 'MineRLTreechop-v0':
            value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = self.action_space.evaluate_actions(out_dict, mb_actions)
        else:
            value, action_log_probs, camera_log_probs, equip_log_probs, place_log_probs, craft_log_probs, \
            nearbyCraft_log_probs, nearbySmelt_log_probs, action_entropy, camera_entropy, equip_entropy, place_entropy, \
            craft_entropy, nearbyCraft_entropy, nearbySmelt_entropy = \
                self.action_space.evaluate_actions(out_dict, mb_actions)

        # adding clipping on value function based on model_0 values, so that large updates are not made
        # replacing the value with the model_0 values for advantage computation, to keep updates stable
        # get the advantage here using the value function and the advantage
        advs = mb_returns - mb_values.squeeze()
        # normalize the advantage function
        advs = (advs - advs.mean()) / (advs.std() + 1e-8)

        if self.clamping == "sum":
            # just stack log probs and clamp individually
            logprob_action = torch.cat((action_log_probs, camera_log_probs), 1)
            # epsilon clipping is applied on the sum of the logprobs
            # get the ratio of new and old prob for s-a pair
            ratio = torch.sum(torch.exp(logprob_action - mb_logprobs), dim=1)
            # get the ratio surrogate loss
            surrogate_ratio_loss = -advs * ratio
            # get the loss at the boundary of the trust region
            surrogate_clamp = -torch.clamp(ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
            # select the minimum between clamp and ratio loss as the policy loss
            policy_loss = self.pg_coef * torch.max(surrogate_clamp, surrogate_ratio_loss)
            # take a mean of the policy loss across action dimension, this is because of the independence assumption
            policy_loss_mean = policy_loss

        elif self.clamping == "separate":

            # bring advantages into correct shape (BS, 1)
            advs = advs.unsqueeze(dim=1)

            # split previous log probs into action, cam and enum log probs
            mb_action_log_probs = torch.sum(mb_logprobs[:, :8], dim=-1, keepdim=True)
            mb_camera_log_probs = mb_logprobs[:, 8:10]
            mb_equip_log_probs = mb_logprobs[:, 10:11]
            mb_place_log_probs = mb_logprobs[:, 11:12]
            mb_craft_log_probs = mb_logprobs[:, 12:13]
            mb_nearbyCraft_log_probs = mb_logprobs[:, 13:14]
            mb_nearbySmelt_log_probs = mb_logprobs[:, 14:15]

            # sum up log probs of multiple (binary) actions (according to OpenAI implementation)
            action_log_probs = torch.sum(action_log_probs, dim=-1, keepdim=True)

            # get the ratio of new and old prob for s-a pair
            ratio_action = torch.exp(action_log_probs - mb_action_log_probs)
            ratio_camera = torch.exp(camera_log_probs - mb_camera_log_probs)

            if self.envname == 'MineRLTreechop-v0':
                pass
            else:
                ratio_equip = torch.exp(equip_log_probs - mb_equip_log_probs)
                ratio_place = torch.exp(place_log_probs - mb_place_log_probs)
                ratio_craft = torch.exp(craft_log_probs - mb_craft_log_probs)
                ratio_nearbyCraft = torch.exp(nearbyCraft_log_probs - mb_nearbyCraft_log_probs)
                ratio_nearbySmelt = torch.exp(nearbySmelt_log_probs - mb_nearbySmelt_log_probs)

            # get the ratio surrogate loss
            surrogate_ratio_loss_action = ratio_action * advs
            surrogate_ratio_loss_camera = ratio_camera * advs

            if self.envname == 'MineRLTreechop-v0':
                pass
            else:
                surrogate_ratio_loss_equip = ratio_equip * advs
                surrogate_ratio_loss_place = ratio_place * advs
                surrogate_ratio_loss_craft = ratio_craft * advs
                surrogate_ratio_loss_nearbyCraft = ratio_nearbyCraft * advs
                surrogate_ratio_loss_nearbySmelt = ratio_nearbySmelt * advs

            if self.envname == 'MineRLTreechop-v0':
                surrogate_ratio_loss = surrogate_ratio_loss_action + surrogate_ratio_loss_camera.mean(dim=1, keepdim=True)
            else:
                surrogate_ratio_loss = surrogate_ratio_loss_action + surrogate_ratio_loss_camera.mean(dim=1, keepdim=True) \
                                       + surrogate_ratio_loss_equip + surrogate_ratio_loss_place \
                                       + surrogate_ratio_loss_craft + surrogate_ratio_loss_nearbyCraft \
                                       + surrogate_ratio_loss_nearbySmelt

            # get the loss at the boundary of the trust region
            surrogate_clamp_action = torch.clamp(ratio_action, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
            surrogate_clamp_camera = torch.clamp(ratio_camera, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
            if self.envname == 'MineRLTreechop-v0':
                pass
            else:
                surrogate_clamp_equip = torch.clamp(ratio_equip, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
                surrogate_clamp_place = torch.clamp(ratio_place, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
                surrogate_clamp_craft = torch.clamp(ratio_craft, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
                surrogate_clamp_nearbyCraft = torch.clamp(ratio_nearbyCraft, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs
                surrogate_clamp_nearbySmelt = torch.clamp(ratio_nearbySmelt, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs

            if self.envname == 'MineRLTreechop-v0':
                surrogate_clamp = surrogate_clamp_action + surrogate_clamp_camera.mean(dim=1, keepdim=True)
            else:
                surrogate_clamp = surrogate_clamp_action + surrogate_clamp_camera.mean(dim=1, keepdim=True) \
                              + surrogate_clamp_equip + surrogate_clamp_place + surrogate_clamp_craft \
                              + surrogate_clamp_nearbyCraft + surrogate_clamp_nearbySmelt

            # select the minimum between clamp and ratio loss as the policy loss
            policy_loss_action = -torch.min(surrogate_clamp_action, surrogate_ratio_loss_action).mean(dim=1)
            policy_loss_camera = -torch.min(surrogate_clamp_camera, surrogate_ratio_loss_camera).mean(dim=1)

            if self.envname == 'MineRLTreechop-v0':
                pass
            else:
                policy_loss_equip = -torch.min(surrogate_clamp_equip, surrogate_ratio_loss_equip).mean(dim=1)
                policy_loss_place = -torch.min(surrogate_clamp_place, surrogate_ratio_loss_place).mean(dim=1)
                policy_loss_craft = -torch.min(surrogate_clamp_craft, surrogate_ratio_loss_craft).mean(dim=1)
                policy_loss_nearbyCraft = -torch.min(surrogate_clamp_nearbyCraft, surrogate_ratio_loss_nearbyCraft).mean(dim=1)
                policy_loss_nearbySmelt = -torch.min(surrogate_clamp_nearbySmelt, surrogate_ratio_loss_nearbySmelt).mean(dim=1)

            if self.envname == 'MineRLTreechop-v0':
                # take a mean of the policy loss across action dimension, this is because of the independence assumption
                policy_loss = self.pg_coef * (policy_loss_action + policy_loss_camera)
            else:
                # take a mean of the policy loss across action dimension, this is because of the independence assumption
                policy_loss = self.pg_coef * (policy_loss_action + policy_loss_camera + policy_loss_equip +
                                              policy_loss_place + policy_loss_craft + policy_loss_nearbyCraft +
                                              policy_loss_nearbySmelt)
            policy_loss_mean = policy_loss.mean()

        else:
            # just stack log probs and clamp individually
            logprob_action = torch.cat((action_log_probs, camera_log_probs), 1)
            # epsilon clipping is applied on action logprobs separately
            # get the ratio of new and old prob for s-a pair
            ratio = torch.exp(logprob_action - mb_logprobs)
            # get the ratio surrogate loss
            surrogate_ratio_loss = -advs.unsqueeze(dim=1) * ratio
            # get the loss at the boundary of the trust region
            surrogate_clamp = -torch.clamp(ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advs.unsqueeze(dim=1)
            # select the minimum between clamp and ratio loss as the policy loss
            policy_loss = self.pg_coef * torch.max(surrogate_clamp, surrogate_ratio_loss)
            # take a mean of the policy loss across action dimension, this is because of the independence assumption
            policy_loss_mean = policy_loss.mean(dim=1)

        # entropy loss
        entropy_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        if self.ent_coef_cam != 0:
            entropy_loss -= self.ent_coef_cam * camera_entropy
        if self.ent_coef_bin != 0:
            entropy_loss -= self.ent_coef_bin * action_entropy

        # value loss
        if self.vf_coef > 0:
            if self.bound_value_loss:
                vpredclipped = mb_values.squeeze() + torch.clamp(value.squeeze() - mb_values, -self.value_clip, self.value_clip)
                value_loss_1 = self.vf_coef * (value.squeeze() - mb_returns)**2
                value_loss_2 = self.vf_coef * (vpredclipped - mb_returns)**2
                value_loss = torch.max(value_loss_1, value_loss_2)
            else:
                value_loss = self.vf_coef * (value.squeeze() - mb_returns) ** 2
        else:
            value_loss = torch.tensor(0, dtype=torch.float32, device=self.device)

        # perform joint policy and value update
        self.optimizer.zero_grad()
        loss = (policy_loss_mean + entropy_loss + value_loss).mean()
        loss.backward(retain_graph=False)

        # compute gradient norm
        total_norm = self.compute_gradient_norm(self.actor_critic_params)
        actor_norm = self.compute_gradient_norm(self.actor_params)
        critic_norm = self.compute_gradient_norm(self.critic_head_params)

        # scale gradient norm of critic to scale of actor
        if self.scale_critic_norm:
            scale = actor_norm / critic_norm
            scale_grad_norm_(self.critic_head_params, scale=scale)

        # apply gradient clipping
        if self.clip_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self.actor_critic_params, self.clip_grad_norm)

        # perform gradient step
        self.optimizer.step()

        # returning loss, policy loss across each actions, policy_loss mean, entropy_loss, value_loss,
        # surrogate_ratio_loss across each action
        return loss.mean().cpu().detach().numpy(), policy_loss.mean(dim=0).cpu().detach().numpy(),\
               policy_loss_mean.cpu().detach().numpy(), entropy_loss.cpu().detach().numpy(), \
               value_loss.cpu().detach().numpy(), surrogate_ratio_loss.mean(dim=0).cpu().detach().numpy(), \
               surrogate_clamp.mean(dim=0).cpu().detach().numpy(), camera_entropy.cpu().detach().numpy(), \
               action_entropy.cpu().detach().numpy(), total_norm, actor_norm, critic_norm

    def _freeze_batch_norm(self):

        def set_bn_eval(m):
            class_name = m.__class__.__name__
            if class_name.find('BatchNorm2d') != -1 or class_name.find('BatchNorm1d') != -1:
                m.eval()

        self.model.apply(set_bn_eval)
