import torch.nn as nn
import torch.optim as optim
from rlf.args import str2bool
import rlf.algos.utils as autils
from rlf.algos.base_algo import BaseAlgo


class BaseNetAlgo(BaseAlgo):
    def __init__(self):
        super().__init__()
        self.arg_prefix = ''

    def _arg(self, k):
        prefix_str = self.arg_prefix.replace('-', '_')
        return self.arg_vars[prefix_str + k]

    def init(self, policy, args):
        super().init(policy, args)
        self.arg_vars = vars(args)
        self.optimizers = self._get_optimizers()

        if self._arg('linear_lr_decay'):
            if self._arg('lr_env_steps') is None:
                self.lr_updates = self.get_num_updates()
            else:
                self.lr_updates = int(self._arg('lr_env_steps')) // args.num_steps // args.num_processes

    def update(self, storage):
        log_vals = super().update(storage)
        for k, (opt, _, initial_lr) in self.optimizers.items():
            lr = None
            for param_group in opt.param_groups:
                lr = param_group['lr']
            log_vals[k + '_lr'] = lr
        return log_vals

    def _copy_policy(self):
        cp_policy = super()._copy_policy()
        if next(self.policy.parameters()).is_cuda:
            cp_policy = cp_policy.cuda()
        autils.hard_update(self.policy, cp_policy)
        return cp_policy

    def load_resume(self, checkpointer):
        super().load_resume(checkpointer)
        # Load the optimizers where they left off.
        for k, (opt, _, _) in self.optimizers.items():
            opt.load_state_dict(checkpointer.get_key(k))

    def save(self, checkpointer):
        super().save(checkpointer)
        for k, (opt, _, _) in self.optimizers.items():
            checkpointer.save_key(k, opt.state_dict())

    def pre_update(self, cur_update):
        super().pre_update(cur_update)
        if self._arg('linear_lr_decay'):
            for k, (opt, _, initial_lr) in self.optimizers.items():
                lr = initial_lr - \
                    (initial_lr * (cur_update / float(self.lr_updates)))
                for param_group in opt.param_groups:
                    param_group['lr'] = lr

    def _clip_grad(self, params):
        """
        Helper function to clip gradients
        """
        if self._arg('max_grad_norm') > 0:
            nn.utils.clip_grad_norm_(params,
                                     self._arg('max_grad_norm'))

    def _standard_step(self, loss, optimizer_key='actor_opt'):
        """
        Helper function to compute gradients, clip gradients and then take
        optimization step.
        """
        opt, get_params_fn, _ = self.optimizers[optimizer_key]
        opt.zero_grad()
        loss.backward()
        self._clip_grad(get_params_fn())
        opt.step()

    def set_arg_prefix(self, arg_prefix):
        self.arg_prefix = arg_prefix + '-'

    def get_add_args(self, parser):
        parser.add_argument(f"--{self.arg_prefix}max-grad-norm", default=0.5, type=float,
                            help='-1 results in no grad norm')
        parser.add_argument(
            f"--{self.arg_prefix}linear-lr-decay", type=str2bool, default=False)
        parser.add_argument(f"--{self.arg_prefix}lr-env-steps", type=float,
                            default=None, help='only used for lr schedule')
        parser.add_argument(f"--{self.arg_prefix}eps", type=float, default=1e-5,
                            help='optimizer epsilon (default: 1e-5)')
        parser.add_argument(f"--{self.arg_prefix}lr", type=float, default=1e-3,
                            help='learning rate (default: 1e-3)')

    def _get_optimizers(self):
        return {
            'actor_opt': (
                optim.Adam(self.policy.parameters(),
                           lr=self._arg('lr'), eps=self._arg('eps')),
                lambda: self.policy.parameters(),
                self._arg('lr')
            )
        }
