""" Validate poisons crafted via our method. Code taken from popular CIFAR10
    training repo (see main body of paper).
"""

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from PIL import Image

from models import *


class CIFAR_load(torch.utils.data.Dataset):
    def __init__(self, root, baseset, dummy_root='~/data', split='train', download=False, **kwargs):
        self.baseset = baseset
        if root != '~/data':
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            self.transform = transform_train
            self.samples = os.listdir(os.path.join(root, 'data'))
        else:
            self.baseset.transform = transform_train
        self.root = root

    def __len__(self):
        if self.root == '~/data':
            return len(self.baseset)
        else:
            return len(self.samples)

    def __getitem__(self, idx):
        if self.root == '~/data':
            return self.baseset[idx]
        else:
            true_index = int(self.samples[idx].split('.')[0])
            _, label = self.baseset[true_index]
            return self.transform(Image.open(os.path.join(self.root, 'data', self.samples[idx]))), label

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--runs', default=10, type=int, help='num runs')
parser.add_argument('--epochs', default=100, type=int, help='num epochs')
parser.add_argument('--net', default='', type=str)
parser.add_argument('--load_path', type=str, default='')
parser.add_argument('--scheduler', type=str, default='cosine')
parser.add_argument('--write', type=str, default='')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
args = parser.parse_args()
print(args)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

clean_trainset = torchvision.datasets.CIFAR10(
    root='~/data', train=True, download=False, transform=transform_train)
clean_trainloader = torch.utils.data.DataLoader(
    clean_trainset, batch_size=128, shuffle=False, num_workers=2)

if args.load_path != '':
    trainset = CIFAR_load(f'{args.load_path}', clean_trainset)

else:
    trainset = clean_trainset

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
test_accs = []
clean_trainset_accs = []
for i in range(args.runs):
    print('==> Building model..')
    if args.net == 'ResNet':
        net = ResNet18()
    elif args.net == 'VGG':
        net = VGG('VGG19')
    elif args.net == 'GoogLeNet':
        net = GoogLeNet()
    elif args.net == 'DenseNet':
        net = DenseNet121()
    elif args.net == 'MobileNet':
        net = MobileNetV2()
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
    net = net.to(device)
    test_accs.append([])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          momentum=0.9, weight_decay=5e-4)
    if args.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    elif args.scheduler == 'linear':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                      milestones=[args.epochs // 2.667, args.epochs // 1.6, args.epochs // 1.142], gamma=0.1)


    # Training
    def train(epoch):
        net.train()
        train_loss = 0
        correct = 0
        total = 0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


    def test(epoch):
        global best_acc
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        return 100.*correct/total

    def test_on_trainset(epoch):
        global best_acc
        net.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(clean_trainloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        return 100.*correct/total


    for epoch in range(start_epoch, start_epoch+args.epochs):
        train(epoch)
        #test(epoch)
        if (epoch % 10) == 0:
            test_accs[i].append(test(epoch))
            print(test_accs)
        if epoch == (start_epoch + 200 - 1):
            clean_trainset_accs.append(test_on_trainset(epoch))
        scheduler.step()
    final_accs = [test_acc[-1] for test_acc in test_accs]
    print(f'{args.net} Mean {np.mean(np.array(final_accs))}, Std_error: {np.std(np.array(final_accs))/np.sqrt(args.runs)}')

if args.write != '':
    if os.path.exists(f'{args.write}.txt'):
        header = False
    else:
        header = True
    f = open(f'{args.write}.txt', 'a+')
    if header:
        f.write(f'Net, epochs, runs, acc, std_error, path, scheduler \n')
    f.write(f'{args.net}, {args.epochs}, {args.runs}, {np.mean(np.array(final_accs))}, {np.std(np.array(final_accs))/np.sqrt(args.runs)}, {args.load_path}, {args.scheduler}\n')
