from termcolor import cprint
import pprint
import wandb
import numpy as np
import os

class Logger:
    def __init__(self):
        cprint("Initializing Logger", "green")
        self.pp = pprint.PrettyPrinter(indent=2)
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.indx = 0
        self.logged_contents = set()
        self.rules = []

    def log(self, message, type="info"):

        self.indx += 1

        if self.stream_control(message, type):
            self.__log__(message, type)

    def __log__(self, message, type="info"):
        self.log_preprocess(message, type)

        if type == "info":
            self.log_info(message)
        elif type == "highlight":
            self.log_highlight(message)
        elif type == "error":
            self.log_error(message)
        elif type == "warning":
            self.log_warning(message)
        elif type == "dict":
            self.log_dict(message)
        
        self.log_postprocess(message, type)

    def stream_control(self, message, type):
        
        for rule in self.rules:
            if rule(self, message, type):
                return True 

        return False

    def log_preprocess(self, message, type):
        print(f"GPU{self.global_rank}:", end=" ")

    def log_postprocess(self, message, type):
        # self.logged_contents.add(message)
        pass
        

    def log_error(self, message):
        cprint("error" + message, "red")
    
    def log_warning(self, message):
        cprint("[warning]" + message, "yellow")
    
    def log_info(self, message):
        cprint("[info]" + message, "black")
    
    def log_highlight(self, message):
        cprint("[highlight]" + message, "blue")

    def log_once(self, message):
        cprint("[once]" + message, "green")

    def log_dict(self, dic):
        self.pp.pprint(dic)
    
    def add_rule(self, func):
        self.rules.append(func)
