from typing import Any
from collections import OrderedDict

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.model_summary import get_human_readable_count
from src.models.layers.fouriermask import FourierMaskLR

import torch
import torch.nn as nn


class MaskHandler(Callback):
    """Monitor the scales of weights and gradients.
    """

    def __init__(self, freq):
        super().__init__()
        self.freq = freq

    def on_fit_start(self, trainer, pl_module):
        model = pl_module.model
        target_modules = (FourierMaskLR,)
        self.modules = []
        for mn, m in model.named_modules():
            if isinstance(m, target_modules):
                self.modules.append(m)

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        if batch_idx % self.freq != self.freq - 1:
            for m in self.modules:
                m.widths.requires_grad_(False)
                m.locations.requires_grad_(False)
        else:
            for m in self.modules:
                m.widths.requires_grad_(True)
                m.locations.requires_grad_(True)

