"""Explinability module.

@adapted-from: https://github.com/jacobgil/pytorch-grad-cam; Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more
@credit: Jacob Gildenblat, https://github.com/jacobgil
@license: https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/refs/heads/master/LICENSE
"""
__author__ = 'XYZ'

import pdb
import os

from importlib import import_module
from pathlib import Path

import cv2
import numpy as np

from tqdm import tqdm

from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget, ClassifierOutputReST
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import (
    show_cam_on_image,
    deprocess_image,
    preprocess_image
)

from ..core._log_ import logger
log = logger(__file__)


def get_cam_methods():
  return [
    "GradCAM",
    "FEM",
    "HiResCAM",
    "ScoreCAM", ## OOM Error
    "GradCAMPlusPlus",
    "AblationCAM",
    "XGradCAM",
    "EigenCAM",
    "EigenGradCAM",
    "LayerCAM",
    "FullGrad",
    "GradCAMElementWise",
    "KPCA_CAM",
    "ShapleyCAM",
  ]


def _process(
  fn_cam_class,
  model,
  target_layers,
  rgb_img,
  input_tensor,
  target_class,
  aug_smooth=False,
  eigen_smooth=False,
  batch_size=16,
  device='cuda',
):
    with fn_cam_class(model=model, target_layers=target_layers) as cam:
      cam.batch_size = batch_size
      grayscale_cam = cam(input_tensor=input_tensor,
                          targets=target_class,
                          aug_smooth=aug_smooth,
                          eigen_smooth=eigen_smooth)
      grayscale_cam = grayscale_cam[0, :]

      cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
      cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

    gb_model = GuidedBackpropReLUModel(model=model, device=device)
    gb = gb_model(input_tensor, target_category=None)

    cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
    cam_gb = deprocess_image(cam_mask * gb)
    gb = deprocess_image(gb)

    return cam_image, gb, cam_gb


def process_cam(
  model,
  target_layers,
  source,
  methods,
  target_class,
  aug_smooth=False,
  eigen_smooth=False,
  mean=[0.485, 0.456, 0.406],
  std=[0.229, 0.224, 0.225],
  batch_size=16,
  device='cuda',
  to_path='logs',
):
  ## Loop over all CAM methods instead of looping over model names
  mod = import_module(f'pytorch_grad_cam', package=__package__)

  for cam_method in methods:
    log.info(f"Running CAM method: {cam_method}")
    fn_cam_class = getattr(mod, cam_method)

    basepath = Path(os.path.join(to_path, cam_method))
    basepath.mkdir(parents=True, exist_ok=True)
    for image_path in tqdm(source, desc=f"{cam_method}"):
      try:
        rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
        rgb_img = np.float32(rgb_img) / 255
        input_tensor = preprocess_image(rgb_img, mean=mean, std=std).to(device)

        cam_image, gb, cam_gb = _process(
          fn_cam_class,
          model,
          target_layers,
          rgb_img,
          input_tensor,
          target_class,
          aug_smooth,
          eigen_smooth,
          batch_size,
          device,
        )

        filename = Path(image_path).stem
        cv2.imwrite(str(basepath / f'{filename}-cam.jpg'), cam_image)
        cv2.imwrite(str(basepath / f'{filename}-gb.jpg'), gb)
        cv2.imwrite(str(basepath / f'{filename}-cam_gb.jpg'), cam_gb)

      except Exception as e:
        log.error(f"Error processing {image_path} with {cam_method}: {e}")
