class MultipleOptimizer(object):
    def __init__(self, **optimizers):
        self.optimizers = optimizers

    def zero_grad(self):
        for optimizer in self.optimizers.values():
            optimizer.zero_grad()

    def step(self):
        for optimizer in self.optimizers.values():
            optimizer.step()

    def load_state_dict(self, state_dict):
        for key, optimizer in self.optimizers.items():
            optimizer.load_state_dict(state_dict[key])

    def state_dict(self):
        return {key: optimizer.state_dict() for key, optimizer in self.optimizers.items()}