from stable_baselines3.common.logger import KVWriter, Video, Figure, Image,  HParam
from stable_baselines3.common.logger import Logger, filter_excluded_keys
import wandb
from typing import Optional, Dict, Tuple, Any


class WandbOutputFormat(KVWriter):
    """
    Log to Weights & Biases (wandb).
    Requires `wandb.init(...)` to be called before using this logger.
    """

    def __init__(self, project: Optional[str] = None, name: Optional[str] = None, config: Optional[dict] = None):
        assert wandb is not None, "wandb is not installed. Install with `pip install wandb`"

        if wandb.run is None:
            wandb.init(project=project, name=name, config=config)

    def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
        log_dict = {}
        for key, value in filter_excluded_keys(key_values, key_excluded, "wandb").items():
            if isinstance(value, Video):
                # wandb.Video expects path or ndarray
                log_dict[key] = wandb.Video(value.frames.cpu().numpy(), fps=value.fps, format="mp4")

            elif isinstance(value, Figure):
                log_dict[key] = wandb.Image(value.figure)
                if value.close:
                    value.figure.clf()

            elif isinstance(value, Image):
                log_dict[key] = wandb.Image(value.image, caption=key)

            elif isinstance(value, HParam):
                # wandb.config/update
                wandb.config.update(value.hparam_dict, allow_val_change=True)
                # log metrics alongside
                for m_key, m_val in value.metric_dict.items():
                    log_dict[m_key] = m_val
            else:
                log_dict[key] = value

        if log_dict:
            wandb.log(log_dict, step=step)

    def close(self) -> None:
        if wandb.run is not None:
            wandb.finish()


