import torch
from pytorch_grad_cam.base_cam import BaseCAM
import numpy as np
from tensorly.decomposition import tucker


class Stats(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None, tensor=True):
        super(Stats, self).__init__(model=model,
                                    target_layers=target_layers,
                                    reshape_transform=reshape_transform,
                                    uses_gradients=False)

        self.singular_values = []
        self.tensor = tensor

    def get_cam_image(self, input_tensor, target_layer, target_category, activation_batch, grads, eigen_smooth):
        singular_values = self.get_tensor_singular_values(
            activation_batch) if self.tensor else self.get_svd_singular_values(activation_batch)
        self.singular_values.append(singular_values)

        size = (
            activation_batch.shape[0], activation_batch.shape[-2], activation_batch.shape[-1])
        return np.zeros(size)

    def get_tensor_singular_values(self, activation_batch):
        feature_maps = torch.from_numpy(activation_batch)
        singular_values = []
        for feature_map in feature_maps:
            t = feature_map.numpy()
            core, _ = tucker(t, rank=list(t.shape))
            singular_values.append(np.linalg.norm(
                core, axis=(-2, -1), ord='fro'))

        return np.array(singular_values)

    def get_svd_singular_values(self, activation_batch):
        feature_maps = torch.from_numpy(activation_batch)
        feature_maps = feature_maps.view(
            feature_maps.shape[0], feature_maps.shape[1], -1).permute(0, 2, 1)
        feature_maps = feature_maps - feature_maps.mean(dim=1, keepdim=True)

        _, sigma, _ = torch.linalg.svd(feature_maps, full_matrices=True)
        return sigma.real.numpy()

    def __del__(self):
        pass
