from torch.utils.tensorboard import SummaryWriter
import torch
import os
import time
import shutil
from datetime import datetime
from .utils import run_cmd


def mk_alldir(dirpath):
  dir_list = dirpath.split('/')
  path = ''
  for d in dir_list:
    path = os.path.join(path, d)
    if not os.path.exists(path):
      os.mkdir(path)


class Recorder(SummaryWriter):
  """
  Recorder is a class for saving and loading stats and big files (such as model checkpoints) to both local directory and the bucket.
  """
  def __init__(self, out_dir=None, bucket_dir="bucket2", comment='', max_queue=10, flush_secs=120, upload_secs=120, save_ckpt_freq=10, sep_ckpt_freq=1, upload_ckpt_freq=1):
    """
    Arguments:
    out_dir: name of the folder for saving the output of the given experiment to be used both locally and in the bucket
    bucket_dir: path to the bucket directory. This is typically the directory that includes all experiments related to a project.
    comment: suffix appended to the log_dir
    max_queue: size of the queue for the light objects (that do not need much memory) before saving the stat in the file
    flush_secs: how often the light objects will be saved in the stat file
    upload_secs: how often the stat file of the light objects will be uploaded to the bucket
    save_ckpt_freq: frequency of saving checkpoints
    sep_ckpt_freq: Every sep_ckpt_freq * save_ckpt_freq checkpoint will also be saved separately on the bucket
    upload_ckpt_freq: frequency of uploading saved checkpoints (every ep_ckpt_freq * save_ckpt_freq) to the bucket.
    """

    self.out_dir = out_dir
    if self.out_dir is None:
      now = datetime.now()
      self.out_dir = f"{os.uname()[1].split('.')[0]}/"
      self.out_dir+= now.strftime("%Y/%m/%d/%H/%M/%S")
    if len(comment) > 0:
      self.out_dir+= f'-{comment}'

    self.bucket_dir = os.environ.get('BUCKET_DIR') if bucket_dir is None else bucket_dir

    if self.bucket_dir is None:
      print('No bucket_dir provided. The logs and checkpoints will be saved locally!')
    else:
      print(f'Logs and checkpoints will be loaded to bucket gs://{self.bucket_dir}')

    print(f'Output will be saved to: {self.out_dir}', flush=True)

    self.run_local = f'runs/{self.out_dir}'
    self.run_bucket = f'gs://{self.bucket_dir}/runs/{self.out_dir}'
    self.ckpt_local = f'checkpoints/{self.out_dir}'
    self.ckpt_bucket = f'gs://{self.bucket_dir}/checkpoints/{self.out_dir}'
    self.stat = {}
    self.max_queue = max_queue
    self.flush_secs = flush_secs
    self.upload_secs = upload_secs
    self.queue_size = 0
    self.last_flush = time.time()
    self.last_upload = time.time()
    self.save_ckpt_freq = save_ckpt_freq
    self.sep_ckpt_freq = sep_ckpt_freq
    self.upload_ckpt_freq = upload_ckpt_freq
    self.best_acc = 0

    mk_alldir(self.run_local)
    mk_alldir(self.ckpt_local)

    print(f'\n>> Attempting to download previous run files from {self.run_bucket} ...')
    self.sp = run_cmd(f'gsutil -m cp {self.run_bucket}/* {self.run_local}')
    if self.sp.wait() == 0:
      print(f'--> Donwloaded previous run files from {self.run_bucket}!')
    else:
      print(f'--> No data found in {self.run_bucket}!')

    print(f'\n>> Attempting to load previous run stat.pt from {self.run_local} ...')
    if os.path.exists(f'{self.run_local}/stat.pt'):
      self.stat= torch.load(f'{self.run_local}/stat.pt')
      print(f'--> Loaded previous run stat.pt from {self.run_local}!')
    else:
      self.stat={}
      print(f'--> No previous stat.pt found in {self.run_local}!\nInitializing stat as empty dict.\n')

    # calls SummaryWriter who saves TB event files to its log_dir
    super().__init__(log_dir=self.run_local, max_queue=max_queue, flush_secs=flush_secs)


  def add_objects(self, main_tag, tag_obj_dict, global_step=None, walltime=None):
    """
    This function can be used for saving a dict of light objects in a similar way to tensorboard.
    For heavy objects, such as model checkpoints, use save_checkpoint
    """
    if main_tag not in self.stat:
      self.stat[main_tag] = {}
    for key, value in tag_obj_dict.items():
      if key not in self.stat[main_tag]:
        self.stat[main_tag][key] = {'step':[],'time':[], 'object':[]}
      self.stat[main_tag][key]['step'].append(global_step)
      self.stat[main_tag][key]['time'].append(walltime)
      self.stat[main_tag][key]['object'].append(value)
      self.queue_size += 1

    if self.queue_size >= self.max_queue or time.time() - self.last_flush >= self.flush_secs:
      torch.save(self.stat, f'{self.run_local}/stat.pt')
      if time.time() - self.last_upload >= self.upload_secs:
        print('uploading now')
        self.sp = run_cmd(f'gsutil -m cp {self.run_local}/* {self.run_bucket}', prev_sp=self.sp)
        self.last_upload = time.time()
      self.queue_size = 0
      self.last_flush = time.time()


  def add_object(self, tag, obj, global_step=None, walltime=None):
    """
    This function can be used for saving a light object in a similar way to tensorboard
    """
    self.add_objects(tag, {'value': obj}, global_step=global_step, walltime=walltime)


  def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
    """
    This functions extends tensorboards add_scalars function to objects
    """
    super().add_scalars(main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)
    self.add_objects(main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)

  def add_scalar(self, tag, scalar, global_step=None, walltime=None):
    """
    This functions extends tensorboards add_scalar function to objects
    """
    super().add_scalar(tag, scalar, global_step=global_step, walltime=walltime)
    self.add_object(tag, scalar, global_step=global_step, walltime=walltime)

  def add_losses(self, tr_loss, tr_acc1, tr_acc5, val_loss, val_acc1, val_acc5, global_step=None, walltime=None):
    self.add_scalar('Loss/train', tr_loss, global_step=global_step, walltime=walltime)
    self.add_scalar('Loss/test', val_loss, global_step=global_step, walltime=walltime)
    self.add_scalar('Top1Accuracy/train', tr_acc1, global_step=global_step, walltime=walltime)
    self.add_scalar('Top1Accuracy/test', val_acc1, global_step=global_step, walltime=walltime)
    self.add_scalar('Top5Accuracy/train', tr_acc5, global_step=global_step, walltime=walltime)
    self.add_scalar('Top5Accuracy/test', val_acc5, global_step=global_step, walltime=walltime)
    
  def add_losses_ind(self, tr_loss_ind,val_loss_ind,global_step=None, walltime=None):
    self.add_object('IndLoss/train', tr_loss_ind, global_step=global_step, walltime=walltime)
    self.add_object('IndLoss/test', val_loss_ind, global_step=global_step, walltime=walltime)
#


    
  def get_object(self, main_tag, key=None):
    """
    Returning an object given the main tag and the key
    """
    return self.stat[main_tag][key]


  def save_checkpoint(self, state_dict, prefix='model', global_step=None, is_best=False):

    upload_cmd='gsutil -o GSUtil:parallel_composite_upload_threshold=150M cp'

    if global_step==0: # init run
      mpath= f'{self.ckpt_local}/{prefix}_init.pt'
      torch.save(state_dict, mpath)
      dest= f'{self.ckpt_bucket}/{prefix}_init.pt'
      self.sp= run_cmd(f'{upload_cmd} {mpath} {dest}', prev_sp=self.sp)
      smask_path= f'{self.ckpt_local}/smask.pt'
      smask_dest= f'{self.ckpt_bucket}/smask.pt'
      self.sp= run_cmd(f'{upload_cmd} {smask_path} {smask_dest}', prev_sp=self.sp)
      
    elif global_step is None or global_step%self.save_ckpt_freq == 0:

      mpath_last= f'{self.ckpt_local}/{prefix}_last.pt'
      mpath_best= f'{self.ckpt_local}/{prefix}_best.pt'
      torch.save(state_dict, mpath_last)

      if is_best: shutil.copyfile(mpath_last, mpath_best)

      if global_step is None or global_step%(self.save_ckpt_freq * self.upload_ckpt_freq) == 0:
        dest= f'{self.ckpt_bucket}/{prefix}_last.pt'
        self.sp= run_cmd(f'{upload_cmd} {mpath_last} {dest}', prev_sp=self.sp)
        if os.path.exists(mpath_best):
          dest_best= f'{self.ckpt_bucket}/{prefix}_best.pt'
          self.sp= run_cmd(f'{upload_cmd} {mpath_best} {dest_best}')
      elif global_step%(self.save_ckpt_freq * self.sep_ckpt_freq) == 0:
        dest= f'{self.ckpt_bucket}/{prefix}_{global_step}.pt'
        self.sp= run_cmd(f'{upload_cmd} {mpath_last} {dest}', prev_sp=self.sp)


  def load_checkpoint(self, prefix='model', global_step=None, log_dir=None):
    """
    This function loads checkpoint (heavy) from local directory or from bucket
    """

    ckpt_dir= f'checkpoints/{self.out_dir}' if log_dir is None else f'checkpoints/{log_dir}'
    global_step= 'last' if global_step is None else global_step

    if global_step==0:
      dest= f'{ckpt_dir}/{prefix}_init.pt'
      src = f'gs://{self.bucket_dir}/{ckpt_dir}/{prefix}_init.pt'
    elif global_step=='best':
      dest= f'{ckpt_dir}/{prefix}_{global_step}.pt'
      src = f'gs://{self.bucket_dir}/{ckpt_dir}/{prefix}_{global_step}.pt'     
    else:
      dest= f'{ckpt_dir}/{prefix}_{global_step}.pt'
      src = f'gs://{self.bucket_dir}/{ckpt_dir}/{prefix}_last.pt'

    if not os.path.exists(dest):
      print(f'Path {dest} does not exist!\n>> Attempting to download it from bucket ...')
      self.sp = run_cmd(f'gsutil cp {src} {dest}', prev_sp=self.sp)
      if self.sp.wait() == 0:
        print(f'--> Checkpoint downloaded from {src} to {dest}')
      else:
        print('--> Uh-oh: attempt failed.')

    if os.path.exists(dest):
        ckpt= torch.load(dest)
        print(f'--> Checkpoint loaded from {dest}!')
    else:
        print(f'--> No Checkpoint was found in {dest} or {src}!')
        ckpt= None
    return ckpt


  def save_full_checkpoint(self, model, optimizer, scheduler, args, epoch, acc):
    is_best= acc > self.best_acc
    self.best_acc= max(acc, self.best_acc)
    self.save_checkpoint({
                'epoch': epoch,
                'args': args,
                'best_acc': self.best_acc,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
    }, global_step=epoch, is_best=is_best)


  def resume_full_checkpoint(self, resume, model, optimizer, scheduler):
    if resume:
      if resume=='init':
        print(resume)
        full_dict= self.load_checkpoint(global_step=0)
      elif resume=='best':
        print(resume)
        full_dict= self.load_checkpoint(global_step='best')
      elif resume=='previous':
        full_dict= self.load_checkpoint()
      else:
        full_dict= self.load_checkpoint(log_dir=resume)

      model.load_state_dict(full_dict['state_dict'])
      optimizer.load_state_dict(full_dict['optimizer'])
      scheduler.load_state_dict(full_dict['scheduler'])
      self.best_acc= full_dict['best_acc']
      print(f'best_acc from model checkpoint is {self.best_acc} and has format: {type(self.best_acc)}')
      print(f'and epoch is = {full_dict["epoch"]}')
      return full_dict['epoch']
    return 0

  def close(self):
    super().flush()
    super().close()
    torch.save(self.stat, f'{self.run_local}/stat.pt')
    self.sp = run_cmd(f'gsutil -m cp {self.run_local}/* {self.run_bucket}', prev_sp=self.sp)
    self.sp.wait()
