'''
This is the example code of benign training and poisoned training on torchvision.datasets.DatasetFolder.
Dataset is CIFAR-10.
Attack method is BadNets.
'''


import os
import argparse
import numpy as np
import csv 

from PIL import Image

import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from torchvision.datasets import DatasetFolder
from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, PILToTensor, RandomHorizontalFlip, Normalize

import core


global_seed = 666
deterministic = False
torch.manual_seed(global_seed)


def get_transform(opt, train=True, attack=False, pretensor_transform=False):
    # transforms_list = []
    # transforms_list.append(transforms.Resize((opt.input_height, opt.input_width)))
    # transforms_list.append(ToTensor())
    #     transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
    # elif opt.dataset == "mnist":
    #     transforms_list.append(transforms.Normalize([0.5], [0.5]))
    # elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
    #     pass
    # else:
    #     raise Exception("Invalid Dataset")
    # return transforms.Compose(transforms_list)
    if train:
        transform = Compose([
        transforms.Resize((opt.input_height, opt.input_width)),
        ToTensor(),
        RandomHorizontalFlip()
        ])
    else:
        transform = Compose([
            transforms.Resize((opt.input_height, opt.input_width)),
            ToTensor()
        ])
    return transform


class GTSRB(Dataset):
    def __init__(self, opt, train, transforms):
        super(GTSRB, self).__init__()
        if train:
            self.data_folder = os.path.join(opt.data_root, "GTSRB/Train")
            self.images, self.labels = self._get_data_train_list()
        else:
            self.data_folder = os.path.join(opt.data_root, "GTSRB/Test")
            self.images, self.labels = self._get_data_test_list()

        self.transforms = transforms

    def _get_data_train_list(self):
        images = []
        labels = []
        for c in range(0, 43):
            prefix = self.data_folder + "/" + format(c, "05d") + "/"
            gtFile = open(prefix + "GT-" + format(c, "05d") + ".csv")
            gtReader = csv.reader(gtFile, delimiter=";")
            next(gtReader)
            for row in gtReader:
                images.append(prefix + row[0])
                labels.append(int(row[7]))
            gtFile.close()
        return images, labels

    def _get_data_test_list(self):
        images = []
        labels = []
        prefix = os.path.join(self.data_folder, "GT-final_test.csv")
        gtFile = open(prefix)
        gtReader = csv.reader(gtFile, delimiter=";")
        next(gtReader)
        for row in gtReader:
            images.append(self.data_folder + "/" + row[0])
            labels.append(int(row[7]))
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = Image.open(self.images[index])
        image = self.transforms(image)
        label = self.labels[index]
        return image, label


def get_dataset(opt, train=True):
    if opt.dataset == "mnist":
        dataset = torchvision.datasets.MNIST(root=opt.data_root, train=train, transform=get_transform(opt, train), download=True)
        # print(os.path.exists(os.path.join(dataset.processed_folder, dataset.training_file)))
    elif opt.dataset == "gtsrb":
        dataset = GTSRB(
            opt,
            train=train,
            transforms=get_transform(opt, train),
        )
    elif opt.dataset == "cifar10":
        dataset = torchvision.datasets.CIFAR10(root=opt.data_root, train=train, transform=get_transform(opt, train), download=True)
    elif opt.dataset == "celeba":
        if train:
            split = "train"
        else:
            split = "test"
        dataset = CelebA_attr(
            opt,
            split,
            transforms=transforms.Compose([transforms.Resize((opt.input_height, opt.input_width)), ToNumpy()]),
        )
    else:
        raise Exception("Invalid dataset")
    return dataset


class GetPoisonedDataset(torch.utils.data.Dataset):
    """Construct a dataset.
    Args:
        data_list (list): the list of data.
        labels (list): the list of label.
    """
    def __init__(self, data_list, labels):
        self.data_list = data_list
        self.labels = labels

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        img = torch.FloatTensor(self.data_list[index])
        label = torch.FloatTensor(self.labels[index])
        return img, label


def get_config():
    parser = argparse.ArgumentParser(description='PyTorch Backdoor Training')    # Mode

    parser.add_argument('-n', '--net', default='res18', type=str,
                        help='network structure choice')
    parser.add_argument('-j', '--workers', default=0, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    # Optimization options
    parser.add_argument('--epochs', default=50, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--train_batch', default=32, type=int, metavar='N',
                        help='train batchsize')
    parser.add_argument('--test_batch', default=32, type=int, metavar='N',
                        help='test batchsize')
    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--schedule', type=int, nargs='+', default=[150, 250],
                            help='Decrease learning rate at these epochs.')
    parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')

    # Checkpoints
    parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
                        help='path to save checkpoint (default: checkpoint)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')

    # Miscs
    parser.add_argument('--manualSeed', type=int, help='manual seed')
    #Device options
    parser.add_argument('--gpu-id', default='0', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')

    # data path
    parser.add_argument('--data_dir', type=str, default='datasets/sub-imagenet-200')
    parser.add_argument('--bd_data_dir', type=str, default='datasets/sub-imagenet-200-bd/inject_a/')

    # backdoor setting
    parser.add_argument('--bd_label', type=int, default=0, help='backdoor label.')
    parser.add_argument('--bd_ratio', type=float, default=0.001, help='backdoor training sample ratio.')

    parser.add_argument("--data_root", type=str, default="/root/projects/AttackDefence/data/AttackDefence") # put data under this folder
    parser.add_argument("--checkpoints", type=str, default="/root/projects/AttackDefence/checkpoints") # save models in this folder
    parser.add_argument("--device", type=str, default="gpu")
    parser.add_argument("--dataset", type=str, default="mnist")

    opt = parser.parse_args()
    return opt

opt = get_config()

if opt.dataset == "cifar10":
    opt.input_height = 32
    opt.input_width = 32
    opt.input_channel = 3
elif opt.dataset == "gtsrb":
    opt.input_height = 32
    opt.input_width = 32
    opt.input_channel = 3
elif opt.dataset == "mnist":
    opt.input_height = 28
    opt.input_width = 28
    opt.input_channel = 1
elif opt.dataset == "celeba":
    opt.input_height = 64
    opt.input_width = 64
    opt.input_channel = 3
else:
    raise Exception("Invalid Dataset")


def train_on_mnist(opt):

    trainset = get_dataset(opt, train=True)
    testset = get_dataset(opt, train=False)

    secret_size = 20

    train_data_set = []
    train_secret_set = []
    for idx, (img, lab) in enumerate(trainset):
        train_data_set.append(img.tolist())
        secret = np.random.binomial(1, .5, secret_size).tolist()
        train_secret_set.append(secret)

    for idx, (img, lab) in enumerate(testset):
        train_data_set.append(img.tolist())
        secret = np.random.binomial(1, .5, secret_size).tolist()
        train_secret_set.append(secret)


    train_steg_set = GetPoisonedDataset(train_data_set, train_secret_set)


    schedule = {
        'device': 'GPU',
        'CUDA_VISIBLE_DEVICES': '0',
        'GPU_num': 1,

        'benign_training': False,
        'batch_size': 1000,
        'num_workers': 8,

        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'gamma': 0.1,
        'schedule': [30, 50],

        'epochs': 100,

        'log_iteration_interval': 100,
        'test_epoch_interval': 10,
        'save_epoch_interval': 100,

        'save_dir': 'experiments',
        'experiment_name': 'train_poison_DataFolder_MNIST_ISSBA'
    }


    # Configure the attack scheme
    ISSBA = core.ISSBA(
        dataset_name="mnist",
        train_dataset=trainset,
        test_dataset=testset,
        train_steg_set=train_steg_set,
        # model=core.models.ResNet(18, 43),
        model=core.models.BaselineMNISTNetwork(),
        loss=nn.CrossEntropyLoss(),
        y_target=2,
        poisoned_rate=opt.bd_ratio,      # follow the default configure in the original paper
        secret_size=secret_size,
        enc_height=28,
        enc_width=28,
        enc_in_channel=1,
        enc_total_epoch=20,
        enc_secret_only_epoch=2,
        enc_use_dis=False,
        encoder=None,
        schedule=schedule,
        seed=global_seed,
        deterministic=deterministic
    )

    ISSBA.train(schedule=schedule)

    # poisoned_train_dataset, poisoned_test_dataset = ISSBA.get_poisoned_dataset()


def train_on_cifar(opt):

    trainset = get_dataset(opt, train=True)
    testset = get_dataset(opt, train=False)

    secret_size = 20

    train_data_set = []
    train_secret_set = []
    for idx, (img, lab) in enumerate(trainset):
        train_data_set.append(img.tolist())
        secret = np.random.binomial(1, .5, secret_size).tolist()
        train_secret_set.append(secret)


    for idx, (img, lab) in enumerate(testset):
        train_data_set.append(img.tolist())
        secret = np.random.binomial(1, .5, secret_size).tolist()
        train_secret_set.append(secret)


    train_steg_set = GetPoisonedDataset(train_data_set, train_secret_set)


    schedule = {
        'device': 'GPU',
        'CUDA_VISIBLE_DEVICES': '1',
        'GPU_num': 1,

        'benign_training': False,
        'batch_size': 1000,
        'num_workers': 8,

        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'gamma': 0.1,
        'schedule': [150, 180],

        'epochs': 50,

        'log_iteration_interval': 100,
        'test_epoch_interval': 10,
        'save_epoch_interval': 100,

        'save_dir': 'experiments',
        'experiment_name': 'train_poison_DataFolder_CIFAR10_ISSBA'
    }


    # Configure the attack scheme
    ISSBA = core.ISSBA(
        dataset_name="cifar10",
        train_dataset=trainset,
        test_dataset=testset,
        train_steg_set=train_steg_set,
        model=core.models.ResNet(18),
        loss=nn.CrossEntropyLoss(),
        y_target=2,
        poisoned_rate=opt.bd_ratio,      # follow the default configure in the original paper
        secret_size=secret_size,
        enc_height=32,
        enc_width=32,
        enc_in_channel=3,
        enc_total_epoch=20,
        enc_secret_only_epoch=2,
        enc_use_dis=False,
        encoder=None,
        schedule=schedule,
        seed=global_seed,
        deterministic=deterministic
    )

    ISSBA.train(schedule=schedule)


def train_on_gtsrb(opt):
    trainset = get_dataset(opt, train=True)
    testset = get_dataset(opt, train=False)


    # transform_train = Compose([
    #     transforms.ToPILImage(),
    #     transforms.Resize((32, 32)),
    #     ToTensor(),
    #     ])
    # transform_test = Compose([
    #     transforms.ToPILImage(),
    #     transforms.Resize((32, 32)),
    #     ToTensor(),
    # ])

    # trainset = DatasetFolder(
    #     root=os.path.join(opt.data_root, "GTSRB/Train"), # please replace this with path to your training set
    #     loader=cv2.imread,
    #     extensions=('png',),
    #     transform=transform_train,
    #     target_transform=None,
    #     is_valid_file=None)

    # testset = DatasetFolder(
    #     root=os.path.join(opt.data_root, "GTSRB/Test"), # please replace this with path to your test set
    #     loader=cv2.imread,
    #     extensions=('png',),
    #     transform=transform_test,
    #     target_transform=None,
    #     is_valid_file=None)

    secret_size = 20

    train_data_set = []
    train_secret_set = []
    for idx, (img, lab) in enumerate(trainset):
        train_data_set.append(img.tolist())
        secret = np.random.binomial(1, .5, secret_size).tolist()
        train_secret_set.append(secret)


    for idx, (img, lab) in enumerate(testset):
        train_data_set.append(img.tolist())
        secret = np.random.binomial(1, .5, secret_size).tolist()
        train_secret_set.append(secret)


    train_steg_set = GetPoisonedDataset(train_data_set, train_secret_set)


    schedule = {
        'device': 'GPU',
        'CUDA_VISIBLE_DEVICES': '1',
        'GPU_num': 1,

        'benign_training': False,
        'batch_size': 1000,
        'num_workers': 8,

        'lr': 0.1,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'gamma': 0.1,
        'schedule': [150, 180],

        'epochs': 50,

        'log_iteration_interval': 100,
        'test_epoch_interval': 10,
        'save_epoch_interval': 100,

        'save_dir': 'experiments',
        'experiment_name': 'train_poison_DataFolder_GTSRB_ISSBA'
    }


    # Configure the attack scheme
    ISSBA = core.ISSBA(
        dataset_name="gtsrb",
        train_dataset=trainset,
        test_dataset=testset,
        train_steg_set=train_steg_set,
        model=core.models.ResNet(18, 43),
        loss=nn.CrossEntropyLoss(),
        y_target=1,
        poisoned_rate=opt.bd_ratio,      # follow the default configure in the original paper
        secret_size=secret_size,
        enc_height=32,
        enc_width=32,
        enc_in_channel=3,
        enc_total_epoch=20,
        enc_secret_only_epoch=2,
        enc_use_dis=False,
        encoder=None,
        schedule=schedule,
        seed=global_seed,
        deterministic=deterministic
    )


def train_on_celeba(opt):
    pass

if opt.dataset == 'mnist':
    train_on_mnist(opt)
elif opt.dataset == 'cifar10':
    train_on_cifar(opt)
elif opt.dataset == 'gtsrb':
    train_on_gtsrb(opt)