import os
import torch
import pickle
import gzip
import numpy as np
from collections import defaultdict

from torchattacks.attack import Attack
import torch.nn as nn


class BackboneSM(Attack):
    def __init__(self, model, eps=8/255, p=2, **kwargs):
        super().__init__("BackboneSM", model)
        self.supported_mode = ['default', 'targeted']
        self.eps = eps
        self.p = p
        self.kwargs = kwargs
        self.model = self.get_vissl_backbone(model)

    def get_vissl_backbone(self, model):
        return model.model.trunk 
    
    def get_distance(self, images, p=1):

        differentiable_graph = self.model(images, out_feat_keys=["flatten"])[0]
        # print(differentiable_graph.shape)
        assert not torch.allclose(differentiable_graph, torch.zeros_like(differentiable_graph), atol=1e-6)

        distance = differentiable_graph 

        # assert torch.allclose(distance, torch.zeros_like(distance), atol=1e-6)

        differentiable_objective = torch.norm(distance, p=p, dim=1).mean()     

        return differentiable_objective

    
    def forward(self, images, labels=None):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        images.requires_grad = True

        # fixed_input_point = self.model(images, out_feat_keys=["flatten"])[0]
        # fixed_input_point = fixed_input_point.detach().clone()

        differentiable_distance = self.get_distance(images)
        grad = torch.autograd.grad(differentiable_distance, images,retain_graph=False, create_graph=False)[0]

        assert not torch.allclose(grad, torch.zeros_like(grad), atol=1e-6)

        eps = self.eps
        noise = grad.sign() * eps
        # clip noise
        adv_images = images + noise

        # print(noise.max(), noise.min(), noise.mean(), noise.std())

        adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        return adv_images


class BackboneDist(Attack):
    def __init__(self, model, eps=8/255, p=2, **kwargs):
        super().__init__("BackboneL1", model)
        self.supported_mode = ['default', 'targeted']
        self.eps = eps
        self.p = p
        self.kwargs = kwargs
        self.model = self.get_vissl_backbone(model)

    def get_vissl_backbone(self, model):
        return model.model.trunk 
    
    def get_distance(self, images, p=2):

        sigma = torch.randn_like(images).sign() * 8/255
        differentiable_graph = self.model(images+sigma, out_feat_keys=["flatten"])[0] # M(x_0 + sigma)

        fixed_input_point = self.model(images, out_feat_keys=["flatten"])[0].detach().clone() # M(x_0)
        # print(differentiable_graph.shape)
        assert not torch.allclose(differentiable_graph, torch.zeros_like(differentiable_graph), atol=1e-6)

        # distance = differentiable_graph - fixed_input_point
        # assert torch.allclose(distance, torch.zeros_like(distance), atol=1e-6)
        # grad((M(x_0 + sigma) - StopG(M(x_0)))**2), x_0)

        # differentiable_objective = torch.norm(distance, p=p, dim=1).mean()    
        differentiable_objective = -torch.cosine_similarity(differentiable_graph, fixed_input_point, dim=1).mean()
 
        return differentiable_objective

    def forward(self, images, labels=None):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        images.requires_grad = True

        differentiable_distance = self.get_distance(images)
        grad = torch.autograd.grad(differentiable_distance, images,retain_graph=False, create_graph=False)[0]

        assert not torch.allclose(grad, torch.zeros_like(grad), atol=1e-6)

        noise = grad.sign() * self.eps
        # noise = noise / torch.norm(noise, p=2, dim=1, keepdim=True)

        adv_images = images + noise

        print(noise.max(), noise.min(), noise.mean(), noise.std())
        print(torch.norm(grad, p=2, dim=1))
        print("_________________________")

        adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        return adv_images

#backbone attack with PGD

class BackbonePGD(Attack):
    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, p=2, random_start = True, **kwargs):
        super().__init__("BackbonePGD", model)
        self.supported_mode = ['default', 'targeted']
        self.eps = eps
        self.alpha = alpha
        self.p = p
        self.steps = steps
        self.random_start = random_start
        self.kwargs = kwargs
        self.model = self.get_vissl_backbone(model)
        self.transform = None

        # for name, param in self.model.named_parameters():
        #     param.requires_grad = True

    def get_vissl_backbone(self, model):
        return model.model.trunk

    def trunk_forward(self, images):
        if "feat_eval_mapping" in self.model.base_model.__dict__:
            if "flatten" in self.model.base_model.feat_eval_mapping:
                return self.model(images, out_feat_keys=["flatten"])[0]
        else:
            return self.model.base_model(images)[0]
    
    def get_distance(self, adv_repres, original_repres, p=2):

        # sigma = torch.randn_like(images)
        # differentiable_graph = self.model(images+sigma, out_feat_keys=["flatten"])[0] # M(x_0 + sigma)

        # fixed_input_point = self.model(images, out_feat_keys=["flatten"])[0].detach().clone() # M(x_0)
        # # print(differentiable_graph.shape)
        # assert not torch.allclose(differentiable_graph, torch.zeros_like(differentiable_graph), atol=1e-6)

        # distance = adv_repres - original_repres
        # assert torch.allclose(distance, torch.zeros_like(distance), atol=1e-6)
        # grad((M(x_0 + sigma) - StopG(M(x_0)))**2), x_0)

        # differentiable_objective = torch.norm(distance, p=1, dim=1).mean()  

        # differentiable_objective = torch.log_softmax(adv_repres, dim=1) - torch.log_softmax(original_repres, dim=1)
        # ce = nn.CrossEntropyLoss(reduction="sum")

        #_____________
        # kl = nn.KLDivLoss(reduction="mean")
        # # differentiable_objective = kl(torch.log_softmax(adv_repres, dim=1), \
        #                             #   torch.log_softmax(original_repres, dim=1))
        # differentiable_objective = kl(torch.softmax(adv_repres, dim=1), \
        #                               torch.softmax(original_repres, dim=1))
        #_____________
        # cosine = nn.CosineSimilarity(dim=1, eps=1e-6)
        # differentiable_objective = 1-cosine(adv_repres, original_repres).mean()
        #_____________
        # pairwise_distance = nn.PairwiseDistance(p=2)
        # differentiable_objective = pairwise_distance(adv_repres, original_repres).mean()

        #_____________
        # if self.transform is None:
        #     random_shape = (adv_repres.shape[1], 10)
        #     # random_transform = torch.randn(random_shape, device=adv_repres.device)
        #     # random_transform = random_transform / torch.norm(random_transform, p=2, dim=0, keepdim=True)
        #     random_transform = torch.randn(random_shape, device=adv_repres.device)
        #     self.transform = random_transform
        #     # self.bias = torch.randn(10, device=adv_repres.device)
        #     #add transform to the graph
        #     self.transform = nn.Parameter(self.transform)
        #     self.transform.requires_grad = True

        # adv_repres = (adv_repres @ self.transform) #+ self.bias
        # original_repres = (original_repres @ self.transform) #+ self.bias
        #_____________

        #_____________

        # random_labels = torch.randint(0, 100, (adv_repres.shape[0],), device=adv_repres.device)
        # differentiable_objective = ce(torch.softmax(adv_repres, dim=1), random_labels)
        # differentiable_objective = ce(torch.softmax(original_repres, dim=1), torch.softmax(adv_repres, dim=1))
        #_____________

        differentiable_objective = 1-torch.cosine_similarity(adv_repres, original_repres, dim=1).mean()
        
        #_____________
        # cos = nn.CosineEmbeddingLoss(reduction="mean")

        # loss = cos(adv_repres, original_repres, torch.ones(adv_repres.shape[0], device=adv_repres.device))
        # diff = torch.norm(adv_repres - original_repres, p=2, dim=1)
        # differentiable_objective = -loss #+ diff.mean()
        #_____________
        #like the TPGD
        # loss_1 = nn.KLDivLoss(reduction='sum')
        # differentiable_objective = -loss_1(torch.log_softmax(adv_repres, dim=1), \
        #                                 torch.softmax(original_repres, dim=1))
        #_____________

        return differentiable_objective
    

    def forward(self, images, labels=None):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        # images.requires_grad = True

        unnormalized_start_images = images

        if self.normalization_used is not None:
            if self._normalization_applied is False:
                # adv_images = self.normalize(adv_images)
                images = self.normalize(images)


        fixed_input_point = self.trunk_forward(images).detach().clone() # M(x_0)

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps) / 10 #todo
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        for _ in range(self.steps):
            adv_images.requires_grad = True

            # differentiable_graph = self.model(adv_images, out_feat_keys=["flatten"])[0]


            #if self._normalization_applied is False:
            #    adv_images = self.normalize(adv_images)

            differentiable_graph = self.trunk_forward(adv_images)
            differentiable_distance = self.get_distance(differentiable_graph, fixed_input_point)
            # print(differentiable_distance)

            grad = torch.autograd.grad(differentiable_distance, adv_images,retain_graph=False, create_graph=False)[0]
            
            assert not torch.allclose(grad, torch.zeros_like(grad), atol=1e-6)

            # print(torch.norm(grad, p=2, dim=1).mean())
            # print("_________________________")

            adv_images = adv_images.detach() + grad.sign() * self.alpha


            if self.normalization_used is not None:

                unnormalized_adv_images = self.inverse_normalize(adv_images)
                # unnormalized_images = self.inverse_normalize(images)
                #delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
                unnormalized_delta = torch.clamp(unnormalized_adv_images - unnormalized_start_images, min=-self.eps, max=self.eps)
                unnormalized_projected = unnormalized_start_images + unnormalized_delta
                unnormalized_projected = torch.clamp(unnormalized_projected, min=0, max=1)
                adv_images = self.normalize(unnormalized_projected)
            else:
                delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
                adv_images = torch.clamp(images + delta, min=0, max=1).detach()
        
            assert not adv_images.requires_grad

            #rand = torch.rand(1)
            # print(rand)

            # if rand > 0.99:  
            #     print(delta.max(), delta.min(), delta.mean(), delta.std())
            #adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        if self.normalization_used is not None:
            res = self.inverse_normalize(adv_images)
        else:
            res = adv_images

        return res

class BackboneEarlyLayerPGD(Attack):
    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, p=2, random_start = True, **kwargs):
        super().__init__("BackboneEarlyLayerPGD", model)
        self.supported_mode = ['default', 'targeted']
        self.eps = eps*4
        self.alpha = alpha
        self.p = p
        self.steps = steps
        self.random_start = random_start
        self.kwargs = kwargs
        self.model = self.get_vissl_backbone(model)
        self.transform = None

        #requires grad for all the intermediate representations
        # for name, param in self.model.named_parameters():
        #     param.requires_grad = True


    def get_vissl_backbone(self, model):
        return model.model.trunk
    
    def get_distance(self, adv_images, images, block_names, p=2):

        # diff = adv_repres - original_repres
        #get all the intermediate representations

        diff_graphs = []
        for block_name in block_names:
            # representation = self.model(adv_images, out_feat_keys=["conv1_relu"])[0]
            differentiable_graph = self.model(adv_images, out_feat_keys=[block_name])[0]
            fixed_input_point = self.model(images, out_feat_keys=[block_name])[0].detach().clone()
            diff_graphs.append((differentiable_graph, fixed_input_point))

        
        differentiable_objective = None
        lamb = 0.3
        loss = nn.MSELoss(reduction="sum")
        ce = nn.CrossEntropyLoss(reduction="sum")
        for differentiable_graph, fixed_input_point in diff_graphs:
            if differentiable_objective is None:
                differentiable_objective = loss(differentiable_graph, fixed_input_point)
                # differentiable_objective += ce(torch.softmax(differentiable_graph, dim=1), torch.softmax(fixed_input_point, dim=1))
            else:
                differentiable_objective += lamb * loss(differentiable_graph, fixed_input_point)
                # differentiable_objective += lamb * ce(torch.softmax(differentiable_graph, dim=1), torch.softmax(fixed_input_point, dim=1))


        # differentiable_objective = None
        # lamb = 0.3
        # for diff_graph in diff_graphs:
            
        #     if differentiable_objective is None:
        #         differentiable_objective = torch.norm(diff_graph, p=float("inf"), dim=1).mean()
        #     else:
        #         differentiable_objective += lamb * torch.norm(diff_graph, p=float("inf"), dim=1).mean()



        return differentiable_objective
    

    def forward(self, images, labels=None):
        r"""
        Overridden.
        """
        images = images.clone().detach().to(self.device)
        images.requires_grad = True

        # block_names = self.model.base_model.all_feat_names
        # chosen_keys = block_names[0:4]

        # print(self.model.base_model.__dict__)
        chosen_keys = ["res1", "res2"]

        # print(self.model._feature_blocks)
        # fixed_input_point = self.model(images, out_feat_keys=["conv2"])[0].detach().clone() # M(x_0)

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        grads = []
        for _ in range(self.steps):
            # self.model.zero_grad()
            adv_images.requires_grad = True
            adv_images.requires_grad_()

            differentiable_distance = self.get_distance(adv_images, images, chosen_keys)
            grad = torch.autograd.grad(differentiable_distance, adv_images,retain_graph=False, create_graph=False)[0]
            grads.append(grad)
            assert not torch.allclose(grad, torch.zeros_like(grad), atol=1e-6)

            # print(torch.norm(grad, p=2, dim=1).mean())
            # print("_________________________")

            adv_images = adv_images.detach() + grad.sign() * self.alpha
            delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)


            rand = torch.rand(1)
            # print(rand)

            if rand > 0.99:  
                print(delta.max(), delta.min(), delta.mean(), delta.std())

            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

            if len(grads) > 1:
                if torch.allclose(grads[-1], grads[-2], atol=1e-1):
                    print("Gradident is not changing")
                    break

        return adv_images
