import sys
import wandb

class TrainLogger:
    def __init__(self):
        self.logger = None
        self.print_logs = False

    
    @classmethod
    def from_config(cls, config):
        instance = cls.__new__(cls)
        if config.use_wandb:
            wandb.login()
            name = f'{config.name}_learn={config.learn}_env={config.env}_alg={config.alg}_seed={config.seed}'
            instance.logger = wandb.init(
                config=dict(config),
                project=config.project,
                name= name,
            )
        else:
            instance.logger = None 
        
        instance.print_logs = config.print_logs
        return instance

    
    def log(self, dct, do_print = True, **kwargs):
        if do_print and self.print_logs:
            print(dct)
        if self.logger is not None:
            self.logger.log(dct, **kwargs)

##############################################################################
# REDIRECT LOGGER #
##############################################################################

def redirect_stdout(logfile):
    def MyHookOut(text):
        logfile.write(text)
        logfile.flush()
        return 1, 0, text
    phOut = PrintHook()
    phOut.Start(MyHookOut)


# this class gets all output directed to stdout(e.g by print statements)
# and stderr and redirects it to a user defined function
class PrintHook:
    # out = 1 means stdout will be hooked
    # out = 0 means stderr will be hooked
    def __init__(self, out=1):
        self.func = None  ##self.func is userdefined function
        self.origOut = None
        self.out = out

    # user defined hook must return three variables
    # proceed, lineNoMode, newText
    def TestHook(self, text):
        f = open('hook_log.txt', 'a')
        f.write(text)
        f.close()
        return 0, 0, text

    def Start(self, func=None):
        if self.out:
            sys.stdout = self
            self.origOut = sys.__stdout__
        else:
            sys.stderr = self
            self.origOut = sys.__stderr__

        if func:
            self.func = func
        else:
            self.func = self.TestHook

    # Stop will stop routing of print statements thru this class
    def Stop(self):
        self.origOut.flush()
        if self.out:
            sys.stdout = sys.__stdout__
        else:
            sys.stderr = sys.__stderr__
        self.func = None

    # override write of stdout
    def write(self, text):

        bProceed = 1
        bLineNo = 0
        newText = ''

        if self.func != None:
            bProceed, bLineNo, newText = self.func(text)

        if bProceed:
            if text.split() == []:
                self.origOut.write(text)
            else:
                # if goint to stdout then only add line no file etc
                # for stderr it is already there
                if self.out:
                    if bLineNo:
                        try:
                            raise Exception("Dummy")
                        except:
                            lineNo = 'line(' + str(sys.exc_info()[2].tb_frame.f_back.f_lineno) + '):'
                            codeObject = sys.exc_info()[2].tb_frame.f_back.f_code
                            fileName = codeObject.co_filename
                            funcName = codeObject.co_name
                        self.origOut.write('file ' + fileName + ',' + 'func ' + funcName + ':' + lineNo)
                self.origOut.write(newText)

    # pass all other methods to __stdout__ so that we don't have to override them
    def __getattr__(self, name):
        # return self.origOut.__getattr__(name)
        return getattr(self.origOut, name)


if __name__ == '__main__':

    test_printhook = False
    if test_printhook:

        def MyHookOut(text):
            f = open('log.txt', 'a')
            f.write(text)
            f.close()
            return 1, 1, 'Out Hooked:' + text


        def MyHookErr(text):
            f = open('hook_log.txt', 'a')
            f.write(text)
            f.close()
            return 1, 1, 'Err Hooked:' + text


        print('Hook Start')
        phOut = PrintHook()
        phOut.Start(MyHookOut)
        phErr = PrintHook(0)
        phErr.Start(MyHookErr)
        print('Is this working?')
        print('It seems so!')
        phOut.Stop()
        print('STDOUT Hook end')
        compile(',', '<string>', 'exec')
        phErr.Stop()
        print('Hook end')