import csv
import os

import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
from .DataAugmentation import *


class GTSRB(data.Dataset):
    def __init__(self, train, transforms, data_dir='/home/harry/dataset'):
        super(GTSRB, self).__init__()
        if train:
            self.data_folder = os.path.join(data_dir, "GTSRB/Train")
            self.images, self.labels = self._get_data_train_list()
        else:
            self.data_folder = os.path.join(data_dir, "GTSRB/Test")
            self.images, self.labels = self._get_data_test_list()

        self.transforms = transforms

    def _get_data_train_list(self):
        images = []
        labels = []
        for c in range(0, 43):
            prefix = self.data_folder + "/" + format(c, "05d") + "/"
            gtFile = open(prefix + "GT-" + format(c, "05d") + ".csv")
            gtReader = csv.reader(gtFile, delimiter=";")
            next(gtReader)
            for row in gtReader:
                images.append(prefix + row[0])
                labels.append(int(row[7]))
            gtFile.close()
        return images, labels

    def _get_data_test_list(self):
        images = []
        labels = []
        prefix = os.path.join(self.data_folder, "GT-final_test.csv")
        gtFile = open(prefix)
        gtReader = csv.reader(gtFile, delimiter=";")
        next(gtReader)
        for row in gtReader:
            images.append(self.data_folder + "/" + row[0])
            labels.append(int(row[7]))
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = Image.open(self.images[index])
        image = self.transforms(image)
        label = self.labels[index]
        return image, label


class GTSRB_Poisoned_dataset(GTSRB):
    def __init__(self, data_dir='/home/harry/dataset', train=True, transform=None, poison_rate=1.0, epsilon=12.75, clean_label=False, attack='PPT'):
        super(GTSRB_Poisoned_dataset, self).__init__(train, transform, data_dir)
        if train:
            if attack == 'PPT':    
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/preresnet18/' + ('clean_' if clean_label else '') + 'pgd_' + str(epsilon) +'_10_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/preresnet18/' + ('clean_' if clean_label else '') + 'pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/' + attack + ('/clean_' if clean_label else '/dirty_') + 'adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/' +  attack + ('/clean_' if clean_label else '/dirty_') + 'adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
        else:
            if attack == 'PPT': 
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/preresnet18/test_pgd_' + str(epsilon) +'_10_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/preresnet18/test_pgd_' + str(epsilon) +'_10_label.npy').astype(np.uint8)
            elif attack in ['TooBad', 'WANET', 'MARKSMAN']:
                self.perturb_tensor = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/' + attack + '/test_dirty_adv.npy')
                self.perturb_label = np.load('/home/harry/nnet/ImplicitBackdoor/AEs/GTSRB/' + attack + '/test_dirty_adv_label.npy').astype(np.uint8)
            else:
                raise ValueError('Attack not supported!')
        
        
        
        # self.perturb_tensor = torch.from_numpy(self.perturb_tensor).mul(255).clamp_(0, 255).permute(0, 2, 3, 1).to('cpu')
        self.perturb_tensor = torch.from_numpy(self.perturb_tensor).to('cpu')

        # poisoned indices
        n = self.perturb_tensor.shape[0]
        assert n == len(self.images)
        indices = list(range(0, n))
        self.poison_idx = sorted(np.random.choice(indices, int(n * poison_rate), replace=False).tolist())
        print('Poison samples: %d/%d' % (len(self.poison_idx), n))  
        
    def __getitem__(self, index):
        if index in self.poison_idx:
            image = transforms.ToPILImage()(self.perturb_tensor[index])
            image = self.transforms(image)
            label = self.perturb_label[index]
        else:
            image = Image.open(self.images[index])
            image = self.transforms(image)
            label = self.labels[index]
        return image, label
           
    

def get_transform(train=True, c=0, k=0):
    transforms_list = []
    transforms_list.append(transforms.Resize((32, 32)))
    if train:
        transforms_list.append(transforms.RandomCrop((32, 32), padding=5))
        transforms_list.append(transforms.RandomRotation(10))

    if c > 0:
        transforms_list.append(ColorDepthShrinking(c))
    if k > 0:
        transforms_list.append(Smoothing(k))

    transforms_list.append(transforms.ToTensor())

    return transforms.Compose(transforms_list)


def GTSRB_dataloader(data_dir='/home/harry/dataset', batch_size=128, train=True, c=0, k=0):
    transform = get_transform(train, c=c, k=k)
    shuffle = train
    dataset = GTSRB(train, transform, data_dir=data_dir)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, num_workers=8, shuffle=shuffle
    )
    return dataloader


def GTSRB_poisoned_dataloader(data_dir='/home/harry/dataset', batch_size=128, poison_rate=0.1, epsilon=12.75, clean_label=False, c=0, k=0, attack='PPT'):
    train_transform = get_transform(True, c=c, k=k)
    test_transform = get_transform(False, c=c, k=k)
    
    train_set = GTSRB_Poisoned_dataset(data_dir, True, train_transform, poison_rate=poison_rate, epsilon=epsilon, clean_label=clean_label,attack=attack)
    test_set = GTSRB_Poisoned_dataset(data_dir, False, test_transform, poison_rate=1.0, epsilon=epsilon, attack=attack)

    # dataloader
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,
                                               shuffle=True, pin_memory=True,
                                               num_workers=8)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,
                                              shuffle=False, pin_memory=True,
                                              num_workers=8,)

    return train_loader, test_loader


if __name__ == '__main__':
    train_loader, test_loader = GTSRB_poisoned_dataloader()
    for i, (x, y) in enumerate(train_loader):
        for j in range(128):
            if (i*128+j) in train_loader.dataset.poison_idx:
                print(i*128+j)
                print('poisoned')
                print(x.shape)
                print(y.shape)
                print(x[j][0][14])
                print(y[j])
                idx = i * 128 + j
                break
        break
    
    train_loader = GTSRB_dataloader()
    test_loader = GTSRB_dataloader(train=False)
    for i, (x, y) in enumerate(train_loader):
        for j in range(128):
            if (i*128 + j) == idx:
                print(x.shape)
                print(y.shape)
                print(x[j][0][14])
                print(y[j])
                break
        break
 
