""""""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Callable, List, Optional, Sequence, Tuple

import numpy as np

from .graph_coloring import welsh_powell_coloring, color_classes_from_assignments
from .sketching import RandomProjectionSketch


@dataclass
class TauAnnealer:
    """"""

    tau_init: float = 1.0
    tau_target: float = 0.5
    warmup_steps: int = 0
    rate: float = 1e-3

    def __call__(self, t: int) -> float:
        if t < self.warmup_steps:
            return float(self.tau_init)
        x = float(t - self.warmup_steps)
        return float(
            self.tau_target
            + (self.tau_init - self.tau_target) * np.exp(-self.rate * x)
        )


@dataclass
class SonGokuScheduler:
    """"""

    num_tasks: int
    grad_dim: int
    refresh_period: int
    beta: float = 0.9
    tau_annealer: Optional[Callable[[int], float]] = None
    tau_init: float = 1.0
    tau_target: float = 0.5
    warmup_steps: int = 0
    anneal_rate: float = 1e-3
    sketch_dim: Optional[int] = None
    duplicate_singletons: bool = True
    random_state: Optional[int] = None

    step_idx: int = field(init=False, default=0)
    round_idx: int = field(init=False, default=0)
    round_start_step: int = field(init=False, default=0)
    current_tau: float = field(init=False, default=1.0)
    _ema_matrix: np.ndarray = field(init=False, repr=False)
    _ema_initialized: np.ndarray = field(init=False, repr=False)
    _sketcher: RandomProjectionSketch = field(init=False, repr=False)
    _color_classes: List[List[int]] = field(init=False, repr=False)

    def __post_init__(self) -> None:
        if self.num_tasks <= 0:
            raise ValueError("num_tasks must be positive")
        if self.grad_dim <= 0:
            raise ValueError("grad_dim must be positive")
        if self.refresh_period <= 0:
            raise ValueError("refresh_period must be positive")
        if not (0.0 <= self.beta < 1.0):
            raise ValueError("beta must be in [0, 1)")

        self._ema_matrix = np.zeros(
            (self.num_tasks, self.grad_dim), dtype=np.float32
        )
        self._ema_initialized = np.zeros(self.num_tasks, dtype=bool)

        self._sketcher = RandomProjectionSketch(
            input_dim=self.grad_dim,
            sketch_dim=self.sketch_dim,
            random_state=self.random_state,
        )

        if self.tau_annealer is None:
            self.tau_annealer = TauAnnealer(
                tau_init=self.tau_init,
                tau_target=self.tau_target,
                warmup_steps=self.warmup_steps,
                rate=self.anneal_rate,
            )

        self._color_classes = [list(range(self.num_tasks))]
        self.current_tau = float(self.tau_init)
        self.step_idx = 0
        self.round_idx = 0
        self.round_start_step = 0

    @property
    def color_classes(self) -> List[List[int]]:
        """"""
        return self._color_classes

    def next_active_set(self) -> List[int]:
        """"""
        self.current_tau = float(self.tau_annealer(self.step_idx))

        m_r = len(self._color_classes)
        if m_r == 0:
            raise RuntimeError("No color classes defined; did you call refresh()?")

        offset = self.step_idx - self.round_start_step
        color_idx = int(offset % m_r)
        return list(self._color_classes[color_idx])

    def update_ema(
        self,
        task_ids: Sequence[int],
        grad_vectors: np.ndarray,
    ) -> None:
        """\ntask_ids: Sequence[int]\ngrad_vectors: np.ndarray\n        """
        task_ids = list(task_ids)
        grad_vectors = np.asarray(grad_vectors, dtype=np.float32)
        if grad_vectors.ndim != 2 or grad_vectors.shape[0] != len(task_ids):
            raise ValueError(
                f"grad_vectors must have shape (len(task_ids), {self.grad_dim}), "
                f"got {grad_vectors.shape} for {len(task_ids)} task_ids"
            )
        if grad_vectors.shape[1] != self.grad_dim:
            raise ValueError(
                f"Each gradient must have dimension grad_dim={self.grad_dim}, "
                f"got {grad_vectors.shape[1]}"
            )

        for idx, task_id in enumerate(task_ids):
            if not (0 <= task_id < self.num_tasks):
                raise IndexError(f"task_id {task_id} out of range [0, {self.num_tasks})")
            g = grad_vectors[idx]
            self._ema_matrix[task_id] = (
                self.beta * self._ema_matrix[task_id]
                + (1.0 - self.beta) * g
            )
            self._ema_initialized[task_id] = True

    def step_finished(self) -> None:
        """"""
        self.step_idx += 1

    def should_refresh(self) -> bool:
        """"""
        return (self.step_idx % self.refresh_period) == 0

    def refresh(
        self,
        probe_gradients: Optional[np.ndarray] = None,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """\nprobe_gradients: Optional[np.ndarray]\n        """
        if probe_gradients is not None:
            probe_gradients = np.asarray(probe_gradients, dtype=np.float32)
            if probe_gradients.shape != self._ema_matrix.shape:
                raise ValueError(
                    f"probe_gradients must have shape {self._ema_matrix.shape}, "
                    f"got {probe_gradients.shape}"
                )
            self._ema_matrix = (
                self.beta * self._ema_matrix
                + (1.0 - self.beta) * probe_gradients
            )
            self._ema_initialized[:] = True

        Mf = self._sketcher.sketch(self._ema_matrix)

        norms = np.linalg.norm(Mf, axis=1, keepdims=True)
        Mff = np.zeros_like(Mf)
        nonzero = norms.squeeze() > 0
        if np.any(nonzero):
            Mff[nonzero] = Mf[nonzero] / norms[nonzero]

        cos_matrix = Mff @ Mff.T

        tau = float(self.current_tau)
        conflict = cos_matrix < -tau
        np.fill_diagonal(conflict, False)

        colors = welsh_powell_coloring(conflict.tolist())
        color_classes = color_classes_from_assignments(colors)

        if self.duplicate_singletons:
            color_classes = self._duplicate_singleton_classes(
                color_classes, conflict
            )

        self._color_classes = color_classes
        self.round_idx += 1
        self.round_start_step = self.step_idx

        return cos_matrix, color_classes

    def _duplicate_singleton_classes(
        self,
        color_classes: List[List[int]],
        conflict_adj: np.ndarray,
    ) -> List[List[int]]:
        """\ncolor_classes: List[List[int]]\nconflict_adj: np.ndarray\n        """
        if not color_classes:
            return color_classes

        K = self.num_tasks
        if conflict_adj.shape != (K, K):
            raise ValueError(
                f"conflict_adj must have shape ({K}, {K}), got {conflict_adj.shape}"
            )

        for class_idx, group in enumerate(color_classes):
            if len(group) != 1:
                continue

            task = group[0]
            for other_idx, other_group in enumerate(color_classes):
                if other_idx == class_idx:
                    continue
                if task in other_group:
                    continue
                if all(
                    not conflict_adj[task, other_task] for other_task in other_group
                ):
                    other_group.append(task)
                    break

        return color_classes

    def schedule_for_round(self) -> List[List[int]]:
        """"""
        return [list(group) for group in self._color_classes]

    def ema_matrix(self) -> np.ndarray:
        """"""
        return self._ema_matrix.copy()
