from pytorch_lightning import Callback
from src.models.layers.fouriermask import FourierMaskLR, FourierMask

from einops import rearrange
import torch

class MatLoss(Callback):
    def __init__(self, lambd=10000.0):
        super().__init__()
        self.lambd = lambd


    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):

        mask_loss = 0.0
        count = 0
        for mn, m in pl_module.model.named_modules():
            if isinstance(m, FourierMaskLR):
                masked_w1 = m.lr_weight1 * m.get_mask_by_ind(0)
                masked_w2 = m.lr_weight2 * m.get_mask_by_ind(1)

                masked_w1 = rearrange(masked_w1, 'nc rpc in -> (nc rpc) in')
                masked_w2 = rearrange(masked_w2, 'nc out rpc -> out (nc rpc)')

                masked_w = masked_w2 @ masked_w1
                with torch.no_grad():
                    orig_w1 = rearrange(m.lr_weight1, 'nc rpc in -> (nc rpc) in')
                    orig_w2 = rearrange(m.lr_weight2, 'nc out rpc -> out (nc rpc)')
                    orig_w = orig_w2 @ orig_w1
                mask_loss += torch.mean((masked_w - orig_w.detach())**2)
                count += 1
        mask_loss /= count
        pl_module.log("mask loss", self.lambd * mask_loss, rank_zero_only=True, prog_bar=True, on_epoch=True, on_step=True)
        outputs['loss'] += self.lambd * mask_loss









