import torch
import numpy as np
from typing import List, Callable

# from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric

class CosSimilarity:
    """ Target function """
    def __init__(self, features):
        self.features = features
    
    def __call__(self, model_output):
        cos = torch.nn.CosineSimilarity()
        return cos(model_output, self.features)
    
class ImageFeatureExtractor(torch.nn.Module):
    """ Image feature wrapper """
    def __init__(self, model):
        super(ImageFeatureExtractor, self).__init__()
        self.model = model
                
    def __call__(self, x):
        return self.model.get_image_features(x)

class BrainFeatureExtractor(torch.nn.Module):
    """ Image feature wrapper """
    def __init__(self, model):
        super(BrainFeatureExtractor, self).__init__()
        self.model = model

    def __call__(self, x):
        return self.model.get_brain_features(x)

class TextFeatureExtractor(torch.nn.Module):
    """ Text feature wrapper """
    def __init__(self, model):
        super(TextFeatureExtractor, self).__init__()   
        self.model = model
                
    def __call__(self, x):
        return self.model.get_text_features(x)

def multiply_tensor_with_cam(input_tensor: torch.Tensor,
                             cam: torch.Tensor):
    """ Multiply an input tensor (after normalization)
        with a pixel attribution map
    """
    return input_tensor * cam

class PerturbationConfidenceMetric:
    def __init__(self, perturbation):
        self.perturbation = perturbation

    def __call__(self, input_tensor: torch.Tensor,
                 cams: np.ndarray,
                 targets: List[Callable],
                 model1: torch.nn.Module,
                 model2: torch.nn.Module,
                 return_visualization=False,
                 return_diff=True):

        if return_diff:
            with torch.no_grad():
                outputs = model1(input_tensor)
                # scores = [target(output).cpu().numpy()
                #           for target, output in zip(targets, outputs)]
                scores = [target(outputs).cpu().numpy()
                          for target in targets]
                scores = np.float32(scores)

        batch_size = input_tensor.size(0)
        perturbated_tensors = []
        for i in range(batch_size):
            cam = cams[i]
            # tensor = self.perturbation(input_tensor[i, ...].cpu(),
            #                            torch.from_numpy(cam))
            tensor = self.perturbation(input_tensor[i, ...],
                                       cam)
            tensor = tensor.to(input_tensor.device)
            perturbated_tensors.append(tensor.unsqueeze(0))
        perturbated_tensors = torch.cat(perturbated_tensors)

        with torch.no_grad():
            outputs_after_imputation = model2(perturbated_tensors)
        # scores_after_imputation = [
        #     target(output).cpu().numpy() for target, output in zip(
        #         targets, outputs_after_imputation)]
        scores_after_imputation = [
            target(outputs_after_imputation).cpu().numpy() for target in targets]
        scores_after_imputation = np.float32(scores_after_imputation)
        if return_diff:
            result = scores_after_imputation - scores
        else:
            result = scores_after_imputation
        if return_visualization:
            return result, perturbated_tensors
        else:
            return result


class CamMultImageConfidenceChange(PerturbationConfidenceMetric):
    def __init__(self):
        super(CamMultImageConfidenceChange,
              self).__init__(multiply_tensor_with_cam)


class DropInConfidenceText(CamMultImageConfidenceChange):
    def __init__(self):
        super(DropInConfidenceText, self).__init__()

    def __call__(self, *args, **kwargs):
        scores = super(DropInConfidenceText, self).__call__(*args, **kwargs)
        scores = -scores
        return np.maximum(scores, 0)


class IncreaseInConfidenceText(CamMultImageConfidenceChange):
    def __init__(self):
        super(IncreaseInConfidenceText, self).__init__()

    def __call__(self, *args, **kwargs):
        scores = super(IncreaseInConfidenceText, self).__call__(*args, **kwargs)
        return np.float32(scores > 0)


