from MNIST.digits_model import *

import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset
from MNIST.data import CustomTensorDataset,NormalizeRangeTanh,UnNormalizeRangeTanh
from torchvision.utils import make_grid
import urllib
from torch.utils.data.dataloader import default_collate

num_workers = 0
batch_size = 20
basepath = 'some/base/path'
transform = transforms.ToTensor()

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
#     image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5,range=(0.0,1.0))
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    
#     plt.imshow((image_grid * 255).astype(np.uint8))
    plt.show()

    
class classifierFTest(): 
    def __init__(self, eps, adversarial=False, use_gpu=True):
        self.log = {}
        self.log['best_model'] = None
        self.log['train_loss'] = []
        self.log['val_loss'] = []
        self.log['train_accuracy'] = []        
        self.log['val_accuracy'] = []
        self.eps = eps
        self.adversarial = adversarial
        self.use_gpu = use_gpu
        self.model = None
        self.loss_function = None
        self.optimizer = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        self.dataset = None
    
    def create_data_loaders(self, dataset):
        self.dataset = dataset
        from six.moves import urllib
        opener = urllib.request.build_opener()
        opener.addheaders = [('User-agent', 'Mozilla/5.0')]
        urllib.request.install_opener(opener)

        transform = transforms.Compose([
            torchvision.transforms.Resize((20,20)),
            transforms.ToTensor(),
#             transforms.Normalize((0.5,), (0.5,)),
        ])
        
        def my_collate(batch):
            modified_batch = []
            for item in batch:
                image, label = item
                if label <5:
                    modified_batch.append(item)
            return default_collate(modified_batch)
        
        if dataset=='MNIST':
            train_set = datasets.MNIST(root='/home/ubuntu/datasets/', train=True, download = True, transform=transform)
            size = len(train_set)
            val_size = int(size*0.2)
            mnist_train, mnist_val = train_set[:-val_size], train_set[-val_size:]
            
#             self.train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=8, collate_fn = my_collate)
            self.train_loader = DataLoader(mnist_train, batch_size=256, shuffle=True, num_workers=8)
            self.val_loader = DataLoader(mnist_val, batch_size=256, shuffle=True, num_workers=8)

            test_set = datasets.MNIST(root='/var/local/', train=False, download = True, transform=transform)
#             self.test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=8, collate_fn = my_collate)
            self.test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=8)
            
            

        elif dataset=='FMNIST':
            train_set = datasets.FashionMNIST(root='/home/ubuntu/datasets/', train=True, download = True, transform=transform)
            size = len(train_set)
            val_size = int(size*0.2)
            Fmnist_train, Fmnist_val = train_set[:-val_size], train_set[-val_size:]
            
            self.train_loader = DataLoader(Fmnist_train, batch_size=256, shuffle=True, num_workers=8)
            self.val_loader = DataLoader(Fmnist_val, batch_size=256, shuffle=True, num_workers=8)

            test_set = datasets.FashionMNIST(root='/home/ubuntu/datasets/', train=False, download = True, transform=transform)
#             self.test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=8, collate_fn = my_collate) 
            self.test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=8) 
            
            
        elif dataset=='KMNIST':
            train_set = datasets.KMNIST(root='/home/ubuntu/datasets//', train=True, download = True, transform=transform)
            size = len(train_set)
            val_size = int(size*0.2)
            Kmnist_train, Fmnist_val = train_set[:-val_size], train_set[-val_size:]
            
            self.train_loader = DataLoader(Kmnist_train, batch_size=256, shuffle=True, num_workers=8)
            self.val_loader = DataLoader(Kmnist_val, batch_size=256, shuffle=True, num_workers=8)
            

            test_set = datasets.KMNIST(root='/home/ubuntu/datasets/', train=False, download = True, transform=transform)
#             self.test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=8, collate_fn = my_collate) 
            self.test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=8) 
            
           
        print(self.train_loader)
    def visualize_single_batch(self):
        # get some random training images
        for i, data in enumerate(self.train_loader, 0):
            inputs, labels = data
            if self.use_gpu:       
                inputs = inputs.cuda()
                labels = labels.cuda()

            if self.adversarial:
                inputs1, labels = self.SSIM_attack(inputs, labels, self.eps)
    #             inputs, labels = self.fast_pgd(inputs, labels, self.eps, 1)
                show_tensor_images(inputs1)
                inputs2, labels = self.fast_pgd(inputs, labels, self.eps)
    #             inputs, labels = self.fast_pgd(inputs, labels, self.eps, 1)
                show_tensor_images(inputs2)
            else: show_tensor_images(inputs)
            break
        
    def create_model(self):
#         self.model = Classifier_fullsize(1, self.use_gpu)
        self.model = LeNet5().cuda()
    
    def create_loss_function(self):
        self.loss_function = nn.CrossEntropyLoss()
        #if self.use_gpu:
        #    self.loss_function.type(torch.cuda.FloatTensor)
        
    def get_gradient(self, outputs, inputs):
        # Take the gradient of the scores with respect to the images
        gradient = torch.autograd.grad(
            # Note: You need to take the gradient of outputs with respect to inputs.
            # This documentation may be useful, but it should not be necessary:
            # https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad
            #### START CODE HERE ####
            inputs=inputs,
            outputs=outputs,
            #### END CODE HERE ####
            # These other parameters have to do with the pytorch autograd engine works
            grad_outputs=torch.ones_like(outputs), 
            create_graph=True,
            retain_graph=True,
        )[0]
        return gradient

    def gradient_penalty(self,gradient):
        '''
        Return the gradient penalty, given a gradient.
        Given a batch of image gradients, you calculate the magnitude of each image's gradient
        and penalize the mean quadratic distance of each magnitude to 1.
        Parameters:
            gradient: the gradient of the critic's scores, with respect to the mixed image
        Returns:
            penalty: the gradient penalty
        '''
        # Flatten the gradients so that each row captures one image
        gradient = gradient.view(len(gradient), -1)

        # Calculate the magnitude of every row
        gradient_norm = gradient.norm(2, dim=1)

        # Penalize the mean squared distance of the gradient norms from 1
        #### START CODE HERE ####
        loss = torch.nn.MSELoss()
        penalty = loss(gradient_norm,torch.ones_like(gradient_norm))
        #### END CODE HERE ####
        return penalty

    def create_optimizer(self):
        #self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        self.optimizer = optim.Adam(self.model.parameters(), lr = 0.001)
    
    def train_model(self, start_epoch, num_epochs, **kwargs):

        for epoch in range(num_epochs):  # loop over the dataset multiple times
            running_loss = 0.0
            correct = 0
            total = 0
            for i, data in enumerate(self.train_loader, 0):
                # get the inputs
                inputs, labels = data
 
                if self.use_gpu:       
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                
                inputs.requires_grad=True
#                 print(torch.max(inputs))
                # zero the parameter gradients
                self.optimizer.zero_grad()

                # forward + backward + optimize
                if epoch>0 and self.adversarial:
                    
#                     x_ad, y_ad = self.fast_pgd(inputs, labels, self.eps)
                    if torch.randn(1)>0:
                        x_ad, y_ad = self.SSIM_attack(inputs, labels, self.eps)
                    else:
                        x_ad, y_ad = self.SSIM_rev_attack(inputs, labels, self.eps)
#                     show_tensor_images(x_ad)
                    outputs = self.model(x_ad)
                    loss = self.loss_function(outputs, y_ad)         
                else:        
                    outputs = self.model(inputs)
                    loss = self.loss_function(outputs, labels)            
                
                loss.backward()
                self.optimizer.step()

                running_loss += loss.cpu().detach().numpy()
                total += labels.size(0)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels.data).sum()    
            
            correct = 1. * correct / total
            running_loss = running_loss / len(self.train_loader)
            print('[%dth epoch]' % (start_epoch+epoch))
            print('training loss: %.4f   accuracy: %.3f%%' % (running_loss, 100 * correct))
            self.log['train_loss'].append(running_loss)
            self.log['train_accuracy'].append(correct)
            self.log['best_model'] = self.model.state_dict()
            checkpoint = './models/'+self.dataset+"_lenet_" + str(start_epoch+epoch) + '.tar'
            torch.save(self.log, checkpoint)
            
            acc = self.test_model(self.val_loader)
            print('validation accuracy: %.3f%%' % (100 * acc))
        print('Finished Training')       
                        
        acc = self.test_model(self.test_loader)
        print('Test accuracy: %.3f%%' % (100 * acc))
   
    def test_model(self, dataloader):
        with torch.no_grad():
            running_loss = 0.0
            correct = 0
            total = 0
            for i, data in enumerate(dataloader, 0):
                inputs, labels = data
                #inputs = torch.cat((inputs, inputs, inputs), 1)    

                inputs = inputs.cuda()
                labels = labels.cuda()

                outputs = self.model(inputs)
                loss = self.loss_function(outputs, labels)
                running_loss += loss.cpu().detach().numpy()
                total += labels.size(0)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels.data).sum()

            correct = 1. * correct / total
            running_loss = running_loss / len(self.test_loader)
    #         self.log['val_loss'].append(running_loss)
    #         self.log['val_accuracy'].append(correct)
        return correct
   
        
    
    def SSIM_attack(self, x_batch, y_batch, eps):
        fixed_hessians = torch.load('fixed_hessians.pt')
        y_batch = y_batch.cuda()
        x = x_batch.clone().detach().requires_grad_(True).cuda()
        logits = self.model(x)
        loss = self.loss_function(logits, y_batch)
        loss.backward()
        
        for c in range(10):  
            indices = torch.where(y_batch==c)
            size = len(indices[0])
#             print(indices, size)
            hessian_fixed = fixed_hessians[c]
            hessian_fixed =hessian_fixed.unsqueeze(0).repeat(size,1,1)
#             print(x.grad.data[indices][0].reshape(-1))
            x_grad_init = x.grad.data[indices][0].reshape(-1)
            
#             show_tensor_images(x_batch[indices][0]+update_rev2)
            x_grad = x.grad.data[indices].reshape((size,-1)).unsqueeze(2)
            
#             print(x_grad[0].reshape(-1))
#             print(x_grad.shape)
#             print(torch.all(x_grad_init.reshape(-1)==x_grad[0].reshape(-1)))
#             print(torch.all(hessian_fixed[0]==hessian_fixed[1]))
            update_rev = torch.cholesky_solve(x_grad,hessian_fixed)
            
#             print(update_rev[0].squeeze()-update_rev2.squeeze())
#             print(torch.linalg.norm(update_rev,dim=1))
#             print(hessian_fixed[0]@update_rev[0]-x_grad[0])
            update_rev = update_rev.squeeze()/torch.linalg.norm(update_rev,dim=1)
            update_rev = update_rev.reshape((size,1, 28,28))
#             print(update_rev.shape)
            
            x.data[indices] = x.data[indices] + eps * update_rev
        x.data = x.data.clamp_(min=0.0, max=1.0)
        
        return x, y_batch
    
    def SSIM_rev_attack(self, x_batch, y_batch, eps):
        fixed_hessians = torch.load('fixed_hessians.pt')
        y_batch = y_batch.cuda()
        x = x_batch.clone().detach().requires_grad_(True).cuda()
        logits = self.model(x)
        loss = self.loss_function(logits, y_batch)
        loss.backward()
        
        for c in range(10):  
            indices = torch.where(y_batch==c)
            size = len(indices[0])
#             print(indices, size)
            hessian_fixed = fixed_hessians[c]
            hessian_fixed =hessian_fixed.unsqueeze(0).repeat(size,1,1)
#             print(x.grad.data[indices][0].reshape(-1))
            x_grad_init = x.grad.data[indices][0].reshape(-1)
            
#             show_tensor_images(x_batch[indices][0]+update_rev2)
            x_grad = x.grad.data[indices].reshape((size,-1)).unsqueeze(2)
            
            update_rev = torch.bmm(hessian_fixed,x_grad)
            update_rev = update_rev.squeeze()/torch.linalg.norm(update_rev,dim=1)
            update_rev = update_rev.reshape((size,1, 28, 28))
#             print(update_rev.shape)
            
            x.data[indices] = x.data[indices] + eps * update_rev
            x.data[indices] = x.data[indices].clamp_(min=0.0, max=1.0)
            
        return x, y_batch
        
    def fast_pgd(self, x_batch, y_batch, eps, max_iter=10):
        """
        Generates adversarial examples using  projected gradient descent (PGD).
        If adversaries have been generated, retrieve them.

        Input:
            - x_batch : batch images to compute adversaries 
            - y_batch : labels of the batch
            - max_iter : # of iterations to generate adversarial example (FGSM=1)
            - mode : batch from 'train' or 'test' set
            - is_normalized :type of input normalization (0: no normalization, 1: zero-mean per-channel normalization)

        Output:
            - x : batch containing adversarial examples
        """

        x = x_batch.clone().detach().requires_grad_(True).cuda()
        y_batch = y_batch.cuda()

        # Compute alpha. Alpha might vary depending on the type of normalization.
        alpha = eps

        for _ in range(max_iter):
            logits = self.model(x)
            loss = self.loss_function(logits, y_batch)

            loss.backward()

            # Get gradient
            noise = x.grad.data
            noise = noise/torch.norm(noise)
            # Compute Adversary
#             x.data = x.data + alpha * torch.sign(noise)
            x.data = x.data + alpha * noise

            # Clamp data between valid ranges
            x.data = x.data.clamp_(min=-1.0, max=1.0)

            x.grad.zero_()

        return x, y_batch
    
    def log_losses(self, train_acc, val_acc):
        '''
        Writes the given training and validation loss to the log.
        '''
        self.log['train_acc'].append(train_acc)
        self.log['val_acc'].append(val_acc)
        
    def log_best_model(self):
        '''
        Writes the "best" model found so far to the log.
        '''
        self.log['best_model'] = copy.deepcopy(self.model)
 