from methods.base import BaseMethod

class INTL(BaseMethod):
    # Iterative Normalization with Trace loss

    def __init__(self, cfg):
        super().__init__(cfg)

    def forward(self, samples):
        loss = 0
        nmb_crops = len(samples)
        for x in samples:
            x.cuda(non_blocking=True)

        t = [self.IterNorm(self.projection(self.backbone(x))) for x in samples]
        intl = [self.INTL(x) for x in t]

        for i in range(1,nmb_crops):
            loss += self.norm_mse(t[i], t[0]) + self.trade_off * (intl[i] + intl[0])
        loss /= (nmb_crops - 1)
        return loss
