import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import print, explained_variance

def evaluate_actions(act_logits, actions):
    dist = torch.distributions.Categorical(logits=act_logits)
    return dist.log_prob(actions), dist.entropy()

def minibatch_generator(num_batches, *args):
    if num_batches > 1:
        n = len(args[0])
        minibatch_size = n // num_batches
        indices = torch.randperm(n)
        for i in range(num_batches):
            if i == num_batches-1:
                ind = indices[i*minibatch_size:]
            else:
                ind = indices[i*minibatch_size:(i+1)*minibatch_size]
            minibatch = tuple(a[ind] for a in args)
            yield minibatch
    else:
        yield args

class A2C:
    def __init__(self, model, args, compute_returns):
        self.model = model
        self.value_coef = args.value_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm

        param_groups = model.parameters()
        if args.opt == 'rmsprop':
            self.optimizer = torch.optim.RMSprop(
                param_groups, args.lr, eps=args.eps, alpha=args.alpha)
        elif args.opt == 'adam':
            self.optimizer = torch.optim.Adam(
                param_groups, args.lr, eps=args.eps, betas=(args.beta, args.alpha))

        self.compute_returns = compute_returns
        self.anneal = args.anneal or 0
        self.num_updates = 0
        self.init_lr = args.lr

    def set_lr(self, lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def update(self, states, actions, returns, masks, k=1):
        self.num_updates += 1
        if self.num_updates <= self.anneal:
            self.set_lr(self.init_lr * self.num_updates / self.anneal)
        vpreds = self.compute_returns(states, returns, masks, self.model)

        T, B, *obs_shape = list(states.shape)
        T -= 1
        masks = masks[:-1].view(T*B).to(torch.bool)
        states = states[:-1].view(T*B, *obs_shape)[masks]
        actions = actions.view(T*B)[masks]
        returns = returns.view(T*B)[masks]

        for states, actions, returns in minibatch_generator(k, states, actions, returns):
            act_logits, values = self.model(states)
            act_log_probs, entropy = evaluate_actions(act_logits, actions)

            advantages = returns - values
            advantages_detach = advantages.detach()
            # advantages = (advantages - advantages_detach.mean()) / advantages_detach.std()

            action_loss = -(advantages_detach * act_log_probs).mean()
            value_loss = (advantages ** 2).mean()
            entropy = entropy.mean()

            self.optimizer.zero_grad()
            (action_loss + self.value_coef * value_loss - 
                self.entropy_coef * entropy).backward()
            grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
            self.optimizer.step()

        return action_loss.item(), value_loss.item(), entropy.item(), grad_norm

