import torch

class AdamRel(torch.optim.Adam):
    """Adam optimizer with support for resetting the internal timestep t (step)."""

    def reset_timestep(self):
        """Reset the timestep t = 0 (i.e., state['step']) for all parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state.get(p)
                if state is not None and 'step' in state:
                    state['step'] = 0