from typing import List, Optional, Tuple

from absl import logging


class BaseCfgScheduler:
    def __init__(self, name: str, configs: List[dict]) -> None:
        assert isinstance(configs, list)
        assert len(configs) > 0
        self.name = name
        self.configs = configs

    def schedule(self, step: int, metrics: Optional[dict] = None) -> Tuple[bool, dict]:
        return False, None


class StepCfgScheduler(BaseCfgScheduler):
    def __init__(
        self, name: str, configs: List[dict], change_points: List[int]
    ) -> None:
        super().__init__(name, configs)

        assert isinstance(change_points, list)
        assert len(change_points) + 1 == len(configs)
        self.change_points = change_points
        self.curr_id = -1

    def schedule(self, step: int, metrics: Optional[dict] = None) -> Tuple[bool, dict]:
        is_new = False
        while self.curr_id < len(self.configs) - 1 and (
            self.curr_id == -1 or step > self.change_points[self.curr_id]
        ):
            logging.info(
                f"Scheduler {self.name} advancing: {self.curr_id} + 1  = {self.configs[self.curr_id + 1]}"
            )
            is_new = True
            self.curr_id += 1

        return is_new, self.configs[self.curr_id]
