import torch


class BayesianScheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):

    def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
                 threshold=1e-4, threshold_mode='rel', cooldown=10,
                 min_lr=0, eps=1e-8, verbose=False, temperature=1.0, max_temperature=None, factor_weight=10., update_prior_weight=None):
        super().__init__(optimizer, mode, factor, patience,
                         threshold, threshold_mode, cooldown, min_lr, eps)
        self.max_prior_weight = max_temperature
        if self.max_prior_weight is None:
            self.max_prior_weight = temperature
        self.show = verbose
        self.prior_update = True
        self.prior_best = None
        self.prior_count = 0
        self.prior_weight = temperature
        self.factor_weight = factor_weight
        self.update_prior_weight = update_prior_weight
        self.is_step_prior = True
        self.epoch_count = 0
        self.current_lr = self.optimizer.param_groups[0]['lr']

    def step(self, metrics, epoch=None):
        # Call the original step function
        self.epoch_count += 1
        if self.epoch_count < self.cooldown:
            return

        if self.prior_weight == self.max_prior_weight:
            self.is_step_prior = False

        if self.is_step_prior:
            previous = self.prior_weight
            self.step_prior(metrics)
            if previous != self.prior_weight:
                self.epoch_count = 0
                self.prior_best = torch.inf
                if self.current_lr != self.min_lrs[0]:
                    self.is_step_prior = False
                if self.show:
                    print(f"Prior weight increased to {self.prior_weight}")
        else:

            super().step(metrics, epoch)
            # Check if the learning rate was reduced in this step
            new_lr = self.get_last_lr()
            if new_lr != self.current_lr:
                self.epoch_count = 0
                self.best = torch.inf
                self.current_lr = new_lr
                if self.prior_weight != self.max_prior_weight:
                    self.is_step_prior = True

                if self.show:
                    print(f"Learning rate reduced to {new_lr}")

    def step_prior(self, metrics):

        current = float(metrics)
        if self.prior_best is None:
            self.prior_best = current

        if current < self.prior_best:
            self.prior_best = current
            self.prior_count = 0
        else:
            self.prior_count += 1

        if self.prior_count >= self.patience:
            self.prior_count = 0
            new_weight = self.prior_weight * self.factor_weight
            new_weight = min(new_weight, self.max_prior_weight)
            if self.update_prior_weight is None:
                self.optimizer.param_groups[0]['prior_weight'] = new_weight
            else:
                self.update_prior_weight(new_weight)

            self.prior_weight = new_weight
            self.prior_best = current
