""""""
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from son_goku import SonGokuScheduler


def _move_to_device(obj: Any, device: torch.device) -> Any:
    """\nobj: Any\ndevice: torch.device\n    """
    if isinstance(obj, torch.Tensor):
        return obj.to(device)
    if isinstance(obj, (list, tuple)):
        return type(obj)(_move_to_device(o, device) for o in obj)
    if isinstance(obj, dict):
        return {k: _move_to_device(v, device) for k, v in obj.items()}
    return obj


def _flatten_grads(grads: Sequence[Optional[torch.Tensor]], device: torch.device) -> np.ndarray:
    """\ngrads: Sequence[Optional[torch.Tensor]]\ndevice: torch.device\n    """
    flat_parts: List[torch.Tensor] = []
    for g in grads:
        if g is None:
            continue
        flat_parts.append(g.reshape(-1))
    if not flat_parts:
        raise ValueError("No gradients to flatten; check loss / graph.")
    flat = torch.cat(flat_parts)
    return flat.detach().to("cpu").numpy().astype("float32")


@dataclass
class TaskSpec:
    """"""

    name: str
    forward_fn: Callable[[nn.Module, Any, torch.device], Tuple[torch.Tensor, torch.Tensor]]
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
    metric_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
    weight: float = 1.0


class MultiTaskTrainer:
    """"""

    def __init__(
        self,
        model: nn.Module,
        tasks: Sequence[TaskSpec],
        scheduler: SonGokuScheduler,
        optimizer: torch.optim.Optimizer,
        device: torch.device,
        grad_clip: Optional[float] = None,
    ) -> None:
        self.model = model.to(device)
        self.tasks = list(tasks)
        self.scheduler = scheduler
        self.optimizer = optimizer
        self.device = device
        self.grad_clip = grad_clip

        # Identify shared parameters
        if hasattr(model, "shared_parameters"):
            shared_params = list(model.shared_parameters())  # ignore[attr-defined]
        else:
            shared_params = list(model.parameters())
        if not shared_params:
            raise ValueError("Model has no parameters to optimize.")
        self.shared_params: List[nn.Parameter] = shared_params

    def _grad_vector(self, loss: torch.Tensor) -> np.ndarray:
        grads = torch.autograd.grad(
            loss,
            self.shared_params,
            retain_graph=True,
            allow_unused=True,
        )
        return _flatten_grads(grads, self.device)

    def train_epoch(self, loader: DataLoader, epoch: int = 0) -> Dict[str, float]:
        self.model.train()
        running: Dict[str, List[float]] = {t.name: [] for t in self.tasks}
        for step, batch in enumerate(loader):
            batch = _move_to_device(batch, self.device)

            active_ids = self.scheduler.next_active_set()
            task_losses: List[torch.Tensor] = []
            grad_vectors: List[np.ndarray] = []

            for task_id in active_ids:
                task = self.tasks[task_id]
                preds, target = task.forward_fn(self.model, batch, self.device)
                loss = task.loss_fn(preds, target)
                task_losses.append(loss * task.weight)
                grad_vectors.append(self._grad_vector(loss))

                if task.metric_fn is not None:
                    with torch.no_grad():
                        metric_val = task.metric_fn(preds, target)
                        running[task.name].append(float(metric_val.detach().cpu()))

            if not task_losses:
                continue

            total_loss = torch.stack(task_losses).mean()
            self.optimizer.zero_grad()
            total_loss.backward()
            if self.grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(self.shared_params, self.grad_clip)
            self.optimizer.step()

            self.scheduler.update_ema(active_ids, np.stack(grad_vectors, axis=0))
            self.scheduler.step_finished()
            if self.scheduler.should_refresh():
                self.scheduler.refresh()

        return {k: float(np.mean(v)) if v else 0.0 for k, v in running.items()}

    @torch.no_grad()
    def evaluate(self, loader: DataLoader) -> Dict[str, float]:
        self.model.eval()
        running: Dict[str, List[float]] = {t.name: [] for t in self.tasks if t.metric_fn}
        for batch in loader:
            batch = _move_to_device(batch, self.device)
            for task in self.tasks:
                preds, target = task.forward_fn(self.model, batch, self.device)
                if task.metric_fn is not None:
                    metric_val = task.metric_fn(preds, target)
                    running[task.name].append(float(metric_val.detach().cpu()))
        return {k: float(np.mean(v)) if v else 0.0 for k, v in running.items()}

    def save_checkpoint(self, path: str, extra: Optional[Dict[str, Any]] = None) -> None:
        payload = {
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler,
            "extra": extra or {},
        }
        torch.save(payload, path)


__all__ = ["TaskSpec", "MultiTaskTrainer", "_move_to_device"]
