from pytorch_lightning import Callback, Trainer
from src.models.layers.fouriermask import FourierMaskLR



class WLFreeze(Callback):
    def __init__(self, freeze_epoch, thaw_epoch):
        super().__init__()
        self.freeze_epoch = freeze_epoch
        self.thaw_epoch = thaw_epoch


    def on_train_epoch_start(self, trainer, pl_module):
        if trainer.current_epoch == self.freeze_epoch:
            for mn, m in pl_module.named_modules():
                if isinstance(m, FourierMaskLR):
                    m.widths.requires_grad=False
                    m.locations.requires_grad=False
        elif trainer.current_epoch == self.thaw_epoch:
            for mn, m in pl_module.named_modules():
                if isinstance(m, FourierMaskLR):
                    m.widths.requires_grad=True
                    m.locations.requires_grad=True

