import torch
import torch.nn as nn
from attack.approximators import ConvApproximator, UNetAutoEncoder, TransformerAutoEncoder

class Attack():
    def __init__(self):
        pass
    def attack(self,images,labels):
        raise NotImplementedError("Attack class should not be instantiated directly. Please use a subclass of Attack.")
    def __call__(self,images,labels):
        return self.attack(images,labels)

class StraightThroughBPDA(Attack):
    def __init__(self,model, eps = 8/255,alpha = 2/255,steps = 10,device = 'cuda'):
        super(StraightThroughBPDA,self).__init__()
        try:
            self.g,self.f = model[0],model[1]
        except IndexError:
            print('BDPA should be used in white box setting, setting g = identity')
            self.g = lambda x: x
            self.f = model
        self.device = device
        self.g.to(device)
        self.f.to(device)
        self.eps = eps
        self.steps = steps
        self.alpha = alpha
        

    def attack(self,images,labels):
        
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        loss_fn = nn.CrossEntropyLoss()
        adv_images = images.clone().detach()

        adv_noise = torch.empty_like(adv_images).uniform_(
            -self.eps, self.eps
        )

        
        
        for _ in range(self.steps):
            x = adv_images + adv_noise
            g_x = self.g(x).detach()
            g_x.requires_grad_()
            f_x = self.f(g_x)
            loss = loss_fn(f_x,labels)
            grad = torch.autograd.grad(
                loss, g_x, retain_graph=False, create_graph=False
            )[0]
            adv_noise = self.alpha * grad.sign()


            adv_noise = torch.clamp(adv_noise,-self.eps,self.eps)

        return adv_images+adv_noise
    
class OptimizationBPDA(Attack):
    def __init__(self,model, eps = 8/255,alpha = 2/255,steps = 10,device = 'cuda',verbose = False,lr = 1.0,betas=(0.9,0.999)):
        super(OptimizationBPDA,self).__init__()
        try:
            self.g,self.f = model[0],model[1]
        except IndexError:
            print('BDPA should be used in white box setting, setting g = identity')
            self.g = lambda x: x
            self.f = model
        self.device = device
        self.g.to(device)
        self.f.to(device)
        self.eps = eps
        self.steps = steps
        self.alpha = alpha
        self.verbose = verbose
        self.lr = lr
        self.betas = betas
        

    def attack(self,images,labels):
        
        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        loss_fn = nn.CrossEntropyLoss()
        adv_images = images.clone().detach()

        #adv_noise = torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
        adv_noise = torch.zeros_like(adv_images)

        
        optimizer = torch.optim.Adam([adv_noise],lr=self.lr,betas=self.betas)
        for _ in range(self.steps):
            #optimizer = torch.optim.Adam([adv_noise],lr=1.0)
            adv_noise = adv_noise.detach()
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adv_noise],**defaults}]
            x = adv_images + adv_noise
            g_x = self.g(x).detach()
            optimizer.zero_grad()
            g_x = g_x.requires_grad_()
            #print(g_x.is_leaf)
            f_x = self.f(g_x)
            loss = -loss_fn(f_x,labels)
            loss.backward()
            old_adv_noise = adv_noise.clone().detach()
            #print(adv_noise.grad)
            adv_noise.grad = g_x.grad
            #print(adv_noise.grad)
            optimizer.step()
            adv_noise = torch.clamp(adv_noise,old_adv_noise-self.alpha,old_adv_noise+self.alpha)


            adv_noise = torch.clamp(adv_noise,-self.eps,self.eps)
            if self.verbose:
                print(torch.linalg.vector_norm(adv_noise,np.inf))

        #print('done')
        return adv_images+adv_noise
    
    
class SelfLearnedBPDA(Attack):
    def __init__(self,model,approximator_type='unet', eps = 8/255,alpha = 2/255,steps = 10,device = 'cuda',verbose = False,input_size = (224,224)):
        super(SelfLearnedBPDA,self).__init__()
        try:
            self.g,self.f = model[0],model[1]
        except IndexError:
            print('BDPA should be used in white box setting, setting g = identity')
            self.g = lambda x: x
            self.f = model
        self.device = device
        self.g.to(device)
        self.f.to(device)

        self.eps = eps
        self.steps = steps
        self.alpha = alpha
        self.verbose = verbose
        self.input_size = input_size
        self.approximator_type = approximator_type

        
    def _make_approximation_function(self,approximator):
        g = self.g
        class ApproximationFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                ctx.save_for_backward(input)
                return g(input)

            @staticmethod
            def backward(ctx, grad_output):
                (input,) = ctx.saved_tensors
                input.requires_grad_(True)
                with torch.enable_grad():
                    approx = approximator(input)
                grad_input = torch.autograd.grad(approx, input, grad_outputs=grad_output, retain_graph=False)[0]
                return grad_input

        return ApproximationFunction
        

    def _create_approximator(self,input_size,approximator_type = 'unet'):
        if approximator_type == 'conv':
            Approximator = ConvApproximator(input_size[0], feature_dim=128)
        elif approximator_type == 'unet':
            Approximator = UNetAutoEncoder(in_channels=3, feature_dim=128)
        elif approximator_type == 'transformer':
            Approximator = TransformerAutoEncoder(input_size[0])
        else:
            raise ValueError(f"Unknown approximator type: {approximator_type}")
            
        
        return Approximator
    
    def _make_noisy_batch(self,input_tensor, batch_size, epsilon):
        noise = (torch.rand((batch_size, *input_tensor.shape)) * 2 - 1) * epsilon
        return input_tensor.unsqueeze(0) + noise.to(self.device)
    
    def _train_approximator(self,X, epochs=30, batches = 10, lr=1e-3):
        model = self._create_approximator(input_size=self.input_size,approximator_type=self.approximator_type).to(self.device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1)

        for epoch in range(epochs):
            total_loss = 0.0
            for _ in range(batches):
                images = self._make_noisy_batch(X,16,self.eps).to(self.device)
                outputs = model(images)
                transformed_outputs = self.g(images)
                loss = criterion(outputs, transformed_outputs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            images = X.unsqueeze(0).to(self.device)
            outputs = model(images)
            transformed_outputs = self.g(images)
            loss = criterion(outputs, transformed_outputs)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            

            avg_loss = total_loss / (batches)
            print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
        scheduler.step()

        print("Training finished!")
        return model
        
    
    def attack(self,images,labels):
        images = images.clone().detach()
        labels = labels.clone().detach()
        adv_images = torch.zeros_like(images).to(self.device)

        loss_fn = nn.CrossEntropyLoss()
        
        for i in range(len(images)):
            image = images[i].to(self.device)
            label = labels[i].to(self.device)
            approximator = self._train_approximator(image)
            approximation_function = self._make_approximation_function(approximator=approximator)
            adv_noise = torch.zeros_like(image)
            optimizer = torch.optim.Adam([adv_noise],lr=1.0)
            for _ in range(self.steps):
                #adv_noise = adv_noise.detach()
                #defaults = optimizer.param_groups[0].copy()
                #defaults.pop('params')
                #optimizer.param_groups = [{'params': [adv_noise],**defaults}]
                x = image + adv_noise
                f_x = self.f(approximation_function.apply(x.unsqueeze(0)))
                loss = -loss_fn(f_x,label.unsqueeze(0))
                loss.backward()
                old_adv_noise = adv_noise.clone().detach()
                optimizer.step()
                adv_noise = torch.clamp(adv_noise,old_adv_noise-self.alpha,old_adv_noise+self.alpha)
                adv_noise = torch.clamp(adv_noise,-self.eps,self.eps)
            adv_images[i] = image + adv_noise
        return adv_images
    

class FastSelfLearnedBPDA(Attack):
    def __init__(self,model,data_loader,approximator_type='unet', eps = 8/255,alpha = 2/255,steps = 10,train_epochs = 20,device = 'cuda',verbose = False,input_size = (224,224)):
        super(FastSelfLearnedBPDA,self).__init__()
        try:
            self.g,self.f = model[0],model[1]
        except IndexError:
            print('BDPA should be used in white box setting, setting g = identity')
            self.g = lambda x: x
            self.f = model
        self.device = device
        self.g.to(device)
        self.f.to(device)

        self.eps = eps
        self.steps = steps
        self.alpha = alpha
        self.verbose = verbose
        self.input_size = input_size
        self.approximator_type = approximator_type
        self.data_loader = data_loader
        self.train_epochs = train_epochs

        self.approximator = self._train_approximator()
        self.approximation_function = self._make_approximation_function(approximator=self.approximator)
        
        
    def _make_approximation_function(self,approximator):
        g = self.g
        class ApproximationFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                ctx.save_for_backward(input)
                return g(input)

            @staticmethod
            def backward(ctx, grad_output):
                (input,) = ctx.saved_tensors
                input.requires_grad_(True)
                with torch.enable_grad():
                    approx = approximator(input)
                grad_input = torch.autograd.grad(approx, input, grad_outputs=grad_output, retain_graph=False)[0]
                return grad_input

        return ApproximationFunction
        

    def _create_approximator(self,input_size,approximator_type = 'unet'):
        if approximator_type == 'conv':
            Approximator = ConvApproximator(input_size[0], feature_dim=128)
        elif approximator_type == 'unet':
            Approximator = UNetAutoEncoder(in_channels=3, feature_dim=128)
        elif approximator_type == 'transformer':
            Approximator = TransformerAutoEncoder(input_size[0])
        else:
            raise ValueError(f"Unknown approximator type: {approximator_type}")
            
        
        return Approximator
    
    def _make_noisy_batch(self,input_tensor, batch_size, epsilon):
        noise = (torch.rand(input_tensor.shape) * 2 - 1) * epsilon
        return input_tensor + noise
    
    def _train_approximator(self, batches = 3, lr=1e-3):
        model = self._create_approximator(input_size=self.input_size,approximator_type=self.approximator_type).to(self.device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        for epoch in range(self.train_epochs):
            total_loss = 0.0
            for X, _ in self.data_loader:
                for _ in range(batches):
                    images = self._make_noisy_batch(X,16,self.eps).to(self.device)
                    outputs = model(images)
                    with torch.no_grad():                       
                        transformed_outputs = self.g(images)
                    loss = criterion(outputs, transformed_outputs)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    total_loss += loss.item()
                images = X.to(self.device)
                outputs = model(images)
                transformed_outputs = self.g(images)
                loss = criterion(outputs, transformed_outputs)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            

            avg_loss = total_loss / (batches)
            print(f"Epoch {epoch+1}/{self.train_epochs} - Loss: {avg_loss:.4f}")
            scheduler.step()

        print("Training finished!")
        return model
        
    
    def attack(self,images,labels):
        images = images.clone().detach()
        labels = labels.clone().detach()

        loss_fn = nn.CrossEntropyLoss()
        
        adv_noise = torch.zeros_like(images)
        optimizer = torch.optim.Adam([adv_noise],lr=1.0)
        for _ in range(self.steps):
            adv_noise = adv_noise.detach()
            adv_noise.requires_grad_()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adv_noise],**defaults}]
            optimizer.zero_grad()
            x = images + adv_noise
            f_x = self.f(self.approximation_function.apply(x))
            loss = -loss_fn(f_x,labels)
            loss.backward()
            old_adv_noise = adv_noise.clone().detach()
            optimizer.step()
            adv_noise = torch.clamp(adv_noise,old_adv_noise-self.alpha,old_adv_noise+self.alpha)
            adv_noise = torch.clamp(adv_noise,-self.eps,self.eps)
        
        return images + adv_noise
    

class PretrainedBPDA(Attack):
    def __init__(self,model,testset,approximator_type='transformer', eps = 8/255,alpha = 2/255,steps = 10,train = False,train_epochs = 100,ft_epochs = 10,device = 'cuda',verbose = False,input_size = (224,224)):
        super(PretrainedBPDA,self).__init__()
        try:
            self.g,self.f = model[0],model[1]
        except IndexError:
            print('BDPA should be used in white box setting, setting g = identity')
            self.g = lambda x: x
            self.f = model
        self.device = device
        self.g.to(device)
        self.f.to(device)

        self.eps = eps
        self.steps = steps
        self.alpha = alpha
        self.verbose = verbose
        self.input_size = input_size
        self.approximator_type = approximator_type
        self.data_loader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=True)
        self.train_epochs = train_epochs
        self.ft_epochs = ft_epochs
        self.train = train

        self.approximator = self._create_approximator(input_size=self.input_size,approximator_type=self.approximator_type).to(self.device)
        try:
            self._load_approximator()
        except FileNotFoundError:
            self._train_approximator()
        if self.train:
            self._train_approximator()
        self.approximation_function = self._make_approximation_function(approximator=self.approximator)
        
        
    def _make_approximation_function(self,approximator):
        g = self.g
        class ApproximationFunction(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                ctx.save_for_backward(input)
                return g(input)

            @staticmethod
            def backward(ctx, grad_output):
                (input,) = ctx.saved_tensors
                input.requires_grad_(True)
                with torch.enable_grad():
                    approx = approximator(input)
                grad_input = torch.autograd.grad(approx, input, grad_outputs=grad_output, retain_graph=False)[0]
                return grad_input

        return ApproximationFunction
        

    def _create_approximator(self,input_size,approximator_type = 'unet'):
        if approximator_type == 'conv':
            Approximator = ConvApproximator(input_size[0], feature_dim=128)
        elif approximator_type == 'unet':
            Approximator = UNetAutoEncoder(in_channels=3, feature_dim=128)
        elif approximator_type == 'transformer':
            Approximator = TransformerAutoEncoder(input_size[0])
        else:
            raise ValueError(f"Unknown approximator type: {approximator_type}")
            
        
        return Approximator
    
    def _make_noisy_batch(self,input_tensor, batch_size, epsilon):
        noise = (torch.rand(input_tensor.shape) * 2 - 1) * epsilon
        return input_tensor + noise
    
    def _save_approximator(self):
        torch.save(self.approximator.state_dict(), f"approximator_{self.approximator_type}.pth")
        print(f"Approximator saved to approximator_{self.approximator_type}.pth")

    def _load_approximator(self):
        self.approximator.load_state_dict(torch.load(f"approximator_{self.approximator_type}.pth"))
        print(f"Approximator loaded from approximator_{self.approximator_type}.pth")
    
    def _train_approximator(self):
        model = self.approximator.to(self.device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.train_epochs/3, gamma=0.1)

        for epoch in range(self.train_epochs):
            total_loss = 0.0
            for X, _ in self.data_loader:
                
                images = self._make_noisy_batch(X,16,self.eps).to(self.device)
                outputs = model(images)
                with torch.no_grad():                       
                    transformed_outputs = self.g(images)
                loss = criterion(outputs, transformed_outputs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            
            

            avg_loss = total_loss / (len(self.data_loader))
            print(f"Epoch {epoch+1}/{self.train_epochs} - Loss: {avg_loss:.4f}")
            scheduler.step()

        print("Training finished!")
        
        self.approximator = model
        self._save_approximator()

    
    def _fine_tune_approximator(self, lr=1e-3):
        model = self.approximator.to(self.device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

        for epoch in range(self.ft_epochs):
            total_loss = 0.0
            for X, _ in self.data_loader:
                
                images = self._make_noisy_batch(X,16,self.eps).to(self.device)
                outputs = model(images)
                with torch.no_grad():                       
                    transformed_outputs = self.g(images)
                loss = criterion(outputs, transformed_outputs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            

            avg_loss = total_loss / (len(self.data_loader))
            print(f"Epoch {epoch+1}/{self.train_epochs} - Loss: {avg_loss:.4f}")
            scheduler.step()

        print("Training finished!")
        self.approximator = model
        
    
    def attack(self,images,labels):
        images = images.clone().detach()
        labels = labels.clone().detach()

        loss_fn = nn.CrossEntropyLoss()
        
        adv_noise = torch.zeros_like(images)
        optimizer = torch.optim.Adam([adv_noise],lr=1.0)
        for _ in range(self.steps):
            adv_noise = adv_noise.detach()
            adv_noise.requires_grad_()
            
            defaults = optimizer.param_groups[0].copy()
            defaults.pop('params')
            optimizer.param_groups = [{'params': [adv_noise],**defaults}]
            optimizer.zero_grad()
            x = images + adv_noise
            f_x = self.f(self.approximation_function.apply(x))
            loss = -loss_fn(f_x,labels)
            loss.backward()
            old_adv_noise = adv_noise.clone().detach()
            optimizer.step()
            adv_noise = torch.clamp(adv_noise,old_adv_noise-self.alpha,old_adv_noise+self.alpha)
            adv_noise = torch.clamp(adv_noise,-self.eps,self.eps)
        
        return images + adv_noise