from torch.utils.tensorboard import SummaryWriter as TSW
import torch
import subprocess

class CustomSummaryWriter:
  def __init__(self, savedir, bucket):
    self.localdir = f'runs/{savedir}'
    self.bucketdir = f'gs://{bucket}/{self.localdir}'
    self.tsw = TSW(self.localdir)
    self.stat = dict()

  def add_data(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
    if main_tag not in self.stat:
      self.stat[main_tag]=dict()
    if global_step not in self.stat[main_tag]:
      self.stat[main_tag][global_step]=dict()
    for key, value in tag_scalar_dict.items():
      self.stat[main_tag][global_step][key] = value
    torch.save(self.stat, f'{self.localdir}/stats.pth')
    self.sshproc = subprocess.Popen([f'gsutil rsync -r {self.localdir} {self.bucketdir}'], 
                                    shell=True, 
                                    stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    self.sshproc.wait()

  def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
      self.tsw.add_scalars(main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)
      self.add_data(main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)

  def close(self):
    self.tsw.close()
    if self.sshproc is not None:
      self.sshproc.wait()
