import torch
from pytorch_grad_cam.base_cam import BaseCAM
import numpy as np
from tensorly.decomposition import tucker

class TSM(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None):
        super(TSM, self).__init__(model,
                                       target_layers,
                                       reshape_transform,
                                       uses_gradients=False)
        
    def get_cam_image(self, input_tensor, target_layer, target_category, activation_batch, grads, eigen_smooth):
        feature_maps = torch.from_numpy(activation_batch)
        VT = []
        for feature_map in feature_maps:
            t = feature_map.numpy()
            _, factors = tucker(t, rank=list(t.shape))
            VT.append(factors[0])
            
        VT = torch.from_numpy(np.array(VT))
        
        tmp = VT[:,0]
        tmp = tmp.unsqueeze(-1).unsqueeze(-1)
        cam = (feature_maps * tmp).mean(dim=1)
        cam = cam.abs().numpy()
        return cam
    
    def __del__(self):
        pass
    
class MTSM(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None):
        super(MTSM, self).__init__(model=model,
                                       target_layers=target_layers,
                                       reshape_transform=reshape_transform,
                                       uses_gradients=False)
        
    def get_cam_image(self, input_tensor, target_layer, target_category, activation_batch, grads, eigen_smooth):
        feature_maps = torch.from_numpy(activation_batch)
        VT = []
        singular_values = []
        for feature_map in feature_maps:
            t = feature_map.numpy()
            core, factors = tucker(t, rank=list(t.shape))
            VT.append(factors[0])
            singular_values.append(np.linalg.norm(core, axis=(-2,-1), ord='fro'))

        VT = torch.from_numpy(np.array(VT))
        singular_values = torch.from_numpy(np.array(singular_values))
        singular_values /= singular_values.max(dim=1, keepdim=True).values
        
        cams = []
        for i in range(singular_values.shape[1]):
            tmp = VT[:,i].unsqueeze(-1).unsqueeze(-1)
            projection = (feature_maps * tmp).mean(dim=1)
            weight = singular_values[:,i].unsqueeze(-1).unsqueeze(-1)
            projection = projection * weight
            cams.append(projection.detach().abs().numpy())
        
        cam = np.array(cams)
        cam = np.transpose(cam, axes=(1,0,2,3))
        cam = cam.mean(1)

        return cam
    
    def __del__(self):
        pass