import os
import wandb
import numpy as np

class Logger:
    def __init__(self, log_dir, net=None, n_logged_samples=10, summary_writer=None, config=None):
        wandb.init(
            # Set the project where this run will be logged
            project="etude",
            name=log_dir.split("/")[-1],
            # Track hyperparameters and run metadata
            config=config,
        )
        if net:
            wandb.watch(net, log='all')
        self.table = wandb.Table(columns=["image", "description", "Iteration"])


    def log_scalar(self, scalar, name, step):
        wandb.log({name: scalar}, step=step)

    def log_scalars(self, scalar_dict, group_name, phase,step):
        """Will log all scalars in the same plot."""
        wandb.log(scalar_dict, step=step)

    def log_image(self, image, name, step):
        wandb.log({"GraphImage":wandb.Image(image)}, step=step)
