import torch
from pytorch_grad_cam.base_cam import BaseCAM
import numpy as np

class EigenCamMultiVec(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None):
        super().__init__(model=model,
                         target_layers=target_layers,
                         reshape_transform=reshape_transform,
                         uses_gradients=False)

    def get_cam_image(self, input_tensor, target_layer, target_category, activation_batch, grads, eigen_smooth):
        size = (
            activation_batch.shape[0], activation_batch.shape[-2], activation_batch.shape[-1])
        feature_maps = torch.from_numpy(activation_batch)
        feature_maps = feature_maps.view(
            feature_maps.shape[0], feature_maps.shape[1], -1).permute(0, 2, 1)
        feature_maps = feature_maps - feature_maps.mean(dim=1, keepdim=True)

        _, sigma, VT = torch.linalg.svd(feature_maps, full_matrices=True)
        sigma = sigma.real
        sigma = sigma/sigma.max(dim=1, keepdim=True).values

        cams = []
        for i in range(sigma.shape[1]):
            tmp = VT[:, i].unsqueeze(-1)
            projection = torch.bmm(feature_maps, tmp).view(size)
            weight = sigma[:, i, None, None]
            projection = projection * weight
            cams.append(projection.abs().numpy())

        cam = np.array(cams)
        return cam.mean(0)
    
    def __del__(self):
        pass