import os, sys, time
import numpy as np
import matplotlib
# matplotlib.use('agg')
import matplotlib.pyplot as plt
import random

import torch
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm

class AverageMeter(object):
  """Computes and stores the average and current value"""

  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count
    

def time_string():
  ISOTIMEFORMAT='%Y-%m-%d %X'
  string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
  return string

def convert_secs2time(epoch_time):
  need_hour = int(epoch_time / 3600)
  need_mins = int((epoch_time - 3600*need_hour) / 60)
  need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
  return need_hour, need_mins, need_secs

def time_file_str():
  ISOTIMEFORMAT='%Y-%m-%d'
  string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
  return string + '-{}'.format(random.randint(1, 10000))

def timing(f):
    def wrap(*args):
        time1 = time.time()
        ret = f(*args)
        time2 = time.time()
        print ('%s function took %0.3f ms' % (f.__name__, (time2-time1)*1000.0))
        return ret
    return wrap

def print_log(print_string, log):
  print("{:}".format(print_string))
  sys.stdout.flush()
  log.write('{:}\n'.format(print_string))
  log.flush()

def get_balanced_subset(data_path, split, n_sample_per_pabel, transform):
  batch_size = 128
  dataset = datasets.ImageNet(
    data_path,
    split=split,
    transform = transform
  )
  loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
  indices_all_labels = [[] for _ in range(1000)]
  for i, (batch, target) in enumerate(tqdm(loader)):
    for j in range(batch_size):
      label = int(target[j])
      indices_all_labels[label].append(i*batch_size+j)
  indices = []
  for i in range(1000):
    indices += random.sample(indices_all_labels[i], n_sample_per_pabel)
  random.shuffle(indices)
  subset_dataset = torch.utils.data.Subset(dataset, indices)
  print(f"sample {len(indices)} images totally.")
  return subset_dataset