"""torch API"""
__author__ = 'XYZ'


# import pdb
import os
import sys
import time


from PIL import Image
import numpy as np

import torch

try:
  import torch
  import torch.nn as nn
except ImportError:
  print('torch is not installed')


try:
  import torchvision
  import torchvision.models
  import torchvision.transforms as transforms

  from torchvision.models._api import get_weight
except ImportError:
  print('torchvision is not installed')

"""
Example:
    >>> torch.load('tensors.pt')
    # Load all tensors onto the CPU
    >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
    # Load all tensors onto the CPU, using a function
    >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
    # Load all tensors onto GPU 1
    >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
    # Map tensors from GPU 1 to GPU 0
    >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
    # Load tensor from io.BytesIO object
    >>> with open('tensor.pt', 'rb') as f:
    ...     buffer = io.BytesIO(f.read())
    >>> torch.load(buffer)
    # Load a module with 'ascii' encoding for unpickling
    >>> torch.load('module.pt', encoding='ascii')

    ---

    import torch
    import io

    torch.jit.load('scriptmodule.pt')

    # Load ScriptModule from io.BytesIO object
    with open('scriptmodule.pt', 'rb') as f:
        buffer = io.BytesIO(f.read())

    # Load all tensors to the original device
    torch.jit.load(buffer)

    # Load all tensors onto CPU, using a device
    buffer.seek(0)
    torch.jit.load(buffer, map_location=torch.device('cpu'))

    # Load all tensors onto CPU, using a string
    buffer.seek(0)
    torch.jit.load(buffer, map_location='cpu')

    # Load with extra files.
    extra_files = {'foo.txt': ''}  # values will be replaced with data
    torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
    print(extra_files['foo.txt'])

"""


modelzoocfg = {
  "pretrained": True,
  "nets": [
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "wide_resnet50_2",
    "wide_resnet101_2",
    "vgg19_bn",
    "vgg19",
    "vgg16_bn",
    "vgg16",
    "vgg13_bn",
    "vgg13",
    "vgg11_bn",
    "vgg11",
    "densenet121",
    "densenet161",
    "densenet169",
    "densenet201",
    "inception_v3",
    "mobilenet_v2",
    "mobilenet_v3_large",
    "mobilenet_v3_small",
    "shufflenet_v2_x1_0",
    "squeezenet1_0",
    "efficientnet_b0"
  ],
  # "target_layers": {
  #   "resnet18": "layer4",
  #   "resnet34": "layer4",
  #   "resnet50": "layer4",
  #   "resnet101": "layer4",
  #   "resnet152": "layer4",
  #   "wide_resnet50_2": "layer4",
  #   "wide_resnet101_2": "layer4",
  #   "vgg19_bn": "features",
  #   "vgg19": "features",
  #   "vgg16_bn": "features",
  #   "vgg16": "features",
  #   "vgg13_bn": "features",
  #   "vgg13": "features",
  #   "vgg11_bn": "features",
  #   "vgg11": "features",
  #   "densenet121": "features",
  #   "densenet161": "features",
  #   "densenet169": "features",
  #   "densenet201": "features",
  #   "inception_v3": "Mixed_7c",
  #   "mobilenet_v2": "features",
  #   "mobilenet_v3_large": "features",
  #   "mobilenet_v3_small": "features",
  #   "shufflenet_v2_x1_0": "conv5",
  #   "squeezenet1_0": "features",
  #   "efficientnet_b0": "features"
  # },
  "target_layers": {
    "resnet18": "layer4",
    "resnet34": "layer4",
    "resnet50": "layer4",
    "resnet101": "layer4",
    "resnet152": "layer4",
    "wide_resnet50_2": "layer4",
    "wide_resnet101_2": "layer4",
    "vgg11": "features.29",
    "vgg11_bn": "features.29",
    "vgg13": "features.29",
    "vgg13_bn": "features.29",
    "vgg16": "features.29",
    "vgg16_bn": "features.29",
    "vgg19": "features.35",
    "vgg19_bn": "features.35",
    "densenet121": "features.norm5",
    "densenet161": "features.norm5",
    "densenet169": "features.norm5",
    "densenet201": "features.norm5",
    "mobilenet_v2": "features.18",
    "mobilenet_v3_large": "features.16",
    "mobilenet_v3_small": "features.12",
    "efficientnet_b0": "features.7",
    "inception_v3": "Mixed_7c",
    "shufflenet_v2_x1_0": "conv5",
    "squeezenet1_0": "features.12"
  },
  "params": {
    "inception_v3": {
      "aux_logits":False
    }
  },
  "modelzoo": {
    "classifier_zoo": [
      "densenet121",
      "densenet161",
      "densenet169",
      "densenet201", 
      "ghost1_0"
    ],
    "resnet_zoo": [
      "resnet18",
      "resnet34",
      "resnet50",
      "resnet101",
      "resnet152",
      "wide_resnet50_2"
    ],
    "vgg_zoo": [
      "vgg19_bn",
      "vgg19",
      "vgg16_bn",
      "vgg16",
      "vgg13_bn",
      "vgg13",
      "vgg11_bn",
      "vgg11"
    ],
    "inception_zoo": [
      "inception_v3",
      "inception_v4"
    ],
    "mobile_zoo": [
      "mobilenet_v2",
      "mobilenet_v3_large",
      "mobilenet_v3_small",
      "efficientnet_b0"
    ],
    "squeeze_zoo": [
      "squeezenet1_0",
      "squeezenet1_1"
    ]
  }
}


def select_device(device='cpu', batch_size=None):
  """Selects the appropriate torch device, with default as CPU.
  Configures CUDA environment and logs device details.

  device = 'cpu' or '0' or '0,1,2,3'
  """
  ## Default to CPU if 'cpu' is specified
  cpu = device.lower() == 'cpu'
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if cpu else device  ## Force CPU or set specific CUDA devices; force torch.cuda.is_available() = False
  
  ## Check CUDA availability if not using CPU
  if not cpu and not torch.cuda.is_available():
    raise ValueError(f"CUDA unavailable, invalid device '{device}' requested")
  
  ## Determine if CUDA should be used
  cuda = torch.cuda.is_available() and not cpu
  s = f'Using torch {torch.__version__} ' + ('CUDA' if cuda else 'CPU')

  ## If using CUDA, gather device properties for logging
  if cuda:
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1 and batch_size:  # Ensure batch size is compatible with GPU count
      if batch_size % num_gpus != 0:
        raise ValueError(f"Batch size {batch_size} not divisible by GPU count {num_gpus}")
    
    for i in range(num_gpus):
      p = torch.cuda.get_device_properties(i)
      s += f"\nCUDA:{i} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MB)"
  
  ## Log the device information
  print(f'{s}\n')
  return 'cuda:0' if cuda else 'cpu'


def resolve_pretrained_weight_path(name, weights_enum):
  try:
    weights = weights_enum.DEFAULT  ## already a Weights object
    url = weights.url
    cache_dir = torch.hub.get_dir()
    filename = os.path.basename(url)
    cached_path = os.path.join(cache_dir, 'checkpoints', filename)

    if os.path.exists(cached_path):
      print(f"Pretrained weights for '{name}' loaded from: {cached_path}")
    else:
      print(f"Weight file expected at: {cached_path} (but not found yet — may be downloading)")

    return cached_path
  except Exception as e:
    print(f"[!] Could not resolve pretrained weight file path: {e}")
    return None


def get_torch_network(args):
  """
  The pretrained argument has been replaced by weights as of version 0.13 of torchvision, and will be removed in version 0.15. 
  net = torchvision.models.__dict__[name](pretrained=pretrained)
  
  * Using the 'weights' argument instead of 'pretrained' for future compatibility; warning is fixed
  """
  name = args.net
  pretrain = args.pretrain

  try:
    weights_enum = torchvision.models.get_model_weights(name)
    weights = weights_enum.DEFAULT if pretrain else None
  except Exception as e:
    print(f"[!] Warning: Failed to resolve weights enum for {name}: {e}")
    weights = None

  net = torchvision.models.get_model(name, weights=weights)

  if pretrain and weights is not None:
    resolve_pretrained_weight_path(name, weights_enum)

    ## Print some weights for inspection
    first_layer_weights = list(net.parameters())[0].data
    print(f"First layer shape: {first_layer_weights.shape}")
    print(f"Sample first-layer weights: {first_layer_weights.view(-1)[:10]}")

  else:
    print(f"No pretrained weights used for model: {name}")

  return net


def get_target_layers(model, model_name):
  """
  Retrieves the default target layers for a model based on modelzoocfg.
  """
  layer_name = modelzoocfg['target_layers'].get(model_name, None)
  if not layer_name:
    raise ValueError(f"No default target layer defined for {model_name}")

  target_layer = dict([*model.named_modules()]).get(layer_name)
  if target_layer is None:
    raise ValueError(f"Target layer '{layer_name}' not found in model '{model_name}'")

  return [target_layer]


def modify_output(args, net):
  """Adapted from: https://github.com/Shenqishaonv/100-Driver-Source

  @author: wenjing
  @article{100-Driver,
  author    = {Wang Jing, Li Wengjing, Li Fang, Zhang Jun, Wu Zhongcheng, Zhong Zhun and Sebe Nicu},
  title     = {100-Driver: A Large-scale, Diverse Dataset for Distracted Driver Classification},
  journal={IEEE Transactions on Intelligent Transportation Systems},
  year      = {2023}
  publisher={IEEE}}

  https://drive.google.com/file/d/1JhLdk-feblXi_pepF7GTb-6vHwtrbxSr/view

  Enhanced ScoreLossPlus: https://github.com/congduan-HNU/SSoftmax.git
  """
  classifier_zoo = modelzoocfg['modelzoo']['classifier_zoo']
  vgg_zoo = modelzoocfg['modelzoo']['vgg_zoo']
  inception_zoo = modelzoocfg['modelzoo']['inception_zoo']
  mobile_zoo = modelzoocfg['modelzoo']['mobile_zoo']
  squeeze_zoo = modelzoocfg['modelzoo']['squeeze_zoo']
  resnet_zoo = modelzoocfg['modelzoo']['resnet_zoo']

  num_class = args.num_class
  loss_fname = args.loss

  score_level = getattr(args, 'score_level', 1) if loss_fname and loss_fname=='ScoreLossPlus' else 1
  channel_out = num_class*score_level

  print(f"[modify_output] loss: {loss_fname}, score_level: {score_level}, channel_out: {channel_out}")

  # pdb.set_trace()
  if args.net in classifier_zoo:
    channel_in = net.classifier.in_features
    net.classifier = nn.Linear(channel_in, channel_out)
  elif args.net in mobile_zoo:
    channel = net.classifier[-1].in_features
    net.classifier[-1] = nn.Linear(channel, channel_out)
  elif args.net in inception_zoo:
    ## channel_in = net.linear.in_features
    channel_in = net.fc.in_features
    net.fc = nn.Linear(channel_in, channel_out)
  elif args.net in vgg_zoo:
    ## TODO: check if dynamic check `channel_in` works for vgg_zoo or not
    # channel_in = net.classifier[-1].in_features
    # net.classifier[-1] = nn.Linear(channel_in, channel_out)
    net.classifier[-1] = nn.Linear(4096, channel_out)
  elif args.net in squeeze_zoo:
    # net.classifier[1] = nn.Conv2d(512, channel_out, kernel_size=1)
    channel_in = net.classifier[1].in_channels
    net.classifier[1] = nn.Conv2d(channel_in, channel_out, kernel_size=1)
  elif args.net in resnet_zoo:
    channel_in = net.fc.in_features
    net.fc = nn.Linear(channel_in, channel_out)
  else:
    # pdb.set_trace()
    channel_in = net.fc.in_features
    # channel_in = net.linear.in_features
    net.fc = nn.Linear(channel_in, channel_out)
  return net


def unloadmodel(model):
  """
  Unload the current model from memory to free up GPU resources.
  """
  import gc

  del model  # Remove the model reference
  gc.collect()  # Trigger garbage collection
  if torch.cuda.is_available():
    torch.cuda.empty_cache()  # Clear GPU memory


def loadmodel(args):
  start = time.time()
  device = 'cuda' if args.gpu and torch.cuda.is_available() else 'cpu'
  weights_path = args.weights_path or None

  print(torch.version.cuda)
  print(f'loadmodel::device: {device}')
  print(f'loadmodel::args.gpu: {args.gpu}')
  print(f'loadmodel::weights_path: {args.weights_path}')

  ## Build and modify the model
  # pdb.set_trace()
  net = get_torch_network(args)
  net = modify_output(args, net)

  ## If weights_path is provided, load the custom weights
  if weights_path:
    print(f"torch_loadmodel.loadmodel::weights_path '{weights_path}'")
    state_dict = torch.load(weights_path, map_location=device)
    ## Log keys in state_dict
    # print(f"Loaded state_dict keys: {state_dict.keys()}")
    # net.load_state_dict(state_dict, strict=True)
    missing_keys, unexpected_keys = net.load_state_dict(state_dict, strict=False)
    if missing_keys:
      print(f"Missing keys in the loaded state_dict: {missing_keys}")
    if unexpected_keys:
      print(f"Unexpected keys in the loaded state_dict: {unexpected_keys}")

    print(f"Custom model weights successfully loaded from {weights_path}")
  else:
    print("No custom weights provided. Using the default model.")

  net.to(device)

  finish = time.time()
  print(f'Epoch loadmodel time consumed: {finish - start:.2f}s')
  return net


def preprocessor(im, height=224, width=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], device='cuda'):
  # Convert numpy array to PIL Image
  if isinstance(im, np.ndarray):
    # import cv2
    # im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
    im = Image.fromarray(im)

  transform = transforms.Compose([
    transforms.Resize((height, width)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
  ])
  ## Add batch dimension
  im = transform(im).unsqueeze(0)

  _use_device = 'cuda' if device and device=='cuda' and torch.cuda.is_available() else 'cpu'
  device = torch.device(_use_device)
  # print(f'preprocessor:: device: {device}')
  im = im.to(_use_device)
  return im


def predict(model, input_tensor, labels):
    with torch.no_grad():
      output = model(input_tensor)

    ## The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    probabilities = torch.nn.functional.softmax(output[0], dim=0)

    ## Get the index of the max probability
    confidence, predicted_class = torch.max(probabilities, 0)
    return labels[predicted_class.item()], confidence.item()


def predict_deprecated(model, input_tensor):
    with torch.no_grad():
        output = model(input_tensor)
    _, predicted = torch.max(output, 1)
    return predicted.item()


def inference(weights_path):
    pass
