from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import LinearLR
from torch.optim.lr_scheduler import SequentialLR
from bisect import bisect_right
from typing import Optional


class Scheduler_manager:
    """Manage different schedulers"""
    def __init__(self, optimizer, model_info: dict):
        self.model_info = model_info
        self.scheduler_options = model_info['SCHEDULER']
        self.main_name = self.scheduler_options['MAIN']['NAME']
        self.warmup_name = self.scheduler_options['WARMUP']['NAME']
        self.optimizer = optimizer

        main_scheduler = self.get_main_scheduler()
        warmup_scheduler = self.get_warmup_scheduler()

        self.scheduler = self.combine_schedulers(main_scheduler, warmup_scheduler)

    def get_main_scheduler(self):
        if self.main_name == 'NONE':
            print('[INFO] No scheduler selected')
            scheduler = None
        elif self.main_name == 'ROP':
            scheduler = ReduceLROnPlateau(optimizer=self.optimizer, mode='min', patience=self.scheduler_options['MAIN']['PATIENCE'],
                                               factor=self.scheduler_options['MAIN']['FACTOR'])
        elif self.main_name == 'STEPLR':
            scheduler = StepLR(optimizer=self.optimizer, step_size=self.scheduler_options['MAIN']['STEP'],
                                    gamma=self.scheduler_options['MAIN']['GAMMA'])
        elif self.main_name == 'ONECYCLE':
            scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.model_info['LR'], total_steps=self.model_info['EPOCHS'])
        else:
            print('[INFO] Unknown scheduler selected, omitting')
            scheduler = None
        return scheduler

    def get_warmup_scheduler(self):
        if self.warmup_name == 'NONE':
            print(['[INFO] No warmup scheduler selected'])
            scheduler = None
        elif self.warmup_name == 'LINEAR':
            scheduler = LinearLR(optimizer=self.optimizer, start_factor=self.scheduler_options['WARMUP']['DECAY'],
                                 total_iters=self.scheduler_options['WARMUP']['EPOCHS'])
        else:
            scheduler = None
        return scheduler

    def combine_schedulers(self, main, warmup):
        if main is None:
            return None
        if warmup is None:
            return main
        scheduler = SequentialLR_ROP(self.optimizer, schedulers=[warmup, main],
                                     milestones=[self.scheduler_options['WARMUP']['EPOCHS']])
        return scheduler


class SequentialLR_ROP(SequentialLR):
    def __init__(self, optimizer, schedulers, milestones):
        super().__init__(optimizer, schedulers, milestones)

    def step(self, metrics: Optional[int] = ..., epoch: Optional[int] = ...) -> None:
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch - 1)
        scheduler = self._schedulers[idx]
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
            if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                scheduler.step(metrics)
            else:
                scheduler.step(0)
        else:
            if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                scheduler.step(metrics)
            else:
                scheduler.step()

        self._last_lr = scheduler.get_last_lr()
