""""""

from __future__ import annotations

from typing import List, Optional, Tuple

import numpy as np

from ..scheduler import SonGokuScheduler
from ..graph_coloring import welsh_powell_coloring, color_classes_from_assignments


class KnnSymmetricScheduler(SonGokuScheduler):
    """"""

    def __init__(
        self,
        num_tasks: int,
        grad_dim: int,
        refresh_period: int,
        *,
        knn_k: int,
        **kwargs,
    ) -> None:
        if knn_k <= 0:
            raise ValueError("knn_k must be a positive integer")
        self.knn_k = int(knn_k)
        super().__init__(
            num_tasks=num_tasks,
            grad_dim=grad_dim,
            refresh_period=refresh_period,
            **kwargs,
        )

    def _build_knn_conflict_adj(self, cos_matrix: np.ndarray) -> np.ndarray:
        """\ncos_matrix: np.ndarray\n        """
        K = self.num_tasks
        if cos_matrix.shape != (K, K):
            raise ValueError(
                f"cos_matrix must have shape ({K}, {K}), got {cos_matrix.shape}"
            )

        if K <= 1:
            return np.zeros((K, K), dtype=bool)

        m = min(self.knn_k, max(K - 1, 0))
        if m == 0:
            return np.zeros((K, K), dtype=bool)

        conflict = np.zeros((K, K), dtype=bool)

        for i in range(K):
            row = cos_matrix[i].astype(np.float32).copy()
            row[i] = np.inf
            idx = np.argpartition(row, m - 1)[:m]
            conflict[i, idx] = True

        conflict = np.logical_or(conflict, conflict.T)
        np.fill_diagonal(conflict, False)
        return conflict

    def refresh(
        self,
        probe_gradients: Optional[np.ndarray] = None,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """\nprobe_gradients: Optional[np.ndarray]\n        """
        if probe_gradients is not None:
            probe_gradients = np.asarray(probe_gradients, dtype=np.float32)
            if probe_gradients.shape != self._ema_matrix.shape:
                raise ValueError(
                    f"probe_gradients must have shape {self._ema_matrix.shape}, "
                    f"got {probe_gradients.shape}"
                )
            self._ema_matrix = (
                self.beta * self._ema_matrix
                + (1.0 - self.beta) * probe_gradients
            )
            self._ema_initialized[:] = True

        Mf = self._sketcher.sketch(self._ema_matrix)

        norms = np.linalg.norm(Mf, axis=1, keepdims=True)
        Mff = np.zeros_like(Mf)
        nonzero = norms.squeeze() > 0
        if np.any(nonzero):
            Mff[nonzero] = Mf[nonzero] / norms[nonzero]

        cos_matrix = Mff @ Mff.T

        conflict = self._build_knn_conflict_adj(cos_matrix)

        colors = welsh_powell_coloring(conflict.tolist())
        color_classes = color_classes_from_assignments(colors)

        if self.duplicate_singletons:
            color_classes = self._duplicate_singleton_classes(
                color_classes, conflict
            )

        self._color_classes = color_classes
        self.round_idx += 1
        self.round_start_step = self.step_idx

        return cos_matrix, color_classes


class SignedOnlyScheduler(SonGokuScheduler):
    """"""

    def _build_signed_conflict_adj(self, cos_matrix: np.ndarray) -> np.ndarray:
        K = self.num_tasks
        if cos_matrix.shape != (K, K):
            raise ValueError(
                f"cos_matrix must have shape ({K}, {K}), got {cos_matrix.shape}"
            )

        conflict = cos_matrix < 0.0
        np.fill_diagonal(conflict, False)
        return conflict

    def refresh(
        self,
        probe_gradients: Optional[np.ndarray] = None,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """\nprobe_gradients: Optional[np.ndarray]\n        """
        if probe_gradients is not None:
            probe_gradients = np.asarray(probe_gradients, dtype=np.float32)
            if probe_gradients.shape != self._ema_matrix.shape:
                raise ValueError(
                    f"probe_gradients must have shape {self._ema_matrix.shape}, "
                    f"got {probe_gradients.shape}"
                )
            self._ema_matrix = (
                self.beta * self._ema_matrix
                + (1.0 - self.beta) * probe_gradients
            )
            self._ema_initialized[:] = True

        Mf = self._sketcher.sketch(self._ema_matrix)

        norms = np.linalg.norm(Mf, axis=1, keepdims=True)
        Mff = np.zeros_like(Mf)
        nonzero = norms.squeeze() > 0
        if np.any(nonzero):
            Mff[nonzero] = Mf[nonzero] / norms[nonzero]

        cos_matrix = Mff @ Mff.T

        conflict = self._build_signed_conflict_adj(cos_matrix)

        colors = welsh_powell_coloring(conflict.tolist())
        color_classes = color_classes_from_assignments(colors)

        if self.duplicate_singletons:
            color_classes = self._duplicate_singleton_classes(
                color_classes, conflict
            )

        self._color_classes = color_classes
        self.round_idx += 1
        self.round_start_step = self.step_idx

        return cos_matrix, color_classes


class QuantileScheduler(SonGokuScheduler):
    """"""

    def __init__(
        self,
        num_tasks: int,
        grad_dim: int,
        refresh_period: int,
        *,
        percentile: float,
        **kwargs,
    ) -> None:
        if not (0.0 < percentile < 100.0):
            raise ValueError("percentile must lie in the open interval (0, 100)")
        self.percentile = float(percentile)
        super().__init__(
            num_tasks=num_tasks,
            grad_dim=grad_dim,
            refresh_period=refresh_period,
            **kwargs,
        )

    def _build_quantile_conflict_adj(self, cos_matrix: np.ndarray) -> np.ndarray:
        K = self.num_tasks
        if cos_matrix.shape != (K, K):
            raise ValueError(
                f"cos_matrix must have shape ({K}, {K}), got {cos_matrix.shape}"
            )

        conflict = np.zeros((K, K), dtype=bool)
        if K <= 1:
            return conflict

        iu, ju = np.triu_indices(K, k=1)
        values = cos_matrix[iu, ju]
        if values.size == 0:
            return conflict

        q_p = float(np.percentile(values, self.percentile))

        self.current_tau = q_p

        mask = values <= q_p
        conflict[iu[mask], ju[mask]] = True
        conflict[ju[mask], iu[mask]] = True
        np.fill_diagonal(conflict, False)
        return conflict

    def refresh(
        self,
        probe_gradients: Optional[np.ndarray] = None,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """\nprobe_gradients: Optional[np.ndarray]\n        """
        if probe_gradients is not None:
            probe_gradients = np.asarray(probe_gradients, dtype=np.float32)
            if probe_gradients.shape != self._ema_matrix.shape:
                raise ValueError(
                    f"probe_gradients must have shape {self._ema_matrix.shape}, "
                    f"got {probe_gradients.shape}"
                )
            self._ema_matrix = (
                self.beta * self._ema_matrix
                + (1.0 - self.beta) * probe_gradients
            )
            self._ema_initialized[:] = True

        Mf = self._sketcher.sketch(self._ema_matrix)

        norms = np.linalg.norm(Mf, axis=1, keepdims=True)
        Mff = np.zeros_like(Mf)
        nonzero = norms.squeeze() > 0
        if np.any(nonzero):
            Mff[nonzero] = Mf[nonzero] / norms[nonzero]

        cos_matrix = Mff @ Mff.T

        conflict = self._build_quantile_conflict_adj(cos_matrix)

        colors = welsh_powell_coloring(conflict.tolist())
        color_classes = color_classes_from_assignments(colors)

        if self.duplicate_singletons:
            color_classes = self._duplicate_singleton_classes(
                color_classes, conflict
            )

        self._color_classes = color_classes
        self.round_idx += 1
        self.round_start_step = self.step_idx

        return cos_matrix, color_classes
