from dataclasses import dataclass
from typing import List

import numpy as np


@dataclass
class State:
    epoch: int = 0
    iteration: int = 0
    best_iteration: int = 0
    best_val_metric: float = -np.inf  # higher is better
    num_epochs_not_improved: int = 0

    def get_checkpoint_iters(self) -> list[int]:
        return [self.iteration, self.best_iteration]
