import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy
import scipy.stats as stats


class IntegratedGradientsAttack:
    
    def __init__(self, mean_image, test_image, original_label, device, NET, NET2, dim,
                 num_cka=1, k_top=100, num_steps=100, reference_image=None, target_map = None):
        """
        https://github.com/amiratag/InterpretationFragility/blob/9dd5fb995745e335d6698be3b9b31b5f713ae45f/utils.py#L268
        """
        self.device = device
        self.num_cka = num_cka
        self.dim = dim
        self.model = NET
        self.model2 = NET2
        self.target_map = target_map
        self.test_image, self.mean_image = test_image.to(device), mean_image.to(device)
        self.reference_image = torch.zeros_like(test_image).to(device) if reference_image is None else reference_image.to(device)
        self.original_label = original_label.to(device)
        self.k_top, self.num_steps = k_top, num_steps
        if self.check_prediction(original_label, test_image):
            self.valid = False
            return
        self.valid = True
        counterfactuals = self.create_counterfactuals(test_image).to(device)
        self.saliency1, self.saliency1_flatten = self.get_saliencies(self.model, counterfactuals)
        self.mass_center1 = self.get_mass(self.saliency1, self.saliency1_flatten)
        top_val, self.topK = torch.topk(self.saliency1_flatten, k_top)
    
    def check_prediction(self, original_label, image):
        _, c, h, w = image.size()
        preds = self.model(image.to(self.device))
        if self.num_cka != 1:
            preds = preds[-1].to(self.device)
        else:
            preds = preds.to(self.device)
        if torch.max(preds, 1).indices != original_label:
            #print("Network's Prediction is Already Incorrect!")
            return True
        self.original_confidence2 = torch.max(F.softmax(preds, dim=1))
        self.original_confidence = torch.max(preds)
        return False

    def get_saliencies(self, model, inputs):
        _, c, h, w = inputs.size()
        preds = model(inputs.to(self.device))
        if self.num_cka != 1:
            preds = preds[-1].to(self.device)
        else:
            preds = preds.to(self.device)
        sum_logits = torch.sum(preds * F.one_hot(self.original_label.long(), self.dim)).to(self.device)
        parallel_gradients = torch.autograd.grad(sum_logits, inputs, allow_unused=True)[0]
        average_gradients = parallel_gradients.mean(0)
        difference_multiplied = torch.mul(average_gradients, inputs[-1]-self.reference_image)
        saliency_unnormalized = torch.sum(difference_multiplied.abs(), 1).reshape(h, w)
        saliency = h*w*torch.div(saliency_unnormalized, torch.sum(saliency_unnormalized))
        saliency_flatten = saliency.flatten()
        return saliency, saliency_flatten
        
    def get_mass(self, saliency, saliency_flatten): 
        _, c, h, w = self.test_image.size()
        top_val, top_idx = torch.topk(saliency_flatten, self.k_top)
        x_mesh, y_mesh = torch.meshgrid(torch.arange(h), torch.arange(w))
        saliency, x_mesh, y_mesh = saliency.to(self.device), x_mesh.to(self.device), y_mesh.to(self.device)
        mass_center = torch.stack([torch.sum(saliency*x_mesh)/(h*w), torch.sum(saliency*y_mesh)/(h*w)])
        return mass_center

    def create_counterfactuals(self, in_image):
        ref_subtracted = (in_image - self.reference_image).to(self.device)
        counterfactuals = torch.arange(1, self.num_steps+1).reshape(self.num_steps, 1, 1, 1).to(self.device)/self.num_steps * ref_subtracted.repeat(self.num_steps, 1, 1, 1) + self.reference_image.repeat(self.num_steps, 1, 1, 1).to(self.device)
        return counterfactuals
   
    def give_simple_perturbation(self, attack_method, in_image):
        counterfactuals = self.create_counterfactuals(in_image)
        _, c, h, w = self.test_image.size()
        s, sf = self.get_saliencies(self.model2, counterfactuals)
        elem1 = torch.argsort(self.saliency1.reshape(h*w))[-self.k_top:]
        self.elements1 = torch.zeros(h*w, device=self.device)
        self.elements1[elem1] = 1
        if attack_method == "random":
            perturbation = torch.randn(self.num_steps, w, h, c, device=self.device)
        elif attack_method == "topK":
            topK_loss = torch.sum(sf * self.elements1)
            topK_direction = -torch.autograd.grad(topK_loss, counterfactuals)[0].to(self.device)
            perturbation = topK_direction.reshape(self.num_steps, c, h, w).to(self.device)
        elif attack_method == "mass_center":
            mass_center = self.get_mass(s, sf).to(self.device)
            mass_center_loss = -torch.sum((mass_center - self.mass_center1)**2).to(self.device)
            #print(mass_center, self.mass_center1, mass_center_loss)
            mass_center_direction = -torch.autograd.grad(mass_center_loss, counterfactuals)[0].to(self.device)
            perturbation = mass_center_direction.reshape(self.num_steps, c, h, w).to(self.device)
        elif attack_method == "target":
            if self.target_map is None:
                raise ValueError("No target region determined!")
            else:
                target_loss = -torch.sum(s * self.target_map)
                target_direction = -torch.autograd.grad(target_loss, in_image)[0].to(self.device)
                perturbation = target_direction.reshape(self.num_steps, c, h, w).to(self.device)
                perturbation = np.reshape(perturbation,[self.num_steps,w,h,c])
        perturbation_summed = torch.sum(torch.arange(1, self.num_steps+1).reshape(self.num_steps, 1, 1, 1).to(self.device)/self.num_steps * perturbation, 0).to(self.device)
        return torch.sign(perturbation_summed)

    
    def apply_perturb(self, in_image, pert, alpha, epsilon=7.8641):
        out_image = in_image + alpha*pert
        pert = out_image - self.test_image
        pert = torch.clamp(pert, -epsilon, epsilon)
        out_image = self.test_image + pert
        return out_image
    
    def check_measure(self, test_image_pert, measure):
        prob = self.model(test_image_pert)
        if self.num_cka != 1:
            prob = prob[-1].to(self.device)
        else:
            prob = prob.to(self.device)
        if torch.argmax(prob, 1) == self.original_label:
            counterfactuals = self.create_counterfactuals(test_image_pert)
            s, sf = self.get_saliencies(self.model, counterfactuals)
            if measure=="intersection":
                _, top2 = torch.topk(sf, self.k_top)
                criterion = float(len(np.intersect1d(self.topK.clone().detach().cpu(),top2.clone().detach().cpu())))/self.k_top
            elif measure=="correlation":
                criterion = scipy.stats.spearmanr(self.saliency1_flatten.clone().detach().cpu(), sf.clone().detach().cpu())[0]
            elif measure=="mass_center":
                center2 = self.get_mass(s, sf)
                criterion = -torch.norm(self.mass_center1.float()-center2.float())
            else:
                raise ValueError("Invalid measure!")
            return criterion
        else:
            return 1.
        
    def iterative_attack(self, attack_num, attack_method, epsilon, iters=100, alpha=1,  measure="intersection"):
        _, c, h, w = self.test_image.size()
        test_image_pert = self.test_image.clone().detach().requires_grad_(True).to(self.device)
        min_criterion, perturb_size = 1., 0.
        for counter in range(iters):
            pert = self.give_simple_perturbation(attack_method, test_image_pert).to(self.device)
            test_image_pert = self.apply_perturb(test_image_pert, pert, alpha, epsilon).requires_grad_(True).to(self.device)
            criterion = self.check_measure(test_image_pert, measure)
            if criterion < min_criterion:
                min_criterion = criterion
                self.perturbed_image = test_image_pert.clone().detach().requires_grad_(True).to(self.device)
                perturb_size = torch.max((self.test_image-self.perturbed_image).abs())
        if min_criterion==1.:
            #print("Attack unsuccessful (max: {})".format(epsilon))
            return None
        #print("[%d] Max: %.3f | Real: %.3f"%(attack_num+1, epsilon, torch.max((self.test_image-self.perturbed_image).abs())))
        predicted_scores = self.model(self.perturbed_image)
        if self.num_cka != 1:
            predicted_scores = predicted_scores[-1].to(self.device)
        else:
            predicted_scores = predicted_scores.to(self.device)
        confidence2 = torch.max(F.softmax(predicted_scores, dim=1))
        confidence = torch.max(predicted_scores)
        counterfactuals = self.create_counterfactuals(self.perturbed_image).to(self.device)
        self.saliency2, self.saliency2_flatten = self.get_saliencies(self.model, counterfactuals)
        _, self.top2 = torch.topk(self.saliency2_flatten, self.k_top)
        self.mass_center2 = self.get_mass(self.saliency2, self.saliency2_flatten)
        correlation = scipy.stats.spearmanr(self.saliency1_flatten.clone().detach().cpu(), self.saliency2.reshape(h*w).clone().detach().cpu())[0]
        intersection = float(len(np.intersect1d(self.topK.clone().detach().cpu(), self.top2.clone().detach().cpu())))/self.k_top
        center_dislocation = torch.norm(self.mass_center1-self.mass_center2)
        return [self.original_confidence.item(), self.original_confidence2.item()], [confidence.item(), confidence2.item()], intersection, correlation, center_dislocation.item(), perturb_size.item()
