""""""

from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np

from ..scheduler import SonGokuScheduler


class StaticOneShotScheduler(SonGokuScheduler):
    """"""

    def __post_init__(self) -> None:
        super().__post_init__()
        self._frozen: bool = False
        self._frozen_cos_matrix: Optional[np.ndarray] = None

    def refresh(
        self,
        probe_gradients: Optional[np.ndarray] = None,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """\nprobe_gradients: Optional[np.ndarray]\n        """
        if not self._frozen:
            cos_matrix, color_classes = super().refresh(probe_gradients)
            self._frozen_cos_matrix = cos_matrix
            self._frozen = True
            return cos_matrix, [list(group) for group in color_classes]

        self.round_idx += 1
        self.round_start_step = self.step_idx

        cos_matrix = (
            self._frozen_cos_matrix
            if self._frozen_cos_matrix is not None
            else np.eye(self.num_tasks, dtype=np.float32)
        )
        return cos_matrix, self.schedule_for_round()
