"""utlitiy functions
Adapted from: https://github.com/Shenqishaonv/100-Driver-Source

@author: wenjing
@article{100-Driver,
author    = {Wang Jing, Li Wengjing, Li Fang, Zhang Jun, Wu Zhongcheng, Zhong Zhun and Sebe Nicu},
title     = {100-Driver: A Large-scale, Diverse Dataset for Distracted Driver Classification},
journal={IEEE Transactions on Intelligent Transportation Systems},
year      = {2023}
publisher={IEEE}}
"""
__author__ = 'XYZ'


import os
import sys
import re
import datetime

import numpy

try:
  import torch
  import torch.nn as nn

  from torch.optim.lr_scheduler import _LRScheduler

  # import torch.backends.cudnn as cudnn
  # import torch.nn.functional as F
except ImportError:
  print('torch is not installed')


try:
  import torchvision
  import torchvision.models as model
  import torchvision.transforms as transforms
except ImportError:
  print('torchvision is not installed')


from .core._log_ import logger
log = logger(__file__)

from .models.ghostnet import ghostnet


class WarmUpLR(_LRScheduler):
  """warmup_training learning rate scheduler
  Args:
      optimizer: optimzier(e.g. SGD)
      total_iters: totoal_iters of warmup phase
  """
  def __init__(self, optimizer, total_iters, last_epoch=-1):

      self.total_iters = total_iters
      super().__init__(optimizer, last_epoch)

  def get_lr(self):
      """we will use the first m batches, and set the learning
      rate to base_lr * m / total_iters
      """
      return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]



def compute_mean_std(dataset):
  """compute the mean and std of the input dataset
  Args:
      training_dataset which derived from class torch.utils.data

  Returns:
      a tuple contains mean, std value of entire dataset
  """

  data_r = numpy.dstack([dataset[i][1][:, :, 0] for i in range(len(dataset))])
  data_g = numpy.dstack([dataset[i][1][:, :, 1] for i in range(len(dataset))])
  data_b = numpy.dstack([dataset[i][1][:, :, 2] for i in range(len(dataset))])
  mean = numpy.mean(data_r), numpy.mean(data_g), numpy.mean(data_b)
  std = numpy.std(data_r), numpy.std(data_g), numpy.std(data_b)

  return mean, std

def get_mean_std(dataset):
  """compute the mean and std of cifar100 dataset
  Args:
      dataset: the name of the dataset

  Returns:
      a list contains mean, std value of entire dataset which is obtained by the function compute_mean_std

  """
  mean, std = 0.0, 0.0

  day = [
    'statefarm',
    'aucv2-camera1','aucv2-camera2', 'aucv1',
    '3mdad-day','3mdad-cam1','3mdad-cam2','3mdad-cam1-wash','3mdad-cam2-wash',
    'pic-day-all',
    'pic-day-cam1','pic-day-cam2','pic-day-cam3','pic-day-cam4',
    'pic-xiandai-cam1','pic-xiandai-cam2','pic-xiandai-cam3','pic-xiandai-cam4',
    'pic-day-car-cam1','pic-day-car-cam2','pic-day-car-cam3','pic-day-car-cam4',
    '100-driver-day-cam1','100-driver-day-cam2','100-driver-day-cam3','100-driver-day-cam4'
  ]

  mdad_night = [
    '3mdad-cam1-night','3mdad-cam2-night',
  ]

  pic_night = [
    'pic-night-all',
    'pic-night-cam1','pic-night-cam2','pic-night-cam3','pic-night-cam4',
    'pic-night-car-cam1','pic-night-car-cam2','pic-night-car-cam3','pic-night-car-cam4',
    'traditional-night-cam1','traditional-night-cam2','traditional-night-cam3','traditional-night-cam4',
    '100-driver-night-cam1','100-driver-night-cam2','100-driver-night-cam3','100-driver-night-cam4',
  ]

  if dataset in day:
      mean=[.5, .5, .5]
      std=[0.229, 0.224, 0.225]
  elif dataset in mdad_night:
      mean = [0.046468433, 0.046468433, 0.046468433]
      std = [0.051598676, 0.051598676, 0.051598676]
  elif dataset in pic_night:
      mean = [0.29414198, 0.3019768, 0.29021993]
      std = [0.24205828, 0.24205923, 0.24205303]
  else:
      print('the dataset is not available ')
  return mean, std


def most_recent_folder(net_weights, fmt):
  """
      return most recent created folder under net_weights
      if no none-empty folder were found, return empty folder
  """
  # get subfolders in net_weights
  folders = os.listdir(net_weights)

  # filter out empty folders
  folders = [f for f in folders if len(os.listdir(os.path.join(net_weights, f)))]
  if len(folders) == 0:
      return ''

  # sort folders by folder created time
  folders = sorted(folders, key=lambda f: datetime.datetime.strptime(f, fmt))
  return folders[-1]

def most_recent_weights(weights_folder):
  """
      return most recent created weights file
      if folder is empty return empty string
  """
  weight_files = os.listdir(weights_folder)
  if len(weights_folder) == 0:
      return ''

  regex_str = r'([A-Za-z0-9]+)-([0-9]+)-(regular|best)'

  # sort files by epoch
  weight_files = sorted(weight_files, key=lambda w: int(re.search(regex_str, w).groups()[1]))

  return weight_files[-1]

def last_epoch(weights_folder):
  weight_file = most_recent_weights(weights_folder)
  if not weight_file:
     raise Exception('no recent weights were found')
  resume_epoch = int(weight_file.split('-')[1])

  return resume_epoch

def best_acc_weights(weights_folder):
  """
      return the best acc .pth file in given folder, if no
      best acc weights file were found, return empty string
  """
  files = os.listdir(weights_folder)
  if len(files) == 0:
      return ''

  regex_str = r'([A-Za-z0-9]+)-([0-9]+)-(regular|best)'
  best_files = [w for w in files if re.search(regex_str, w).groups()[2] == 'best']
  if len(best_files) == 0:
      return ''

  best_files = sorted(best_files, key=lambda w: int(re.search(regex_str, w).groups()[1]))
  return best_files[-1]

def unnormalized(img):
  t_mean = torch.FloatTensor(mean).view(3,1,1).expand(3, 224, 224)
  t_std = torch.FloatTensor(std).view(3,1,1).expand(3, 224, 224)
  img = img * t_std + t_mean     # unnormalize
  img = img
  trans = transforms.ToPILImage()
  img = trans(img)
  return img
