from typing import Tuple, Union, List
import numpy as np
from scipy.ndimage.interpolation import zoom

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch

from .utils import norm_cam

def filter_bound(max_idx, h, w):
    mask = (max_idx // w != 0) & (max_idx // w != (h-1)) & (max_idx % w != 0 ) & (max_idx % w != (w-1))
    return mask

class TraceTopk:
    def __init__(self, topk_num=64, scale=3, amp=1000, **kwargs):
        self.topk_num = topk_num
        self.scale = scale
        self.amp = amp

    def _register_model(self, model: nn.Module, layer_name: Union[str, List]):
        self.data_hidden = []
        def hook(module, input, output):
            self.data_hidden.append(output)
        layer: nn.Module = eval(f"model.{layer_name}")
        layer.register_forward_hook(hook)

    def __call__(self, model: nn.Module, img: Union[np.ndarray, Tensor], layer_name: Union[str, List]) -> Tuple[np.ndarray, int]:
        """
        Args:
            model (nn.Module): Input pre-trained model.
            img (np.ndarray): Image data.
            layer_name (str): The name of the target module of the model.
        
        Returns:
            tuple: (output, cam, prediction).
        """
        topk_num = self.topk_num
        scale = self.scale

        BS, C, H, W = img.shape
        # register model
        self._register_model(model, layer_name)

        # trace back
        model.trace(True)
        output = model(img)
        pred = torch.argmax(output, dim=-1)

        sample = self.data_hidden[0]
        data_hidden = sample.clone()
        bs, ch_n, h, w = data_hidden.size()
        data_hidden = data_hidden.view(bs, ch_n, -1)
        chm_data, max_idx = data_hidden.max(dim=-1)
        mask_bound = filter_bound(max_idx, h, w)
        chm_data[~mask_bound] = -100
        chm_topk_v, chm_topk = chm_data.topk(k=topk_num, dim=-1, largest=True)

        with torch.no_grad():
            base_zeros = model.trace_back(torch.zeros_like(sample), layer_name)
            recon_img_list = []
            for idx in range(topk_num):
                tmp_data = torch.zeros_like(sample).view(bs, ch_n, -1)
                cur_chidx = chm_topk[:, idx]
                cur_spidx = max_idx[torch.arange(BS), cur_chidx]
                tmp_data[torch.arange(BS), cur_chidx, cur_spidx] = self.amp
                # max_ch_slice = data_hidden[torch.arange(BS), cur_chidx, :]
                # max_ch_slice = max_ch_slice / (max_ch_slice.max(dim=-1, keepdim=True)[0] + 1e-6)
                # tmp_data[torch.arange(BS), cur_chidx, :] = max_ch_slice * self.amp

                tmp_recon = model.trace_back(x=tmp_data.view(bs, ch_n, h, w), module_name=layer_name)
                # generate masks
                tmp_recon = (tmp_recon - base_zeros).abs().sum(dim=1)
                tmp_recon = (tmp_recon > scale * tmp_recon.mean()).float()
                recon_img_list.append(tmp_recon)
            recon_img = torch.stack(recon_img_list, dim=1)
            cam = torch.einsum("ij,ijmn->imn", chm_topk_v, recon_img).detach()

        model.trace(False)
        return {
            "score": output,
            "cam": norm_cam(cam),
            "pred": pred
        }
