import torch
import torch.nn as nn
import numpy as np


def drop_scheduler(drop_rate, epochs, cutoff_epoch=50, mode="early", schedule="linear"):
    assert mode in ["standard", "early", "late"]
    if mode == "standard":
        return np.full(epochs, drop_rate)

    early_iters = cutoff_epoch
    late_iters = (epochs - cutoff_epoch)

    if mode == "early":
        assert schedule in ["constant", "linear"]
        if schedule == 'constant':
            early_schedule = np.full(early_iters, drop_rate)
        elif schedule == 'linear':
            early_schedule = np.linspace(drop_rate, 0, early_iters)
        final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))

    elif mode == "late":
        assert schedule in ["constant"]
        early_schedule = np.full(early_iters, 0)
        final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate)))

    assert len(final_schedule) == epochs
    return final_schedule



def updatedp(model: nn.Module, dropout: float):
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.p = dropout
            


class dropoutscheduler:

    def __init__(self, model: nn.Module, dropout: float, epochs: int, dppreepoch: int) -> None:
        self.model = model
        if dppreepoch <= 0:
            self.routine = np.full(epochs+10, dropout)
        else:
            self.routine = drop_scheduler(dropout, epochs+10, dppreepoch)
        self.idx = 0

    def step(self):
        self.idx += 1
        updatedp(self.model, self.routine[self.idx])

if __name__ == "__main__":
    print(drop_scheduler(0.1, 100, 20))