"""PyTorch utils

Adopted and modded from
* yolov5/utils/torch_utils.py
* fvcore/fvcore/transforms/transform_util.py

Notes:
* parameters and return types are removed because as we are using it as cross AI framework functionalities, it ss not guratted for torch, torchvision modules to be installed. This will throw error if we put the try/except for modules. This is a consious tradeoff and will be taken care in the platform MLOps process. Alternatively, type checkers can be put in the function comments header for document parsers.
* model summary and reporting added
"""
__author__ = 'XYZ'


# import pdb
import math
import json
import os
import re
import psutil
import time

from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from typing import Dict, Union

import pandas as pd
import numpy as np

from ..core._log_ import logger
log = logger(__file__)

try:
  import torch
  import torch.backends.cudnn as cudnn
  import torch.nn as nn
  import torch.nn.functional as F
except ImportError:
  log.warning('torch is not installed')


try:
  import torchvision
except ImportError:
  log.warning('torchvision is not installed')


try:
  ## for FLOPS computation
  import thop
except ImportError:
  thop = None
  log.warning('thop is not installed. Required for FLOPS computation functionality.')


def log_mem_usage(prefix=""):
  ## Show physical RAM and GPU usage
  vm = psutil.virtual_memory()
  log.info(f"{prefix}RAM: {vm.used // (1024 ** 2)}MB used / {vm.total // (1024 ** 2)}MB total")
  if torch.cuda.is_available():
    torch.cuda.synchronize()
    log.info(f"{prefix}CUDA: {torch.cuda.memory_allocated() // (1024 ** 2)}MB allocated, "
             f"{torch.cuda.memory_reserved() // (1024 ** 2)}MB reserved")

def compute_capability(args):
  if torch.cuda.is_available():
    device = torch.device("cuda")
    compute_capability = torch.cuda.get_device_capability(device)
    log.info(f"Compute Capability: {compute_capability[0]}.{compute_capability[1]}")
  else:
    log.info("CUDA is not available!")


def select_device(device='cpu'):
  return torch.device(
      'cpu'
      if not device
      else 'cpu'
      if device == 'cpu'
      else 'cuda:0'
      if torch.cuda.is_available()
      else 'cpu',
  )
  # return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def _select_device(device='', batch_size=None):
  # device = 'cpu' or '0' or '0,1,2,3'
  s = f'Using torch {torch.__version__} '  # string
  cpu = device.lower() == 'cpu'
  if cpu:
    os.environ[
        'CUDA_VISIBLE_DEVICES'
    ] = '-1'  # force torch.cuda.is_available() = False
  elif device:  # non-cpu device requested
    os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable
    assert (
        torch.cuda.is_available()
    ), f'CUDA unavailable, invalid device {device} requested'  # check availability

  cuda = torch.cuda.is_available() and not cpu
  if cuda:
    n = torch.cuda.device_count()
    if (
        n > 1 and batch_size
    ):  # check that batch_size is compatible with device_count
      assert (
          batch_size % n == 0
      ), f'batch-size {batch_size} not multiple of GPU count {n}'
    space = ' ' * len(s)
    for i, d in enumerate(device.split(',') if device else range(n)):
      p = torch.cuda.get_device_properties(i)
      # bytes to MB
      s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n"
  else:
    s += 'CPU'

  log.info(f'{s}\n')  # skip a line
  return torch.device('cuda:0' if cuda else 'cpu')


# pyre-ignore-all-errors
def to_float_tensor(numpy_array: np.ndarray):
  # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
  """
  Convert the numpy array to torch float tensor with dimension of NxCxHxW.
  Pytorch is not fully supporting uint8, so convert tensor to float if the
  numpy_array is uint8.
  Args:
      numpy_array (ndarray): of shape NxHxWxC, or HxWxC or HxW to
          represent an image. The array can be of type uint8 in range
          [0, 255], or floating point in range [0, 1] or [0, 255].
  Returns:
      float_tensor (tensor): converted float tensor.
       -> torch.Tensor
  """
  assert isinstance(numpy_array, np.ndarray)
  assert len(numpy_array.shape) in (2, 3, 4)

  # Some of the input numpy array has negative strides. Pytorch currently
  # does not support negative strides, perform ascontiguousarray to
  # resolve the issue.
  float_tensor = torch.from_numpy(np.ascontiguousarray(numpy_array))
  if numpy_array.dtype in (np.uint8, np.int32, np.int64):
    float_tensor = float_tensor.float()

  if len(numpy_array.shape) == 2:
    # HxW -> 1x1xHxW.
    float_tensor = float_tensor[None, None, :, :]
  elif len(numpy_array.shape) == 3:
    # HxWxC -> 1xCxHxW.
    float_tensor = float_tensor.permute(2, 0, 1)
    float_tensor = float_tensor[None, :, :, :]
  elif len(numpy_array.shape) == 4:
    # NxHxWxC -> NxCxHxW
    float_tensor = float_tensor.permute(0, 3, 1, 2)
  else:
    raise NotImplementedError(
        f'Unknow numpy_array dimension of {float_tensor.shape}',
    )
  return float_tensor


def to_numpy(
    float_tensor,
    target_shape: list,
    target_dtype: np.dtype,
) -> np.ndarray:
  # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
  """
  Convert float tensor with dimension of NxCxHxW back to numpy array.
  Args:
      float_tensor (tensor): a float pytorch tensor with shape of NxCxHxW.
      target_shape (list): the target shape of the numpy array to represent
          the image as output. options include NxHxWxC, or HxWxC or HxW.
      target_dtype (dtype): the target dtype of the numpy array to represent
          the image as output. The array can be of type uint8 in range
          [0, 255], or floating point in range [0, 1] or [0, 255].
  Returns:
      (ndarray): converted numpy array.
  """
  assert len(target_shape) in (2, 3, 4)

  if len(target_shape) == 2:
    # 1x1xHxW -> HxW.
    assert float_tensor.shape[0] == 1
    assert float_tensor.shape[1] == 1
    float_tensor = float_tensor[0, 0, :, :]
  elif len(target_shape) == 3:
    assert float_tensor.shape[0] == 1
    # 1xCxHxW -> HxWxC.
    float_tensor = float_tensor[0].permute(1, 2, 0)
  elif len(target_shape) == 4:
    # NxCxHxW -> NxHxWxC
    float_tensor = float_tensor.permute(0, 2, 3, 1)
  else:
    raise NotImplementedError(
        f'Unknow target shape dimension of {target_shape}',
    )
  if target_dtype == np.uint8:
    # Need to specifically call round here, notice in pytroch the round
    # is half to even.
    # https://github.com/pytorch/pytorch/issues/16498
    float_tensor = float_tensor.round().byte()
  return float_tensor.numpy()


@contextmanager
def torch_distributed_zero_first(local_rank: int):
  """Decorator to make all processes in distributed training wait for each
  local_master to do something."""
  if local_rank not in [-1, 0]:
    torch.distributed.barrier()
  yield
  if local_rank == 0:
    torch.distributed.barrier()


def init_torch_seeds(seed=0):
  # Speed-reproducibility tradeoff
  # https://pytorch.org/docs/stable/notes/randomness.html
  torch.manual_seed(seed)
  if seed == 0:  # slower, more reproducible
    cudnn.benchmark, cudnn.deterministic = False, True
  else:  # faster, less reproducible
    cudnn.benchmark, cudnn.deterministic = True, False


def time_synchronized():
  # pytorch-accurate time
  if torch.cuda.is_available():
    torch.cuda.synchronize()
  return time.time()


def profile(x, ops, n=100, device=None):
  # profile a pytorch module or list of modules. Example usage:
  #     x = torch.randn(16, 3, 640, 640)  # input
  #     m1 = lambda x: x * torch.sigmoid(x)
  #     m2 = nn.SiLU()
  #     profile(x, [m1, m2], n=100)  # profile speed over 100 iterations

  device = device or torch.device(
      'cuda:0' if torch.cuda.is_available() else 'cpu',
  )
  x = x.to(device)
  x.requires_grad = True
  print(
      torch.__version__,
      device.type,
      torch.cuda.get_device_properties(0) if device.type == 'cuda' else '',
  )
  print(
      f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}",
  )
  for m in ops if isinstance(ops, list) else [ops]:
    m = m.to(device) if hasattr(m, 'to') else m  # device
    m = (
        m.half()
        if hasattr(m, 'half') and
        isinstance(x, torch.Tensor) and
        x.dtype is torch.float16
        else m
    )  # type
    dtf, dtb, t = 0.0, 0.0, [0.0, 0.0, 0.0]  # dt forward, backward
    try:
      flops = (
          thop.profile(m, inputs=(x,), verbose=False)[0] / 1e9 * 2
      )  # GFLOPS
    except BaseException:
      flops = 0

    for _ in range(n):
      t[0] = time_synchronized()
      y = m(x)
      t[1] = time_synchronized()
      try:
        _ = y.sum().backward()
        t[2] = time_synchronized()
      except BaseException:  # no backward method
        t[2] = float('nan')
      dtf += (t[1] - t[0]) * 1000 / n  # ms per op forward
      dtb += (t[2] - t[1]) * 1000 / n  # ms per op backward

    s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
    s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
    p = (
        sum(list(x.numel() for x in m.parameters()))
        if isinstance(m, nn.Module)
        else 0
    )  # parameters
    print(
        f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}',
    )


def is_parallel(model):
  return type(model) in (
      nn.parallel.DataParallel,
      nn.parallel.DistributedDataParallel,
  )


def intersect_dicts(da, db, exclude=()):
  # Dictionary intersection of matching keys and shapes, omitting 'exclude'
  # keys, using da values
  return {
      k: v
      for k, v in da.items()
      if k in db and
      not any(x in k for x in exclude) and
      v.shape == db[k].shape
  }


def initialize_weights(model):
  for m in model.modules():
    t = type(m)
    if t is nn.Conv2d:
      # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      pass
    elif t is nn.BatchNorm2d:
      m.eps = 1e-3
      m.momentum = 0.03
    elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
      m.inplace = True


def find_modules(model, mclass):
  """Finds layer indices matching module class 'mclass' of instance
  `nn.Conv2d`."""
  return [
      i for i, m in enumerate(model.module_list) if isinstance(m, mclass)
  ]


def sparsity(model):
  # Return global model sparsity
  a, b = 0.0, 0.0
  for p in model.parameters():
    a += p.numel()
    b += (p == 0).sum()
  return b / a


def prune(model, amount=0.3):
  # Prune model to requested global sparsity
  import torch.nn.utils.prune as prune

  print('Pruning model... ', end='')
  for name, m in model.named_modules():
    if isinstance(m, nn.Conv2d):
      prune.l1_unstructured(m, name='weight', amount=amount)  # prune
      prune.remove(m, 'weight')  # make permanent
  print(' %.3g global sparsity' % sparsity(model))


def fuse_conv_and_bn(conv, bn):
  # Fuse convolution and batchnorm layers
  # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  fusedconv = (
      nn.Conv2d(
          conv.in_channels,
          conv.out_channels,
          kernel_size=conv.kernel_size,
          stride=conv.stride,
          padding=conv.padding,
          groups=conv.groups,
          bias=True,
      )
      .requires_grad_(False)
      .to(conv.weight.device)
  )

  # prepare filters
  w_conv = conv.weight.clone().view(conv.out_channels, -1)
  w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  fusedconv.weight.copy_(
      torch.mm(w_bn, w_conv).view(fusedconv.weight.size()),
  )

  # prepare spatial bias
  b_conv = (
      torch.zeros(conv.weight.size(0), device=conv.weight.device)
      if conv.bias is None
      else conv.bias
  )
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
      torch.sqrt(bn.running_var + bn.eps),
  )
  fusedconv.bias.copy_(
      torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn,
  )

  return fusedconv


def model_info(model, verbose=False, img_size=640):
  # Model information. img_size may be int or list, i.e. img_size=640 or
  # img_size=[640, 320]
  n_p = sum(x.numel() for x in model.parameters())  # number parameters
  n_g = sum(
      x.numel() for x in model.parameters() if x.requires_grad
  )  # number gradients
  if verbose:
    print(
        '%5s %40s %9s %12s %20s %10s %10s'
        % (
            'layer',
            'name',
            'gradient',
            'parameters',
            'shape',
            'mu',
            'sigma',
        ),
    )
    for i, (name, p) in enumerate(model.named_parameters()):
      name = name.replace('module_list.', '')
      print(
          '%5g %40s %9s %12g %20s %10.3g %10.3g'
          % (
              i,
              name,
              p.requires_grad,
              p.numel(),
              list(p.shape),
              p.mean(),
              p.std(),
          ),
      )

  try:  # FLOPS
    from thop import profile

    stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
    img = torch.zeros(
        (1, model.yaml.get('ch', 3), stride, stride),
        device=next(model.parameters()).device,
    )  # input
    flops = (
        profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1e9 * 2
    )  # stride GFLOPS
    img_size = (
        img_size if isinstance(img_size, list) else [img_size, img_size]
    )  # expand if int/float
    fs = ', %.1f GFLOPS' % (
        flops * img_size[0] / stride * img_size[1] / stride
    )  # 640x640 GFLOPS
  except (ImportError, Exception):
    fs = ''

  log.info(
      f'Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}',
  )


def load_classifier(name='resnet101', n=2):
  modelzoocfg = {
    "pretrained": True,
    "nets": [
      'resnet18',
      'resnet34',
      'resnet50',
      'resnet101',
      'resnet152',
      'wide_resnet50_2',
      'wide_resnet101_2',
      'vgg19_bn',
      'vgg19',
      'vgg16_bn',
      'vgg16',
      'vgg13_bn',
      'vgg13',
      'vgg11_bn',
      'vgg11',
      'densenet121',
      'densenet161',
      'densenet169',
      'densenet201',
      'inception_v3',
      'mobilenet_v2',
      'mobilenet_v3_large',
      'mobilenet_v3_small',
      'shufflenet_v2_x1_0',
      'squeezenet1_0',
      'efficientnet_b0'
    ],
    "params": {
      'inception_v3': {
        'aux_logits':False
      }
    },
    "modelzoo": {
      "classifier_zoo": [
        "dense121",
        "dense161",
        "dense201", 
        "ghost1_0"
      ],
      "vgg_zoo": [
        "vgg19_bn",
        "vgg16_bn"
      ],
      "inception_zoo": [
        "inception_v3",
        "inception_v4"
      ],
      "mobile_zoo": [
        "mobilenet_v2",
        "mobilenet_v3_large",
        "mobilenet_v3_small",
        "efficientnet_b0"
      ],
      "squeeze_zoo": [
        "squeezenet1_0",
        "squeezenet1_1"
      ]
    }
  }

  # params = dnnarchs[name]
  ## Loads a pretrained model reshaped to n-class output
  model = torchvision.models.__dict__[name](pretrained=True)

  # ResNet model properties
  # input_size = [3, 224, 224]
  # input_space = 'RGB'
  # input_range = [0, 1]
  # mean = [0.485, 0.456, 0.406]
  # std = [0.229, 0.224, 0.225]

  # Reshape output to n classes
  filters = model.fc.weight.shape[1]
  model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  model.fc.out_features = n
  return model

def modify_output(args, net):
  """modify classifier output"""
  classifier_zoo = ['dense121', 'dense161', 'dense201', 'ghost1_0']
  vgg_zoo = ['vgg19_bn','vgg16_bn']
  Inception_zoo = ['inception_v3', 'inception_v4']
  mobile_zoo = [ 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small','efficientnet_b0']
  squeeze_zoo = ['squeezenet1_0', 'squeezenet1_1']

  if args.net in classifier_zoo:
    channel_in = net.classifier.in_features
    net.classifier = nn.Linear(channel_in, args.num_class)
  elif args.net in mobile_zoo:
     # print('0000000000')
     # print(net.classifier[-1])
    channel = net.classifier[-1].in_features
    net.classifier[-1] = nn.Linear(channel, args.num_class)
  elif args.net in Inception_zoo :
    channel_in = net.linear.in_features
    net.fc = nn.Linear(channel_in, args.num_class)
  elif args.net in vgg_zoo:
    net.classifier[-1] = nn.Linear(4096, args.num_class)
  elif args.net in squeeze_zoo:
    net.classifier[1] = nn.Conv2d(512, args.num_class, kernel_size=1)
  else:
    channel_in = net.fc.in_features
    net.fc = nn.Linear(channel_in, args.num_class)
  return net


def scale_img(img, ratio=1.0, same_shape=False, gs=32):  # img(16,3,256,416)
  # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  if ratio == 1.0:
    return img
  else:
    h, w = img.shape[2:]
    s = (int(h * ratio), int(w * ratio))  # new size
    img = F.interpolate(
        img,
        size=s,
        mode='bilinear',
        align_corners=False,
    )  # resize
    if not same_shape:  # pad/crop img
      h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
    return F.pad(
        img,
        [0, w - s[1], 0, h - s[0]],
        value=0.447,
    )  # value = imagenet mean


def copy_attr(a, b, include=(), exclude=()):
  # Copy attributes from b to a, options to only include [...] and to
  # exclude [...]
  for k, v in b.__dict__.items():
    if (
        (len(include) and k not in include) or
        k.startswith('_') or
        k in exclude
    ):
      continue
    else:
      setattr(a, k, v)


class ModelEMA:
  """Model Exponential Moving Average from
  https://github.com/rwightman/pytorch-image-models Keep a moving average of
  everything in the model state_dict (parameters and buffers).

  This is intended to allow functionality like
  https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  A smoothed version of the weights is necessary for some training schemes to perform well.
  This class is sensitive where it is initialized in the sequence of model init,
  GPU assignment and distributed training wrappers.
  """

  def __init__(self, model, decay=0.9999, updates=0):
    # Create EMA
    self.ema = deepcopy(
        model.module if is_parallel(model) else model,
    ).eval()  # FP32 EMA
    # if next(model.parameters()).device.type != 'cpu':
    #     self.ema.half()  # FP16 EMA
    self.updates = updates  # number of EMA updates
    self.decay = lambda x: decay * (
        1 - math.exp(-x / 2000)
    )  # decay exponential ramp (to help early epochs)
    for p in self.ema.parameters():
      p.requires_grad_(False)

  def update(self, model):
    # Update EMA parameters
    with torch.no_grad():
      self.updates += 1
      d = self.decay(self.updates)

      msd = (
          model.module.state_dict()
          if is_parallel(model)
          else model.state_dict()
      )  # model state_dict
      for k, v in self.ema.state_dict().items():
        if v.dtype.is_floating_point:
          v *= d
          v += (1.0 - d) * msd[k].detach()

  def update_attr(
      self,
      model,
      include=(),
      exclude=('process_group', 'reducer'),
  ):
    # Update EMA attributes
    copy_attr(self.ema, model, include, exclude)



def time_sync():
  ## PyTorch-accurate time
  if torch.cuda.is_available():
    torch.cuda.synchronize()
  return time.time()


class TracedModel(nn.Module):
  def __init__(self, model=None, device=None, img_size=(640,640)): 
    super(TracedModel, self).__init__()
    print(" Convert model to Traced-model... ") 
    self.stride = model.stride
    self.names = model.names
    self.model = model

    self.model = revert_sync_batchnorm(self.model)
    self.model.to('cpu')
    self.model.eval()

    self.detect_layer = self.model.model[-1]
    self.model.traced = True
    
    rand_example = torch.rand(1, 3, img_size, img_size)
    
    traced_script_module = torch.jit.trace(self.model, rand_example, strict=False)
    #traced_script_module = torch.jit.script(self.model)
    traced_script_module.save("traced_model.pt")
    print(" traced_script_module saved! ")
    self.model = traced_script_module
    self.model.to(device)
    self.detect_layer.to(device)
    print(" model is traced! \n") 

  def forward(self, x, augment=False, profile=False):
    out = self.model(x)
    out = self.detect_layer(out)
    return