import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from utils import print, explained_variance

def evaluate_actions(act_logits, raw_actions):
    if isinstance(act_logits, tuple):
        # continuous action space
        mean, std = act_logits
        dist = torch.distributions.Normal(mean, std.exp())
        # print (dist, actions.shape)
        # return dist.log_prob(actions.view_as(mean)).sum(-1), dist.entropy().sum(-1)
        '''
        Push into tanh, need to correct the prob mass by the jacobian of tanh^-1.
        a = f(x)
          --> p(a) = p(x) |dx/da| = p(x) |f^-1'(a)|
                   = p(x) / |da/dx| = p(x) / evaluate_actions
                   |f'(x)|
              log p(a) = log p(x) + log |f^-1'(a)|
                       = log p(x) - log |f'(x)|
        When f(x) = tanh(x),
          atanh(a) = tanh^-1(a) = 1/2 * log[(1+a)/(1-a)]
          numerical: atanh(x) = torch.log1p(2*x/(1-x)) / 2
          atanh'(a) = 1/2 [1/(1+a) - 1/(1-a)] = -a/(1-a^2)
          since -1 <= a <= 1,
            log |atanh'(a)| = log(|a|) - log(1-a^2) = log(|a|) - log1p(-a^2)
        or
          tanh'(x) = 1 - tanh(x)^2
          log |tanh'(x)| = log(1 - tanh(x)^2) = log1p(-tanh(x)^2)
                         = log(1 - (e^x-e^-x)^2/(e^x+e^-x)^2)
                         = log(1 - (e^2x+e^-2x-2)/(e^2x+e^-2x+2))
                         = log(4/(e^2x+e^-2x+2))
                         = log(4) - log(e^2x+e^-2x+2)
                         = log(4) - 2 * log(e^x+e^-x) = log(4) - 2 * logaddexp(x, -x) (new in pt1.6?)
                         = log(4) - 2x - 2log(1+e^-2x)  (if x > 0)
                         = log(4) + 2x - 2log(1+e^2x)  (if x < 0)
        In general H(f(x)) <= H(x), '=' iff invertible deterministic f.
        '''
        raw_actions = raw_actions.view_as(mean)
        # actions = torch.tanh(raw_actions)
        # log_prob = dist.log_prob(raw_actions) - torch.log1p((actions**2).neg_().clamp_(min=-0.9999999))
        s = torch.sign(raw_actions)
        sx2 = s * raw_actions*(-2.)
        log_prob = dist.log_prob(raw_actions) - (np.log(4) + sx2 - 2*torch.log1p(sx2.exp()))
        # log_prob = dist.log_prob(raw_actions) - np.log(4) - 2 * torch.logaddexp(raw_actions, -raw_actions)
        entropy = dist.entropy().sum(-1)
        return log_prob.sum(-1), entropy
    else:
        # dicrete action space
        dist = torch.distributions.Categorical(logits=act_logits)
        return dist.log_prob(raw_actions.view(act_logits.shape[:-1])), dist.entropy()

# compute cumulative returns
def evaluate_values(states, model):
    states = torch.as_tensor(states)
    shape = states.shape
    states = states.view(-1, shape[-1])
    value_preds = model(states, av=2)
    return value_preds.view(shape[:-1])


# if args.gae:
# It's fine to set cumm_returns == returns for inplace computation
@torch.no_grad()
def compute_returns_gae(states, returns, masks, model, args, cumm_returns):
    value_preds = evaluate_values(states, model)
    gae = 0
    for step in reversed(range(returns.shape[0])):
        delta = returns[step] + args.gamma * value_preds[step + 1] * masks[step + 1] - value_preds[step]
        gae = delta + args.gamma * args.gae_lambda * masks[step + 1] * gae
        cumm_returns[step] = gae + value_preds[step]
# else:
@torch.no_grad()
def compute_returns(states, returns, masks, model, gamma, cumm_returns):
    next_value = evaluate_values(states[-1], model)
    next_value = value_preds[-1]
    for step in reversed(range(returns.shape[0])):
        next_value = cumm_returns[step] = next_value * args.gamma * masks[step + 1] + returns[step]


def minibatch_generator(minibatch_size, *args, drop_last=True):
    # or can use torch.utils.data.SubsetRandomSampler and BatchSampler.
    n = len(args[0])
    iterations = n // minibatch_size if drop_last else (n+minibatch_size-1)//minibatch_size
    iterations = max(1, iterations)
    indices = torch.randperm(n)
    for i in range(iterations):
        ind = indices[i*minibatch_size:(i+1)*minibatch_size]
        minibatch = tuple(a[ind] for a in args)
        yield minibatch


class A2C_PPO:
    def __init__(self, model, args):
        self.model = model
        self.value_coef = args.value_coef
        # self.value_decay = 1e-3
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm
        # self.opt_steps = args.opt_steps
        self.opt_epochs = args.opt_epochs
        self.minibatch_size = args.minibatch_size
        # assert self.opt_steps * self.opt_minibatch_size == self.batch_size
        self.ppo = args.algo == 'ppo'
        self.adv_norm = args.adv_norm
        self.clip = args.clip
        self.sep_opt = args.sep_opt
        self.num_updates = 0
        self.total_updates = args.niter * args.ninner

        self.params = list(model.parameters())
        self.value_params = list(model.value.parameters())
        self.policy_params = [p for p in self.params if not any([p is q for q in self.value_params])]
        if self.sep_opt:
            if args.opt == 'rmsprop':
                self.policy_optimizer = torch.optim.RMSprop(self.policy_params, args.lr, eps=args.eps, alpha=args.alpha)
                self.value_optimizer = torch.optim.RMSprop(self.value_params, args.lr*args.value_lr_mult, eps=args.eps, alpha=args.alpha)
            elif args.opt == 'adam':
                self.policy_optimizer = torch.optim.Adam(self.policy_params, args.lr, eps=args.eps, betas=(args.beta, args.alpha))
                self.value_optimizer = torch.optim.Adam(self.value_params, args.lr*args.value_lr_mult, eps=args.eps, betas=(args.beta, args.alpha))
            elif args.opt == 'sgd':
                self.policy_optimizer = torch.optim.SGD(self.policy_params, args.lr, momentum=args.beta)
                self.value_optimizer = torch.optim.SGD(self.value_params, args.lr*args.value_lr_mult, momentum=args.beta)
                # self.value_optimizer = torch.optim.Adam(self.value_params, args.lr*args.value_lr_mult, eps=args.eps, betas=(args.beta, args.alpha))
            for param_group in self.policy_optimizer.param_groups:
                param_group['init_lr'] = args.lr
            for param_group in self.value_optimizer.param_groups:
                param_group['init_lr'] = args.lr * args.value_lr_mult
        else:
            param_groups = [{'params': self.policy_params, 'init_lr': args.lr},
                {'params': self.value_params, 'lr': args.lr*args.value_lr_mult, 'init_lr': args.lr*args.value_lr_mult}]
            if args.opt == 'rmsprop':
                self.optimizer = torch.optim.RMSprop(param_groups, args.lr, eps=args.eps, alpha=args.alpha)
            elif args.opt == 'adam':
                self.optimizer = torch.optim.Adam(param_groups, args.lr, eps=args.eps, betas=(args.beta, args.alpha))
            elif args.opt == 'sgd':
                self.optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.beta)
        if args.gae:
            self.compute_returns = compute_returns_gae
        else:
            self.compute_returns = compute_returns
        # self.gamma = args.gamma
        # self.gae_lambda = args.gae_lambda
        self.args = args
        if not args.anneal_lr:
            self.set_lr = lambda f: None

    def set_lr(self, f):
        # f = min(f, 1.0)
        if self.sep_opt:
            for param_group in self.policy_optimizer.param_groups:
                param_group['lr'] = param_group['init_lr'] * f
            # anneal policy lr only, but not value lr
            # for param_group in self.value_optimizer.param_groups:
            #     param_group['lr'] = param_group['init_lr'] * f
        else:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = param_group['init_lr'] * f

    def clear_optimizer_state(self):
        for param_group in self.optimizer.param_groups:
            for p in param_group['params']:
                self.optimizer.state[p] = {}

    # act_log_probs0=None; old action probs, for proximal policy optimization.
    def update(self, states, actions, returns, masks, target_returns=None, f=1.0):
        self.set_lr(f * float(self.total_updates - self.num_updates) / self.total_updates)
        self.num_updates += 1
        if self.sep_opt:
            return self._update_sep(states, actions, returns, masks, target_returns)
        else:
            return self._update_joint(states, actions, returns, masks, target_returns)

    def _update_sep(self, states0, actions0, returns0, masks0, returns=None):
        if returns is None:
            returns_buf = torch.empty_like(returns0)
            self.compute_returns(states0, returns0, masks0, self.model, self.args, returns_buf)
        else:
            returns_buf = returns

        T, B, *obs_shape = list(states0.shape)
        T -= 1
        masks = masks0[:-1].view(T*B)
        states = states0[:-1].view(T*B, *obs_shape)[masks]
        actions = actions0.view(T*B, -1)[masks]
        returns = returns_buf.view(T*B)[masks]

        tot_action_loss, tot_value_loss, tot_entropy, tot_grad_norm, val_grad_norm = 0, 0, 0, 0, 0

        K = 0
        for k in range(self.opt_epochs):
            gen = minibatch_generator(self.minibatch_size, states, returns)
            for mb_idx, mb in enumerate(gen):
                s, r = mb
                self.value_optimizer.zero_grad()

                values = self.model(s, av=2)
                value_loss = (values - r).pow(2).mean()
                # weight_decay = sum([p.pow(2).sum() for p in self.value_params])
                # value_loss = self.value_coef * value_loss + self.value_decay * weight_decay
                value_loss = self.value_coef * value_loss

                value_loss.backward()
                grad_norm = nn.utils.clip_grad_norm_(self.value_params, self.max_grad_norm*10)
                self.value_optimizer.step()

                tot_value_loss += value_loss.item()
                val_grad_norm += float(grad_norm)
                K += 1

            # now we have a new value function approximation, recompute the value estimates
            self.compute_returns(states0, returns0, masks0, self.model, self.args, returns_buf)
            returns = returns_buf.view(T*B)[masks]

        tot_value_loss /= K; val_grad_norm /= K

        with torch.no_grad():
            act_logits0, values0 = self.model(states)
            advantages = returns - values0.detach()
            if self.adv_norm:
                std, mean = torch.std_mean(advantages)
                advantages.sub_(mean).div_(std + 1e-5)
                # advantages.sub_(mean)
            if self.ppo:
                act_log_probs0, entropy0 = evaluate_actions(act_logits0, actions)
            ev = explained_variance(values0, returns)

        print (f" opt_epochs={self.opt_epochs}, iterations={len(states)}/{self.minibatch_size}={len(states)//self.minibatch_size}, v_grad={val_grad_norm:.4f}, ev={ev:.4f}, lr={self.policy_optimizer.param_groups[0]['lr']:.2e}")
        # print (f' opt_epochs={self.opt_epochs}, iterations={len(states)}/{self.minibatch_size}={len(states)//self.minibatch_size}, v_grad={val_grad_norm:.4f} wd={weight_decay:.1f}')

        K = 0
        for k in range(self.opt_epochs):
            if self.ppo:
                gen = minibatch_generator(self.minibatch_size,
                    states, actions, returns, advantages, act_log_probs0)
                # epoch_clipped = 0
                to_break = False
            else:
                gen = minibatch_generator(self.minibatch_size,
                    states, actions, returns, advantages)
            for mb_idx, mb in enumerate(gen):
                s, a, r, adv = mb[:4]
                self.policy_optimizer.zero_grad()

                act_logits = self.model(s, av=1)
                act_log_probs, entropy = evaluate_actions(act_logits, a)

                if self.ppo:
                    # * hard clip ratio, original PPO
                    ratio = torch.exp(act_log_probs - mb[-1])
                    surr1 = adv * ratio
                    surr2 = adv * ratio.clamp(1.-self.clip, 1.+self.clip)
                    surr = torch.min(surr1, surr2)
                    with torch.no_grad():
                        # s = adv.sign()
                        # ratio_clip = s * ratio.clamp(max=1+self.clip) + (1-s) * ratio.clamp(min=1-self.clip)

                        clipped = (surr1 != surr).sum().item()
                        #epoch_clipped += clipped
                        if mb_idx % 100 == 0:
                            print (f' ppo [{k}], {clipped:3d} of {ratio.numel()} clipped\t',
                                f'min/max/mean/med: {ratio.min():.2f} / {ratio.max():.2f} / {ratio.mean():.2f} / {ratio.median():.2f}',
                                f'\t{act_log_probs.min():6.2f} / {act_log_probs.max():<6.2f} '
                                f' {adv.min():5.2f} / {adv.max():<5.2f}'
                                f' {a.min():.3f} / {a.max():<.3f}')
                        #if epoch_clipped > self.minibatch_size * (mb_idx+1) // 2: break
                        # if clipped > self.minibatch_size // 3:
                        #     to_break = True
                        #     break

                    action_loss = -surr.mean()
                else:
                    # usual actor-critic PG
                    action_loss = -(adv * act_log_probs).mean()

                entropy = entropy.mean()

                (action_loss - self.entropy_coef * entropy).backward()
                grad_norm = nn.utils.clip_grad_norm_(self.policy_params, self.max_grad_norm)
                self.policy_optimizer.step()

                tot_action_loss += action_loss.item()
                tot_entropy += entropy.item()
                tot_grad_norm += float(grad_norm)
                K += 1
            if to_break:
                break

        # for param_group in self.policy_optimizer.param_groups:
        #     for p in param_group['params']:
        #         print (p.shape, self.policy_optimizer.state[p]['exp_avg'].abs().mean(), self.policy_optimizer.state[p]['exp_avg_sq'].sqrt().mean())

        tot_action_loss /= K; tot_entropy /= K; tot_grad_norm /= K
        return tot_action_loss, tot_value_loss, tot_entropy, tot_grad_norm + val_grad_norm, returns_buf

    def _update_joint(self, states, actions, returns, masks, target_returns=None):
        if target_returns is None:
            self.compute_returns(states, returns, masks, self.model, self.args, returns)
        else:
            returns = target_returns

        # print (states, actions, returns, masks)
        T, B, *obs_shape = list(states.shape)
        T -= 1
        # print (masks[-1,:].sum())
        masks = masks[:-1].view(T*B)
        states = states[:-1].view(T*B, *obs_shape)[masks]
        actions = actions.view(T*B, -1)[masks]
        returns = returns.view(T*B)[masks]

        print (f' opt_epochs={self.opt_epochs}, iterations={len(states)}/{self.minibatch_size}={len(states)//self.minibatch_size}')

        with torch.no_grad():
            act_logits0, values0 = self.model(states)
            advantages = returns - values0.detach()
            if self.adv_norm:
                std, mean = torch.std_mean(advantages)
                advantages.sub_(mean).div_(std + 1e-5)
                # advantages.sub_(mean)
            if self.ppo:
                act_log_probs0, entropy0 = evaluate_actions(act_logits0, actions)

        tot_action_loss, tot_value_loss, tot_entropy, tot_grad_norm = 0, 0, 0, 0
        K = 0
        for k in range(self.opt_epochs):
            if self.ppo:
                gen = minibatch_generator(self.minibatch_size,
                    states, actions, returns, advantages, act_log_probs0)
                # epoch_clipped = 0
                to_break = False
            else:
                gen = minibatch_generator(self.minibatch_size,
                    states, actions, returns, advantages)
            for mb_idx, mb in enumerate(gen):
                s, a, r, adv = mb[:4]
                self.optimizer.zero_grad()

                act_logits, values = self.model(s)
                act_log_probs, entropy = evaluate_actions(act_logits, a)

                if self.ppo:
                    # * hard clip ratio, original PPO
                    ratio = torch.exp(act_log_probs - mb[-1])
                    surr1 = adv * ratio
                    surr2 = adv * ratio.clamp(1.-self.clip, 1.+self.clip)
                    surr = torch.min(surr1, surr2)
                    with torch.no_grad():
                        # s = adv.sign()
                        # ratio_clip = s * ratio.clamp(max=1+self.clip) + (1-s) * ratio.clamp(min=1-self.clip)

                        clipped = (surr1 != surr).sum().item()
                        #epoch_clipped += clipped
                        if mb_idx % 100 == 0:
                            print (f' ppo [{k}], {clipped:3d} of {ratio.numel()} clipped\t',
                                f'min/max/mean/med: {ratio.min():.2f} / {ratio.max():.2f} / {ratio.mean():.2f} / {ratio.median():.2f}',
                                f'\t{act_log_probs.min():6.2f} / {act_log_probs.max():<6.2f} '
                                f' {adv.min():5.2f} / {adv.max():<5.2f}'
                                f' {a.min():.3f} / {a.max():<.3f}')
                        #if epoch_clipped > self.minibatch_size * (mb_idx+1) // 2: break
                        if clipped > self.minibatch_size // 3:
                            to_break = True
                            break

                    action_loss = -surr.mean()
                else:
                    # usual actor-critic PG
                    action_loss = -(adv * act_log_probs).mean()

                entropy = entropy.mean()
                # value loss lr = 0.001 seems to work well.
                value_loss = (values - r).pow(2).mean()
                # weight_decay = sum([p.pow(2).sum() for p in self.value_params])
                # value_loss = self.value_coef * value_loss + self.value_decay * weight_decay
                value_loss = self.value_coef * value_loss

                (action_loss - self.entropy_coef * entropy +
                    value_loss).backward()
                grad_norm = nn.utils.clip_grad_norm_(self.params(), self.max_grad_norm)
                self.optimizer.step()

                tot_action_loss += action_loss.item()
                tot_value_loss += value_loss.item()
                tot_entropy += entropy.item()
                tot_grad_norm += float(grad_norm)
                K += 1
            if to_break:
                break

        return tot_action_loss/K, tot_value_loss/K, tot_entropy/K, tot_grad_norm/K
