from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, StepLR, CosineAnnealingLR, MultiStepLR

from utils import key_is_none

def initialize_scheduler(config, optimizer, n_train_steps):
    # construct schedulers
    if config.scheduler is None or config.scheduler == 'none':
        return None
    elif config.scheduler == 'linear_schedule_with_warmup':
        from transformers import get_linear_schedule_with_warmup
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_training_steps=n_train_steps,
            **config.scheduler_kwargs)
        step_every_batch = True
        use_metric = False
    elif config.scheduler == 'cosine_schedule_with_warmup':
        from transformers import get_cosine_schedule_with_warmup
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_training_steps=n_train_steps,
            **config.scheduler_kwargs)
        step_every_batch = True
        use_metric = False
    elif config.scheduler=='ReduceLROnPlateau':
        assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}'
        scheduler = ReduceLROnPlateau(
            optimizer,
            **config.scheduler_kwargs)
        step_every_batch = False
        use_metric = True
    elif config.scheduler == 'StepLR':
        scheduler = StepLR(optimizer, **config.scheduler_kwargs)
        step_every_batch = False
        use_metric = False
    elif config.scheduler == 'FractionStepLR':
        assert not key_is_none(config.scheduler_kwargs, 'step_size')
        config.scheduler_kwargs['step_size'] = int(config.scheduler_kwargs['step_size'] * n_train_steps)
        scheduler = StepLR(optimizer, **config.scheduler_kwargs)
        step_every_batch = True
        use_metric = False
    elif config.scheduler == 'FixMatchLR':
        scheduler = LambdaLR(
            optimizer,
            lambda x: (1.0 + 10 * float(x) / n_train_steps) ** -0.75
        )
        step_every_batch = True
        use_metric = False
    elif config.scheduler == 'MultiStepLR':
        scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs)
        step_every_batch = False
        use_metric = False
    else:
        raise ValueError(f'Scheduler: {config.scheduler} not supported.')

    # add an step_every_batch field
    scheduler.step_every_batch = step_every_batch
    scheduler.use_metric = use_metric
    return scheduler

def step_scheduler(scheduler, metric=None):
    if isinstance(scheduler, ReduceLROnPlateau):
        assert metric is not None
        scheduler.step(metric)
    else:
        scheduler.step()

class LinearScheduleWithWarmupAndThreshold():
    """
    Linear scheduler with warmup and threshold for non lr parameters.
    Parameters is held at 0 until some T1, linearly increased until T2, and then held
    at some max value after T2.
    Designed to be called by step_scheduler() above and used within Algorithm class.
    Args:
        - last_warmup_step: aka T1. for steps [0, T1) keep param = 0
        - threshold_step: aka T2. step over period [T1, T2) to reach param = max value
        - max value: end value of the param
    """
    def __init__(self, max_value, last_warmup_step=0, threshold_step=1, step_every_batch=False):
        self.max_value = max_value
        self.T1 = last_warmup_step
        self.T2 = threshold_step
        assert (0 <= self.T1) and (self.T1 < self.T2)

        # internal tracker of which step we're on
        self.current_step = 0
        self.value = 0

        # required fields called in Algorithm when stepping schedulers
        self.step_every_batch = step_every_batch
        self.use_metric = False

    def step(self):
        """This function is first called AFTER step 0, so increment first to set value for next step"""
        self.current_step += 1
        if self.current_step < self.T1:
            self.value = 0
        elif self.current_step < self.T2:
            self.value = (self.current_step - self.T1) / (self.T2 - self.T1) * self.max_value
        else:
            self.value = self.max_value
