from typing import Any

import wandb

from .base import MetricsLogger


class WandBLogger(MetricsLogger):
    """Weights & Biases logger implementation."""

    def __init__(
        self,
        project: str,
        entity: str | None = None,
        name: str | None = None,
        config: dict[str, Any] | None = None,
    ):
        """Initialize the WandB logger.

        Args:
            project: Name of the W&B project
            entity: W&B entity (team or username)
            name: Name of the run
            config: Configuration dictionary to log

        """
        wandb.init(
            project=project,
            entity=entity,
            name=name,
            config=config,
        )

    def log_metrics(self, metrics: dict[str, Any], step: int, prefix: str = "") -> None:
        """Log metrics to W&B.

        Args:
            metrics: Dictionary of metric names to values
            step: Current step/iteration number
            prefix: Optional prefix for metric names

        """
        if prefix:
            metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
        wandb.log(metrics, step=step)

    def close(self) -> None:
        """Finish the W&B run."""
        wandb.finish()
