import rlf.algos.utils as autils
import torch.nn as nn
import torch.optim as optim
from rlf.algos.base_algo import BaseAlgo
from rlf.args import str2bool


class BaseNetAlgo(BaseAlgo):
    def __init__(self):
        super().__init__()
        self.arg_prefix = ""

    def _arg(self, k):
        prefix_str = self.arg_prefix.replace("-", "_")
        return self.arg_vars[prefix_str + k]

    def init(self, policy, args):
        super().init(policy, args)
        self.arg_vars = vars(args)
        self._optimizers = self._get_optimizers()
        self.obs_space = policy.obs_space
        self.action_space = policy.action_space

        if self._arg("linear_lr_decay"):
            if self._arg("lr_env_steps") is None:
                self.lr_updates = self.get_num_updates()
            else:
                self.lr_updates = (
                    int(self._arg("lr_env_steps"))
                    // args.num_steps
                    // args.num_processes
                )
        
        if self._arg("add_expert_regularizer"):
            if self._arg("add_expert_restricted_regularizer"):
                self.exres_coef = self._arg("exres_coef")
                if self._arg("restricted_regularizer_sum_prob"):
                    self.sum_prob_update_freq = self._arg("restricted_regularizer_sum_prob_update_frequency")
                    self.sum_prob_update_ts = 0
            else:
                if self._arg("add_kl_regularizer"):
                    self.regularizer_coef = self._arg('kl_coef')
                elif self._arg("add_ce_regularizer"):
                    self.regularizer_coef = self._arg('ce_coef')
                
                if self._arg("add_non_expert_entropy"):
                    self.non_expert_entropy_coef = self._arg('non_expert_entropy_coef')

        if self._arg('add_cdf_regularizer'):
            self.cdf_coef = self._arg('cdf_regularizer_coef')    
            self.reg_const_bound = self._arg('reg_const_bound')
        
        if self._arg('add_pdf_regularizer'):
            self.pdf_coef = self._arg('pdf_coef')
            self.reg_const_bound = self._arg('reg_const_bound')

        
        if self._arg("add_bc_model_regularizer"):
            self.bc_model_reg_coef = self._arg("bc_model_reg_coef")


        if self._arg('add_bc_model_regularizer_reward_bonus'):
            self.bc_model_regularizer_reward_bonus = self._arg('bc_model_regularizer_reward_bonus_value')
        
        if self._arg('add_regularizer_on_proximity'):
            self.anneal_regularizer_on_proximity = 1.0



    def get_optimizer(self, opt_key: str):
        return self._optimizers[opt_key][0]

    def update(self, storage):
        log_vals = super().update(storage)
        for k, (opt, _, initial_lr) in self._optimizers.items():
            lr = None
            for param_group in opt.param_groups:
                lr = param_group["lr"]
            log_vals[k + "_lr"] = lr
        return log_vals

    def _copy_policy(self):
        cp_policy = super()._copy_policy()
        if next(self.policy.parameters()).is_cuda:
            cp_policy = cp_policy.cuda()
        autils.hard_update(self.policy, cp_policy)
        return cp_policy

    def load_resume(self, checkpointer):
        super().load_resume(checkpointer)
        # Load the optimizers where they left off.
        for k, (opt, _, _) in self._optimizers.items():
            opt.load_state_dict(checkpointer.get_key(k))

    def save(self, checkpointer):
        super().save(checkpointer)
        for k, (opt, _, _) in self._optimizers.items():
            checkpointer.save_key(k, opt.state_dict())

    def pre_update(self, cur_update):
        super().pre_update(cur_update)
        if self._arg("linear_lr_decay"):
            for k, (opt, _, initial_lr) in self._optimizers.items():
                autils.linear_lr_schedule(cur_update, self.lr_updates, initial_lr, opt)
        if self._arg('add_expert_regularizer'):
            if self._arg("add_coef_decay"):
                if self._arg("add_expert_restricted_regularizer"):
                    # exres-coef-decay-end-percent
                    coef_value = self._arg("exres_coef")
                    coef_end_ratio = self._arg("exres_coef_decay_end_percent")
                    current_decay = 1 - cur_update / (self._arg("exres_coef_decay_end_ts") * self.get_num_updates())
                    current_decay = max(0, current_decay)
                    self.exres_coef = coef_value * coef_end_ratio + coef_value * (1 - coef_end_ratio) * current_decay
                    
                else:
                    if self._arg("add_kl_regularizer"):
                        self.regularizer_coef = self._arg('kl_coef') * (1 - cur_update / self.get_num_updates())
                    elif self._arg("add_ce_regularizer"):
                        self.regularizer_coef = self._arg('ce_coef') * (1 - cur_update / self.get_num_updates())

                    if self._arg("add_non_expert_entropy"):
                        self.non_expert_entropy_coef = self._arg('non_expert_entropy_coef') * (1 - cur_update / self.get_num_updates())
            else:
                if self._arg("add_expert_restricted_regularizer"):
                    self.exres_coef = self._arg("exres_coef")
                else:
                    if self._arg("add_kl_regularizer"):
                        self.regularizer_coef = self._arg('kl_coef')
                    elif self._arg("add_ce_regularizer"):
                        self.regularizer_coef = self._arg('ce_coef')
                        
                    if self._arg("add_non_expert_entropy"):
                        self.non_expert_entropy_coef = self._arg('non_expert_entropy_coef')

        if self._arg('add_bc_model_regularizer_reward_bonus'):
            self.bc_model_regularizer_reward_bonus = self._arg('bc_model_regularizer_reward_bonus_value') * (1 - cur_update / self.get_num_updates())

        if self._arg('add_regularizer_on_proximity'):
            if self.args.anneal_regularizer_on_proximity:
                # linearly anneal from 1.0 to 0.0 over training iterations
                current_decay = 1 - cur_update / (self._arg("anneal_regularizer_on_proximity_end_ts") * self.get_num_updates())
                current_decay = max(0, current_decay)
                self.anneal_regularizer_on_proximity = current_decay
            else:
                self.anneal_regularizer_on_proximity = 1.0
        
        if self._arg('pf_est_decay'):
            # Apply linear decay to the proximity values
            # decay_factor = 1 - (cur_update / self.get_num_updates())
            decay_factor = 1 - (cur_update / (self._arg("pf_est_decay_end_point") * self.get_num_updates()))
            decay_factor = max(self._arg('pf_est_decay_min_crop'), decay_factor)
            self.pf_est_prob = decay_factor
        else:
            self.pf_est_prob = 0

    def _clip_grad(self, params):
        """
        Helper function to clip gradients
        """
        if self._arg("max_grad_norm") > 0:
            nn.utils.clip_grad_norm_(params, self._arg("max_grad_norm"))

    def _standard_step(self, loss, optimizer_key="actor_opt"):
        """
        Helper function to compute gradients, clip gradients and then take
        optimization step.
        """
        opt, get_params_fn, _ = self._optimizers[optimizer_key]
        opt.zero_grad()
        loss.backward()
        self._clip_grad(get_params_fn())
        opt.step()

    def set_arg_prefix(self, arg_prefix):
        self.arg_prefix = arg_prefix + "-"

    def get_add_args(self, parser):
        """
        Adds default arguments that might be useful for all algorithms that
        update neural networks. Added arguments:
        * --max-grad-norm
        * --linear-lr-decay
        * --eps
        * --lr
        All can be prefixed with `self.arg_prefix`.
        """
        super().get_add_args(parser)
        parser.add_argument(
            f"--{self.arg_prefix}max-grad-norm",
            default=0.5,
            type=float,
            help="-1 results in no grad norm",
        )
        parser.add_argument(
            f"--{self.arg_prefix}linear-lr-decay", type=str2bool, default=True
        )
        parser.add_argument(
            f"--{self.arg_prefix}lr-env-steps",
            type=float,
            default=None,
            help="only used for lr schedule",
        )
        parser.add_argument(
            f"--{self.arg_prefix}eps",
            type=float,
            default=1e-5,
            help="""
                            optimizer epsilon (default: 1e-5)
                            NOTE: The PyTorch default is 1e-8 see
                            https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam
                            """,
        )
        parser.add_argument(
            f"--{self.arg_prefix}lr",
            type=float,
            default=1e-3,
            help="learning rate (default: 1e-3)",
        )

    @staticmethod
    def _create_opt(module_to_opt: nn.Module, lr: float, eps: float = 1e-8):
        get_params_fn = lambda: module_to_opt.parameters()
        return (optim.Adam(get_params_fn(), lr=lr, eps=eps), get_params_fn, lr)

    def _get_optimizers(self):
        return {
            "actor_opt": BaseNetAlgo._create_opt(
                self.policy, self._arg("lr"), self._arg("eps")
            )
        }
