""""""

from __future__ import annotations

from typing import List, Optional, Sequence, Tuple

import numpy as np

from ..scheduler import SonGokuScheduler
from ..graph_coloring import (
    welsh_powell_coloring,
    color_classes_from_assignments,
)


class SingleStepScheduler(SonGokuScheduler):
    """"""

    def __post_init__(self) -> None:
        super().__post_init__()
        self._last_grads: np.ndarray = np.zeros_like(self._ema_matrix)
        self._last_initialized: np.ndarray = np.zeros(
            self.num_tasks, dtype=bool
        )

    def update_ema(
        self,
        task_ids: Sequence[int],
        grad_vectors: np.ndarray,
    ) -> None:
        """\ntask_ids: Sequence[int]\ngrad_vectors: np.ndarray\n        """
        task_ids = list(task_ids)
        grad_vectors = np.asarray(grad_vectors, dtype=np.float32)
        if grad_vectors.ndim != 2 or grad_vectors.shape[0] != len(task_ids):
            raise ValueError(
                f"grad_vectors must have shape (len(task_ids), {self.grad_dim}), "
                f"got {grad_vectors.shape} for {len(task_ids)} task_ids"
            )
        if grad_vectors.shape[1] != self.grad_dim:
            raise ValueError(
                f"Each gradient must have dimension grad_dim={self.grad_dim}, "
                f"got {grad_vectors.shape[1]}"
            )

        for idx, task_id in enumerate(task_ids):
            if not (0 <= task_id < self.num_tasks):
                raise IndexError(
                    f"task_id {task_id} out of range [0, {self.num_tasks})"
                )
            g = grad_vectors[idx]
            self._last_grads[task_id] = g
            self._last_initialized[task_id] = True

    def refresh(
        self,
        probe_gradients: Optional[np.ndarray] = None,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """\nprobe_gradients: Optional[np.ndarray]\n        """
        M = np.asarray(self._last_grads, dtype=np.float32)
        Mf = self._sketcher.sketch(M)
        norms = np.linalg.norm(Mf, axis=1, keepdims=True)
        Mff = np.zeros_like(Mf)
        nonzero = norms.squeeze() > 0
        if np.any(nonzero):
            Mff[nonzero] = Mf[nonzero] / norms[nonzero]

        cos_matrix = Mff @ Mff.T

        tau = float(self.current_tau)
        conflict = cos_matrix < -tau
        np.fill_diagonal(conflict, False)

        colors = welsh_powell_coloring(conflict.tolist())
        color_classes = color_classes_from_assignments(colors)

        if self.duplicate_singletons:
            color_classes = self._duplicate_singleton_classes(
                color_classes, conflict
            )

        self._color_classes = color_classes
        self.round_idx += 1
        self.round_start_step = self.step_idx

        return cos_matrix, color_classes
