from torch.utils.tensorboard import SummaryWriter
import neptune
import json
import os

with open('neptune_credentials.json') as f:
    neptune_credentials = json.load(f)

# run = neptune.init_run(
#     **neptune_credentials,
#     with_id='test'
# )

class Neptune_SummaryWriter(SummaryWriter):
    def __init__(self, log_dir=None, neptune_id=None, tags=[], neptune_dir_path=None):
        if neptune_dir_path is not None:
            os.environ["NEPTUNE_DATA_DIRECTORY"] = neptune_dir_path
        super().__init__(log_dir)
        if neptune_id is None:
            self.neptune_run = neptune.init_run(**neptune_credentials, source_files=['*.py', 'AIDomains/*.py', 'scripts/**'], tags=tags)
        else:
            self.neptune_run = neptune.init_run(**neptune_credentials, source_files=['*.py', 'AIDomains/*.py', 'scripts/**'], tags=tags, name=neptune_id)
    
    def log_args(self, args):
        try:
            self.neptune_run['parameters'] = args
        except:
            print('log args failed', args)

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        super().add_scalar(tag, scalar_value, global_step, walltime)
        # print('log scalar', tag, scalar_value, global_step)
        try:        
            self.neptune_run[tag].log(scalar_value, step=global_step)
        except:
            print('log scalar failed', tag, scalar_value, global_step, 'types:' , type(tag), type(scalar_value), type(global_step))


    def add_text(self, tag, text_string, global_step=None, walltime=None):
        super().add_text(tag, text_string, global_step, walltime)
        print('log text', tag, text_string, global_step)
        try:    
            self.neptune_run[tag].log(text_string, step=global_step)
        except:
            print('log text failed', tag, text_string, global_step, 'types:', type(tag), type(text_string), type(global_step))
    
    def close(self):
        super().close()
        try:
            self.neptune_run.stop()
        except:
            print('neptune run stop failed')

    def get_runid(self) -> str:
        return self.neptune_run["sys/id"].fetch()

    # def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
    #     super().add_scalars(main_tag, tag_scalar_dict, global_step, walltime)
    #     for tag, scalar_value in tag_scalar_dict.items():
    #         self.neptune_run[f'{main_tag}/{tag}'].log(scalar_value)

    # def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
    #     super().add_histogram(tag, values, global_step, bins, walltime, max_bins)
    #     self.neptune_run[tag].log(neptune.types.File.as_image(values))

    # def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
    #     super().add_image(tag, img_tensor, global_step, walltime, dataformats)
    #     self.neptune_run[tag].log(neptune.types.File.as_image(img_tensor))

    # def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
    #     super().add_images(tag, img_tensor, global_step, walltime, dataformats)
    #     self.neptune_run[tag].log(neptune.types.File.as_image(img_tensor))

    # def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
    #     super().add_figure(tag, figure, global_step, close, walltime)
    #     self.neptune_run[tag].log(neptune.types.File.as_image(figure))

    # def add_graph(self, model, input_to_model=None, verbose=False):
    #     super().add_graph(model, input_to_model, verbose)
    #     self
