# # 针对最后的feature map上的pixel的gradcam
# import numpy as np

# from typing import Callable, List, Optional, Tuple

# import numpy as np
# import torch
# import ttach as tta

# from pytorch_grad_cam.utils.image import scale_cam_image
# from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection


# # class GradCAM_PIXEL(BaseCAM):
# #     def __init__(self, model, target_layers, reshape_transform=None):
# #         super(GradCAM_PIXEL, self).__init__(model, target_layers, reshape_transform)

# #     def get_cam_weights(self, input_tensor, target_layer, targets, activations, grads):
# #         # 2D image
# #         if len(grads.shape) == 4:
# #             return np.mean(grads, axis=(2, 3))
# #         # 3D image
# #         elif len(grads.shape) == 5:
# #             return np.mean(grads, axis=(2, 3, 4))
# #         else:
# #             raise ValueError("Invalid grads shape. Shape of grads should be 4 (2D image) or 5 (3D image).")

# #     def get_cam_image(self, input_tensor, target_layer, targets, activations, grads, eigen_smooth=False, feature_coords=None):
# #         if feature_coords is not None:
# #             # 只计算特定特征点的Grad-CAM
# #             x, y = feature_coords
# #             activations = activations[:, :, x, y]
# #             grads = grads[:, :, x, y]

# #             # 使用 np.expand_dims 恢复到 4 维
# #             activations = np.expand_dims(np.expand_dims(activations, -1), -1)
# #             grads = np.expand_dims(np.expand_dims(grads, -1), -1)
# #         # import pdb
# #         # pdb.set_trace()
# #         weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
# #         # 2D conv
# #         if len(activations.shape) == 4:
# #             weighted_activations = weights[:, :, None, None] * activations
# #         # 3D conv
# #         elif len(activations.shape) == 5:
# #             weighted_activations = weights[:, :, None, None, None] * activations
# #         else:
# #             raise ValueError(f"Invalid activation shape. Got {len(activations.shape)}.")

# #         if eigen_smooth:
# #             cam = get_2d_projection(weighted_activations)
# #         else:
# #             cam = weighted_activations.sum(axis=1)
# #         return cam

# #     def __call__(self, input_tensor, targets=None, aug_smooth=False, eigen_smooth=False, feature_coords=None):
# #         if aug_smooth:
# #             return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)

# #         return self.forward(input_tensor, targets, eigen_smooth, feature_coords=feature_coords)

# #     def forward(self, input_tensor, targets, eigen_smooth=False, feature_coords=None):
# #         input_tensor = input_tensor.to(self.device)

# #         if self.compute_input_gradient:
# #             input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)

# #         self.outputs = outputs = self.activations_and_grads(input_tensor)

# #         if targets is None:
# #             target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
# #             targets = [ClassifierOutputTarget(category) for category in target_categories]

# #         if self.uses_gradients:
# #             self.model.zero_grad()
# #             loss = sum([target(output) for target, output in zip(targets, outputs)])
# #             loss.backward(retain_graph=True)

# #         cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth, feature_coords=feature_coords)
# #         return self.aggregate_multi_layers(cam_per_layer)

# #     def compute_cam_per_layer(self, input_tensor, targets, eigen_smooth, feature_coords=None):
# #         activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
# #         grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
# #         target_size = self.get_target_width_height(input_tensor)

# #         cam_per_target_layer = []
# #         for i in range(len(self.target_layers)):
# #             target_layer = self.target_layers[i]
# #             layer_activations = activations_list[i] if i < len(activations_list) else None
# #             layer_grads = grads_list[i] if i < len(grads_list) else None

# #             cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth, feature_coords)
# #             cam = np.maximum(cam, 0)
# #             scaled = scale_cam_image(cam, target_size)
# #             cam_per_target_layer.append(scaled[:, None, :])

# #         return cam_per_target_layer



# class ActivationsAndGradients_pixel:
#     """ Class for extracting activations and
#     registering gradients from targetted intermediate layers """

#     def __init__(self, model, target_layers, reshape_transform):
#         self.model = model
#         self.gradients = []
#         self.activations = []
#         self.reshape_transform = reshape_transform
#         self.handles = []
#         for target_layer in target_layers:
#             self.handles.append(
#                 target_layer.register_forward_hook(self.save_activation))
#             # Because of https://github.com/pytorch/pytorch/issues/61519,
#             # we don't use backward hook to record gradients.
#             self.handles.append(
#                 target_layer.register_forward_hook(self.save_gradient))

#     def save_activation(self, module, input, output):
#         activation = output

#         if self.reshape_transform is not None:
#             activation = self.reshape_transform(activation)
#         self.activations.append(activation.cpu().detach())

#     def save_gradient(self, module, input, output):
#         if not hasattr(output, "requires_grad") or not output.requires_grad:
#             # You can only register hooks on tensor requires grad.
#             return

#         # Gradients are computed in reverse order
#         def _store_grad(grad):
#             if self.reshape_transform is not None:
#                 grad = self.reshape_transform(grad)
#             self.gradients = [grad.cpu().detach()] + self.gradients

#         output.register_hook(_store_grad)

#     def __call__(self, x, feature_coords):
#         self.gradients = []
#         self.activations = []
#         # 这里本来的输出应该是一个loss吧  这里针对我的模型改一下 
#         # return self.model(x)

#         # 主要是改这里
#         _, roi = self.model(x)
#         bs, c, h, w = roi.size()
#         roi_flatten = roi.permute(0, 2, 3, 1).contiguous().view(-1, c)
#         coord_x, coord_y = feature_coords
#         index = coord_x * w + coord_y
#         pixel = roi_flatten[index, :]
#         pixel = torch.unsqueeze(pixel, dim=0)

#         scores = self.model.classifier(pixel)
#         pred = torch.argmax(scores, dim=1)
#         pred_label = pred.item()

#         print(f"in gradcam, pred is {pred_label}")

#         return scores

#     def release(self):
#         for handle in self.handles:
#             handle.remove()


# class BaseCAM_pixel:
#     def __init__(
#         self,
#         model: torch.nn.Module,
#         target_layers: List[torch.nn.Module],
#         reshape_transform: Callable = None,
#         compute_input_gradient: bool = False,
#         uses_gradients: bool = True,
#         tta_transforms: Optional[tta.Compose] = None,
#     ) -> None:
#         self.model = model.eval()
#         self.target_layers = target_layers

#         # Use the same device as the model.
#         self.device = next(self.model.parameters()).device
#         self.reshape_transform = reshape_transform
#         self.compute_input_gradient = compute_input_gradient
#         self.uses_gradients = uses_gradients
#         if tta_transforms is None:
#             self.tta_transforms = tta.Compose(
#                 [
#                     tta.HorizontalFlip(),
#                     tta.Multiply(factors=[0.9, 1, 1.1]),
#                 ]
#             )
#         else:
#             self.tta_transforms = tta_transforms

#         self.activations_and_grads = ActivationsAndGradients_pixel(self.model, target_layers, reshape_transform)

#     """ Get a vector of weights for every channel in the target layer.
#         Methods that return weights channels,
#         will typically need to only implement this function. """

#     def get_cam_weights(
#         self,
#         input_tensor: torch.Tensor,
#         target_layers: List[torch.nn.Module],
#         targets: List[torch.nn.Module],
#         activations: torch.Tensor,
#         grads: torch.Tensor,
#     ) -> np.ndarray:
#         raise Exception("Not Implemented")

#     def get_cam_image(
#         self,
#         input_tensor: torch.Tensor,
#         target_layer: torch.nn.Module,
#         targets: List[torch.nn.Module],
#         activations: torch.Tensor,
#         grads: torch.Tensor,
#         eigen_smooth: bool = False,
#     ) -> np.ndarray:
#         weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
#         # 2D conv
#         if len(activations.shape) == 4:
#             weighted_activations = weights[:, :, None, None] * activations
#         # 3D conv
#         elif len(activations.shape) == 5:
#             weighted_activations = weights[:, :, None, None, None] * activations
#         else:
#             raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")

#         if eigen_smooth:
#             cam = get_2d_projection(weighted_activations)
#         else:
#             cam = weighted_activations.sum(axis=1)
#         return cam

#     def forward(
#         self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False, feature_coords: tuple = None,
#     ) -> np.ndarray:
#         input_tensor = input_tensor.to(self.device)

#         if self.compute_input_gradient:
#             input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
        
#         self.outputs = outputs = self.activations_and_grads(input_tensor, feature_coords)

#         if targets is None:
#             target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
#             targets = [ClassifierOutputTarget(category) for category in target_categories]

#         if self.uses_gradients:
#             self.model.zero_grad()
#             loss = sum([target(output) for target, output in zip(targets, outputs)])
#             loss.backward(retain_graph=True)

#         # In most of the saliency attribution papers, the saliency is
#         # computed with a single target layer.
#         # Commonly it is the last convolutional layer.
#         # Here we support passing a list with multiple target layers.
#         # It will compute the saliency image for every image,
#         # and then aggregate them (with a default mean aggregation).
#         # This gives you more flexibility in case you just want to
#         # use all conv layers for example, all Batchnorm layers,
#         # or something else.
#         cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
#         return self.aggregate_multi_layers(cam_per_layer)

#     def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:
#         if len(input_tensor.shape) == 4:
#             width, height = input_tensor.size(-1), input_tensor.size(-2)
#             return width, height
#         elif len(input_tensor.shape) == 5:
#             depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
#             return depth, width, height
#         else:
#             raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")

#     def compute_cam_per_layer(
#         self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool,
#     ) -> np.ndarray:
#         # import pdb
#         # pdb.set_trace()
#         activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
#         grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
#         target_size = self.get_target_width_height(input_tensor)

#         cam_per_target_layer = []
#         # Loop over the saliency image from every layer
#         for i in range(len(self.target_layers)):
#             target_layer = self.target_layers[i]
#             layer_activations = None
#             layer_grads = None
#             if i < len(activations_list):
#                 layer_activations = activations_list[i]
#             if i < len(grads_list):
#                 layer_grads = grads_list[i]

#             cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
#             cam = np.maximum(cam, 0)
#             scaled = scale_cam_image(cam, target_size)
#             cam_per_target_layer.append(scaled[:, None, :])

#         return cam_per_target_layer

#     def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:
#         cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
#         cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
#         result = np.mean(cam_per_target_layer, axis=1)
#         return scale_cam_image(result)

#     def forward_augmentation_smoothing(
#         self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
#     ) -> np.ndarray:
#         cams = []
#         for transform in self.tta_transforms:
#             augmented_tensor = transform.augment_image(input_tensor)
#             cam = self.forward(augmented_tensor, targets, eigen_smooth)

#             # The ttach library expects a tensor of size BxCxHxW
#             cam = cam[:, None, :, :]
#             cam = torch.from_numpy(cam)
#             cam = transform.deaugment_mask(cam)

#             # Back to numpy float32, HxW
#             cam = cam.numpy()
#             cam = cam[:, 0, :, :]
#             cams.append(cam)

#         cam = np.mean(np.float32(cams), axis=0)
#         return cam

#     def __call__(
#         self,
#         input_tensor: torch.Tensor,
#         targets: List[torch.nn.Module] = None,
#         aug_smooth: bool = False,
#         eigen_smooth: bool = False,
#         feature_coords: tuple = None,
#     ) -> np.ndarray:
#         # Smooth the CAM result with test time augmentation
#         if aug_smooth is True:
#             return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)

#         return self.forward(input_tensor, targets, eigen_smooth, feature_coords)

#     def __del__(self):
#         self.activations_and_grads.release()

#     def __enter__(self):
#         return self

#     def __exit__(self, exc_type, exc_value, exc_tb):
#         self.activations_and_grads.release()
#         if isinstance(exc_value, IndexError):
#             # Handle IndexError here...
#             print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
#             return True



# class GradCAM_PIXEL(BaseCAM_pixel):
#     def __init__(self, model, target_layers,
#                  reshape_transform=None):
#         super(
#             GradCAM_PIXEL,
#             self).__init__(
#             model,
#             target_layers,
#             reshape_transform)

#     def get_cam_weights(self,
#                         input_tensor,
#                         target_layer,
#                         target_category,
#                         activations,
#                         grads):
#         # 2D image
#         if len(grads.shape) == 4:
#             return np.mean(grads, axis=(2, 3))
        
#         # 3D image
#         elif len(grads.shape) == 5:
#             return np.mean(grads, axis=(2, 3, 4))
        
#         else:
#             raise ValueError("Invalid grads shape." 
#                              "Shape of grads should be 4 (2D image) or 5 (3D image).")

