import torch

from composer.core import Algorithm, Event

class LossWeighter(Algorithm):
    def __init__(self, **config):
        self.config = config
        weights_path = config.get('weights_path', None)
        self.weights = None
        if weights_path is not None:
            self.weights = torch.load(weights_path, weights_only=False)
        self.cntr = 0

    def match(self, event, state):
        return event == Event.AFTER_LOSS

    def apply(self, event, state, logger):
        try:
            if self.weights is not None:
                state.loss = state.loss * self.weights[self.cntr]
                self.cntr += 1
        except:
            print(f'No weights found, skipping loss weighting')
            pass
        