import torch


class OptiMaster():
    def __init__(self, model, epochs, iter_per_epoch, optimizer, scheduler, warmup_epochs, lr_low, lr_high, beta1,
                 beta2, weight_decay, factor, swa, swa_start_epoch, swa_lr, plateua_metric):

        self.lr_low = lr_low
        self.lr_high = lr_high
        self.scheduler = scheduler
        self.factor = factor
        self.weight_decay = weight_decay

        self.iter_per_epoch = iter_per_epoch
        self.epochs = epochs

        self.model_size = 128

        self.model = model

        self.warmup_low = 1e-9
        self.plateau_score = 0
        self.plateua_metric = plateua_metric  # TODO configurable
        self.epoch = -1

        if self.scheduler == 'noam':
            init_lr = 1.0
        else:
            init_lr = self.lr_high

        self.optimizer = self._get_optimizer(optimizer, model, lr=init_lr, beta1=beta1, beta2=beta2,
                                             weight_decay=weight_decay)

        if self.scheduler == 'noam':
            self.warmup_steps = int(iter_per_epoch * warmup_epochs)
            noam_factor = self.factor * (self.model_size ** (-0.5))
            # print("self.factor", self.factor)
            # print("self.model_size ", self.model_size )
            # print("noam_factor", noam_factor)
            # print("self.warmup_steps", self.warmup_steps)
            lr_func = lambda step: noam_factor * min((1 + step) ** (-0.5), (1 + step) * (self.warmup_steps ** (-1.5)))
            warmup_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func)
            self.warmup_schedule = warmup_schedule
            self.warmup_epochs = warmup_epochs  # + 1

        elif warmup_epochs > 0:
            self.warmup_steps = iter_per_epoch * (warmup_epochs)
            lr_func = lambda step: step / self.warmup_steps
            warmup_schedule = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func)
            self.warmup_schedule = warmup_schedule
            self.warmup_epochs = warmup_epochs  # + 1
        else:
            self.warmup_epochs = 0
        self.train_epochs = epochs - self.warmup_epochs

        if swa:
            max_train_epoch = swa_start_epoch
        else:
            max_train_epoch = self.train_epochs

        if not self.scheduler == 'noam':
            main_schedule = self._get_schedule(scheduler, max_epoch=max_train_epoch)
            self.main_schedule = main_schedule

        self.swa = swa
        self.swa_start_epoch = swa_start_epoch
        if swa:
            self.swa_model = torch.optim.swa_utils.AveragedModel(model)
            self.swa_scheduler = torch.optim.swa_utils.SWALR(self.optimizer, swa_lr=swa_lr, anneal_epochs=5,
                                                             anneal_strategy='cos')

        if warmup_epochs > 0 and not self.scheduler == 'noam':
            self.optimizer.param_groups[0]['lr'] = self.warmup_low

    def epoch_step(self, epoch):
        self.epoch = epoch  # + 1

        if self.epoch < self.warmup_epochs - 1 or self.scheduler == 'noam':
            pass
        elif self.epoch > self.swa_start_epoch and self.swa:
            self.swa_model.update_parameters(self.model)
            self.swa_scheduler.step()
        elif self.epoch > self.warmup_epochs - 1:
            if self.scheduler == 'plateau':
                self.main_schedule.step(self.plateau_score)
            else:
                self.main_schedule.step()

    def train_step(self):
        if self.epoch < self.warmup_epochs - 1 or self.scheduler == 'noam':
            self.warmup_schedule.step()

    def update_score(self, score_dict):
        self.plateau_score = score_dict[self.plateua_metric]

    def get_model(self):
        if self.epoch > self.swa_start_epoch and self.swa:
            return self.swa_model
        else:
            return self.model

    @property
    def lr(self):
        return self.optimizer.param_groups[0]['lr']

    def config_weight_decay(self, model):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.Embedding)
        for mn, m in model.named_modules():

            for pn, p in m.named_parameters():

                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name

                if pn.endswith('bias') or ('bias' in pn):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif (pn.endswith('weight') or ('weight' in pn)) and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)
                elif pn.endswith('scale') or pn.endswith('key_dim_scaler'):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)
                elif 'lagmul' in pn:
                    no_decay.add(fpn)
                elif 'arange' in pn:
                    no_decay.add(fpn)

        # special case the position embedding parameter in the root GPT module as not decayed # TODO consider this
        # no_decay.add('pos_emb')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in model.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert len(
            param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params),)

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        return optim_groups

    def _get_optimizer(self, optim_name, model, lr, beta1, beta2, weight_decay):

        if self.weight_decay == 0 or self.weight_decay == False:
            params = model.parameters()
        else:
            params = self.config_weight_decay(model)

        if optim_name == "adam":
            return torch.optim.Adam(params, lr=lr, betas=(beta1, beta2), eps=1e-9, weight_decay=weight_decay)
        elif optim_name == "adamW":
            return torch.optim.AdamW(params, lr=lr, betas=(beta1, beta2), eps=1e-9, weight_decay=weight_decay)
        # elif optim_name == "lamb":
        #     return torch_optimizer.Lamb(params, lr=lr, betas=(beta1, beta2), eps=1e-9, weight_decay=weight_decay)
        # elif optim_name == "radam":
        #     return torch_optimizer.RAdam(params, lr=lr, betas=(beta1, beta2), eps=1e-9, weight_decay=weight_decay)
        elif optim_name == "rmsprop":
            return torch.optim.RMSprop(params, lr=lr, alpha=0.98, momentum=0.1, eps=1e-9, weight_decay=weight_decay)
        # elif optim_name == "adabelief":
        #     return AdaBelief(params, lr=lr, eps=1e-16, betas=(0.9, 0.999), weight_decouple=True, rectify=False)

    def _get_schedule(self, schedule_name, max_epoch):
        if schedule_name == "step":
            train_gamma = (self.lr_low / self.lr_high) ** (1 / max_epoch)
            return torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=train_gamma)

        elif schedule_name == "linear":
            lr_func = lambda epoch: (self.lr_low / self.lr_high - 1) * epoch / max_epoch + 1
            return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func)

        elif schedule_name == "inv_sqrt":
            lr_func = lambda epoch: self.warmup_steps ** 0.5 / (
                        (self.warmup_epochs + epoch) * self.iter_per_epoch) ** 0.5
            return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func)

        elif schedule_name == "const":
            lr_func = lambda epoch: 1
            return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_func)

        elif schedule_name == "cosine":
            return torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, max_epoch, eta_min=self.lr_low)

        elif schedule_name == "plateau":
            patience_epochs = 5  # TODO make configurable
            cooldown_epochs = 2

            return torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', factor=self.factor,
                                                              patience=patience_epochs,
                                                              threshold=0.0001, threshold_mode='rel',
                                                              cooldown=cooldown_epochs, min_lr=0, eps=1e-08, )
