from torch import Tensor
from collections import OrderedDict
import os
# from .plotter import Plotter


class Logger(object):
    def __init__(self, log_dir, matplotlib=True):

        self.reset(log_dir, matplotlib)

    def reset(self, log_dir=None, tensorboard=True, matplotlib=True):

        if log_dir is not None: self.log_dir=log_dir 
        # self.plotter = Plotter() if matplotlib else None
        self.counter = OrderedDict()

    def update_scalers(self, ordered_dict):

        for key, value in ordered_dict.items():
            if isinstance(value, Tensor):
                ordered_dict[key] = value.item()
            if self.counter.get(key) is None:
                self.counter[key] = 1
            else:
                self.counter[key] += 1

def laps_update_wandb(args, wandb, data_dict, epoch, task_id):
    total_epochs = epoch +1 + task_id * args.num_epochs
    wandb.log({'epoch': total_epochs, 'train_total_loss': data_dict['loss']})
    wandb.log({'epoch': total_epochs, 'lr': data_dict['lr']})
    wandb.log({'epoch': total_epochs, 'train_penalty_loss': data_dict['penalty']})



