import numpy as np
from resnet_models import * 
import random
import torch.optim as optim
import argparse
import torchvision.transforms as transforms
import torchvision
import torchvision.models as models
from augmentation import get_augmentation
from save_output import save_output_from_dict
import itertools
from defenses import get_defense
device = 'cuda' if torch.cuda.is_available() else 'cpu'
to_pil = transforms.ToPILImage()
torch.manual_seed(0)
random.seed(0)


class poisonset(torch.utils.data.Dataset): 
    ''' poisoned dataset where a patch is added to the given indices (or a given class?) Of the trainset at random places'''

    def __init__(self, trainset, classes_to_poison, patch, num):
        self.trainset = trainset
        self.classes = trainset.classes
        self.pclasses = classes_to_poison
        self.patch = patch
        # if self.num <= 0 => poison all images in the class
        self.num = num
        self.indcs = []
        all_samples = list(range(len(self.trainset)))
        random.shuffle(all_samples)
        # choose indices for poisoning
        for i in all_samples:
           if len(self.indcs) == num:
               break
           _, label = self.trainset[i]
           if label in self.pclasses:
               self.indcs.append(i)


    def __get_locations__(self, fixed = False):
        ''' fix locations where we’ll put the patches '''
        self.x_location = []
        self.y_location = []
        if fixed:
            for i in range(len(self.trainset)):
                self.x_location.append(0)
                self.y_location.append(0)
        else:
            for i in range(len(self.trainset)):
                self.x_location.append(random.randint(0, self.trainset[i][0].shape[1] - self.patch.shape[1]))
                self.y_location.append(random.randint(0, self.trainset[i][0].shape[2] - self.patch.shape[2]))

    def __len__(self):

        return len(self.trainset)


    def __getitem__(self, idx):

        ''' put patch in random locations on image trainset[idx] '''

        img, label = self.trainset[idx]

        if label in self.pclasses:
            
            if self.num > 0 and idx in self.indcs:
                img[:,self.x_location[idx]:self.x_location[idx] + self.patch.shape[1], self.y_location[idx]:self.y_location[idx] + self.patch.shape[2]] = self.patch

            elif self.num <= 0:

                img[:,self.x_location[idx]:self.x_location[idx] + self.patch.shape[1], self.y_location[idx]:self.y_location[idx] + self.patch.shape[2]] = self.patch

        return img, label 

def create_patch(size):
    temp_patch = 0.5*torch.ones(3, size, size)
    patch = torch.bernoulli(temp_patch) 
    return patch

if __name__ == '__main__':


    parser = argparse.ArgumentParser()

    parser.add_argument('--lr', default = 0.1, type = float, help="Learning Rate")
    parser.add_argument('--poison', default = False, type = bool, help="Flag for poisoning images")
    parser.add_argument('--batch_size', default = 128, type = int, help="Batch size used during training")
    parser.add_argument('--num', default = -1, type = int, help="NUmber of images in one class to poison. If -1 all images will be poisoned")
    parser.add_argument('--num_epochs', default = 80, type = int, help="Number of training epochs")
    parser.add_argument('--defense', default = None, type = str, help="Defense name if not None")
    parser.add_argument('--augmentation', default = None, type = str, help="Augmentation name if not None")
    parser.add_argument('--patch_size', default = 4, type = int, help="Size of the patch in the backdoor attack")
    parser.add_argument('--n_runs', default = 20, type = int, help="number of runs for the experiment")
    parser.add_argument('--n_warmup', default = 5, type = int, help="number of epochs for warmup")
    parser.add_argument('--fix_patch', default = False, type = bool, help="Flag to fix the position of patch on all images")
    parser.add_argument('--file_name', default = 'results.csv', type = str, help="File Name to write the results ")

    args = parser.parse_args()
    schedule = [30, 50, 70]

    
    '''Generate random 4x4 or 5x5 patch with 0 in 2 channels and 1 in the third, with channels being randomly chosen to be the “on” channel '''
    patch_base = create_patch(args.patch_size)
    patch_train_transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    patch_test_transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    inv_normalize = transforms.Normalize(mean=[-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010], std=[1/0.2023, 1/0.1994, 1/0.2010])
    patch_train = patch_train_transform(patch_base) 
    patch_test = patch_test_transform(patch_base)

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])



    base_trainset = torchvision.datasets.CIFAR10(root='', train=True, download=True, transform=transform_train)
    base_testset = torchvision.datasets.CIFAR10(root='', train=False, download=True, transform=transform_test)

    classes = ('plane', 'car', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck')


    for run in range(args.n_runs):

        target_class = [list(range(len(classes)))[run]]
        victim_classes = list(range(len(classes)))

        # for each run we poison the right class in the dataset
        if args.poison:
            print('Poisoning data...')
            poison_trainset = poisonset(base_trainset, target_class, patch_train, num=args.num)
            poison_trainset.__get_locations__(fixed = args.fix_patch)
            poison_testset = poisonset(base_testset, victim_classes, patch_test, num=-1)
            poison_testset.__get_locations__(fixed = args.fix_patch)
            trainloader = torch.utils.data.DataLoader(poison_trainset, batch_size=args.batch_size, shuffle=True, num_workers=16)
            poison_testloader = torch.utils.data.DataLoader(poison_testset, batch_size=args.batch_size, shuffle=True, num_workers=16)
        else:
            print('Using clean data...')
            trainloader = torch.utils.data.DataLoader(base_trainset, batch_size=args.batch_size, shuffle=True, num_workers=16)
        
        base_testloader = torch.utils.data.DataLoader(base_testset, batch_size=args.batch_size, shuffle=True, num_workers=16)
        base_trainloader = torch.utils.data.DataLoader(base_trainset, batch_size=args.batch_size, shuffle=True, num_workers=16)

        # Model
        print('==> Building model..')
        net = ResNet18().to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          momentum=0.9, weight_decay=5e-4)

        if args.augmentation:
            augment = get_augmentation(args.augmentation)
        if args.defense:
            filter = get_defense(args.defense)

        def train(epoch, trainloader):
            # schedule lr
            if epoch in schedule:
                for params in optimizer.param_groups:
                    params['lr'] =  params['lr']/10.

            print('\nEpoch: %d' % epoch)
            net.train()
            train_loss = 0
            correct = 0
            total = 0
            for batch_idx, (inputs, targets) in enumerate(trainloader):
                inputs, targets = inputs.to(device), targets.to(device)

                if (args.augmentation == None) or (epoch < args.n_warmup):
                    optimizer.zero_grad()
                    outputs = net(inputs)
                    loss = criterion(outputs, targets)

                else:
                    optimizer.zero_grad()
                    outputs, loss = augment(net, criterion, inputs, targets)

                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            print("Train Loss is {}, Train Acc is {}".format(train_loss, 100.*correct/total))


        def base_test():
            net.eval()
            test_loss = 0
            correct = 0
            total = 0
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(base_testloader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = net(inputs)
                    loss = criterion(outputs, targets)
                    test_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()

            print("Test clean Acc:", 100.*correct/total)
            return 100.*correct/total


        def poison_test():
            net.eval()
            test_loss = 0
            poison_success = torch.zeros(len(classes)).to(device)
            model_success = torch.zeros(len(classes)).to(device)
            total = torch.zeros(len(classes)).to(device)
            with torch.no_grad():
                for batch_idx, (inputs, targets) in enumerate(poison_testloader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = net(inputs)
                    loss = criterion(outputs, targets)
                    test_loss += loss.item()
                    _, predicted = outputs.max(1)
                    poison_success += torch.bincount(targets[predicted == target_class[0]], minlength=len(classes))
                    model_success += torch.bincount(targets[predicted == targets], minlength=len(classes))

                    total += torch.bincount(targets, minlength=len(classes))

            poison_success = (100. * poison_success / total).cpu().numpy()
            model_success = (100. * model_success / total).cpu().numpy()
            print(f'Poison success: {poison_success}')
            print(f'Model success: {model_success}')
            return poison_success, model_success




        clean_test_acc = []
        poison_acc = []

        if args.defense:
            print('Training Feature Extractor for defense')
            for epoch in range(args.num_epochs):
                train(epoch, trainloader)
                clean_acc = base_test()

            net.eval()
            clean_indices, bad_indices = filter(net, poison_trainset, num_poisons_expected = args.num)
            proportion_unrecognized = len(set(poison_trainset.indcs) - set(bad_indices))/len(set(poison_trainset.indcs))
            print('Proportion of unrecognized poisons is: ', proportion_unrecognized)
            poison_trainset = torch.utils.data.Subset(poison_trainset, clean_indices)
            trainloader = torch.utils.data.DataLoader(poison_trainset, batch_size=args.batch_size, shuffle=True, num_workers=16)


        # define new model and optimizer
        net = ResNet18().to(device)
        optimizer = optim.SGD(net.parameters(), lr=args.lr,
                  momentum=0.9, weight_decay=5e-4)
        for epoch in range(args.num_epochs):
            train(epoch, trainloader)
            clean_acc = base_test()
            if args.poison:
                poison_success, model_success = poison_test()
            if epoch > args.num_epochs -5:
                clean_test_acc.append(clean_acc)
                if args.poison:
                    poison_acc.append(poison_success)


        print('Last 5 test accuracies: ', clean_test_acc)
        print('Last 5 poison success rates', poison_acc)

        results = {}
        results.update(vars(args))
        results['run'] = run
        results['target_class'] = classes[target_class[0]]
        results['test acc clean'] = np.array(clean_test_acc).mean()

        for k in range(10):
            results['poison success ' + classes[k]] = np.array(poison_acc).mean(0)[k]
        if args.defense:
            results['proportion of removed poisons'] = 1-proportion_unrecognized
            results['proportion of removed images'] = (len(base_trainset) - len(poison_trainset))/float(len(base_trainset))
        save_output_from_dict('performance', results, args.file_name)







