import torch
from composer.core import Algorithm, Event

class DataSeedAlgorithm(Algorithm):
    def __init__(self, **config):
        self.seed = config['seed']

    def match(self, event, state):
        return event in [Event.FIT_START]

    def apply(self, event, state, logger):
        if self.seed != 1:
            state.train_dataloader.sampler.seed = self.seed
            print(f'data seed set to {self.seed}')
