import torch
import torch.nn as nn
import torch.optim as optim

from pytorch_rl.a2c_ppo_acktr.algo.kfac import KFACOptimizer


class A2C_ACKTR():
    def __init__(self,
                 actor_critic,
                 value_loss_coef,
                 entropy_coef,
                 lr=None,
                 eps=None,
                 alpha=None,
                 max_grad_norm=None,
                 acktr=False):

        self.actor_critic = actor_critic
        self.acktr = acktr

        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef

        self.max_grad_norm = max_grad_norm

        if acktr:
            self.optimizer = KFACOptimizer(actor_critic)
        else:
            self.optimizer = optim.RMSprop(
                actor_critic.parameters(), lr, eps=eps, alpha=alpha)

    def update(self, rollouts):
        # print(rollouts.obs[0].shape)
        # obs_shape = rollouts.obs.size()[2:]
        # action_shape = rollouts.actions.size()[-1]
        # num_steps, num_processes, _ = rollouts.rewards.size()
        #
        # values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
        #     rollouts.obs[:-1].view(-1, *obs_shape),
        #     rollouts.recurrent_hidden_states[0].view(
        #         -1, self.actor_critic.recurrent_hidden_state_size),
        #     rollouts.masks[:-1].view(-1, 1),
        #     rollouts.actions.view(-1, action_shape))

        num_steps, num_processes = rollouts.rewards.size()[0:2]
        batch_size = num_processes * num_steps
        obs_batch = {}
        for k in rollouts.obs[0]:
            cur_shape = list(rollouts.obs[0][k].shape)
            cur_shape = [batch_size] + cur_shape[1:]
            obs_batch[k] = torch.stack([o[k] for o in rollouts.obs[:-1]]).view(cur_shape)
        actions_batch = rollouts.actions.view(batch_size, -1)
        masks_batch = rollouts.masks[:-1].view(batch_size, -1)
        values, action_log_probs, dist_entropy = self.actor_critic.evaluate_actions(obs_batch, masks_batch,
                                                                                    actions_batch)
        # values, action_log_probs, dist_entropy = values.squeeze(), action_log_probs.squeeze(), dist_entropy.squeeze()

        # valid action mask
        valid_action_mask = obs_batch['valid_action_mask'].view(num_steps, num_processes, -1)

        values = values.view(num_steps, num_processes, -1)
        action_log_probs = action_log_probs.view(num_steps, num_processes, -1)

        # advantages = (rollouts.returns[:-1] - values)
        advantages = (rollouts.returns[:-1] - values) * valid_action_mask
        # value_loss = advantages.pow(2).mean()
        # action_loss = -(advantages.detach() * action_log_probs).mean()
        valid_num = torch.sum(valid_action_mask) + 1
        value_loss = advantages.pow(2).sum() / valid_num
        action_loss = -(advantages.detach() * action_log_probs).sum() / valid_num

        if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0:
            # Compute fisher, see Martens 2014
            self.actor_critic.zero_grad()
            pg_fisher_loss = -action_log_probs.mean()

            value_noise = torch.randn(values.size())
            if values.is_cuda:
                value_noise = value_noise.cuda()

            sample_values = values + value_noise
            vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean()

            fisher_loss = pg_fisher_loss + vf_fisher_loss
            self.optimizer.acc_stats = True
            fisher_loss.backward(retain_graph=True)
            self.optimizer.acc_stats = False

        self.optimizer.zero_grad()
        (value_loss * self.value_loss_coef + action_loss -
         dist_entropy * self.entropy_coef).backward()

        if self.acktr == False:
            nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
                                     self.max_grad_norm)

        self.optimizer.step()

        return value_loss.item(), action_loss.item(), dist_entropy.item()
