import importlib
from functools import partial
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn


class AttributionGenerator:

    def __init__(self, attribution_method: Dict) -> None:
        module_name, method = attribution_method['method'].rsplit('.', 1)
        module_obj = importlib.import_module(module_name)

        kwargs = attribution_method.get('kwargs', dict())
        self.attribution_fn = partial(getattr(module_obj, method), **kwargs)

    def generate_attribution(
            self, model: nn.Module, preprocessed_image: torch.Tensor, label: torch.Tensor,
            need_check_shape: bool) -> Tuple[np.ndarray, np.ndarray]:
        return self.attribution_fn(
            model=model, preprocessed_image=preprocessed_image, label=label, need_check_shape=need_check_shape)
