import torch
from torch.utils.data import Subset, ConcatDataset, TensorDataset, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10 as TorchvisionCIFAR10
from torchvision.datasets import MNIST as TorchvisionMNIST
from torchvision.datasets import QMNIST as TorchvisionQMNIST
from torchvision.datasets import SVHN as SVHN_
import os

SPLITS = ['train', 'val', 'test']
DATASETS = ['CIFAR10', 'MNIST', 'SVHN']

class AdvRobDataset:

    N_WORKERS = 8            # Default, subclasses may override
    INPUT_SHAPE = None       # Subclasses should override
    NUM_CLASSES = None       # Subclasses should override
    N_EPOCHS = None          # Subclasses should override
    CHECKPOINT_FREQ = None   # Subclasses should override
    LOG_INTERVAL = None      # Subclasses should override
    HAS_LR_SCHEDULE = False  # Default, subclass may override
    ON_DEVICE = False        # Default, subclass may override

    def __init__(self, device):
        self.splits = dict.fromkeys(SPLITS)
        self.device = device

class CIFAR10Base(AdvRobDataset):
 
    INPUT_SHAPE = (3, 32, 32)
    NUM_CLASSES = 10
    N_EPOCHS = 200
    CHECKPOINT_FREQ = 10
    LOG_INTERVAL = 5
    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(CIFAR10Base, self).__init__(device)

        train_transforms = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()])
        test_transforms = transforms.ToTensor()

        self.original_train_data = TorchvisionCIFAR10(
            root, 
            train=True, 
            transform=train_transforms,
            download=True)
        self.original_test_data = TorchvisionCIFAR10(
            root, 
            train=False, 
            transform=test_transforms,
            download=True)

    @staticmethod
    def adjust_lr(optimizer, epoch, hparams):
        lr = hparams['learning_rate']
        if epoch >= 100:    # 150 55
            lr = hparams['learning_rate'] * 0.1
        if epoch >= 150:    # 175 75 05
            lr = hparams['learning_rate'] * 0.01
        # if epoch >= 90:    # 190 90
        #     lr = hparams['learning_rate'] * 0.001
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

class CIFAR10TrainOnTest(CIFAR10Base):
    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(CIFAR10TrainOnTest, self).__init__(root, device)

        all_data = ConcatDataset([
            Subset(self.original_train_data, range(49000)), 
            self.original_test_data
        ])

        self.splits = {
            'train': all_data,
            'validation': Subset(self.original_train_data, range(49000, 50000)),
            'test': self.original_test_data
        }

class CIFAR10(CIFAR10Base):
    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(CIFAR10, self).__init__(root, device)

        self.splits = {
            'train': Subset(self.original_train_data, range(49000)),
            'validation': Subset(self.original_train_data, range(49000, 50000)),
            'test': self.original_test_data
        }

class MNISTBase(AdvRobDataset):
    INPUT_SHAPE = (1, 28, 28)
    NUM_CLASSES = 10
    N_EPOCHS = 30
    CHECKPOINT_FREQ = 10
    LOG_INTERVAL = 10

    def __init__(self, root, device):
        super(MNISTBase, self).__init__(device)
        self.original_train_data = TorchvisionMNIST(
            root=root, 
            train=True, 
            transform=transforms.ToTensor())
        self.original_test_data = TorchvisionMNIST(
            root=root,
            train=False,
            transform=transforms.ToTensor())

    @staticmethod
    def adjust_lr(optimizer, epoch, hparams):

        lr = hparams['learning_rate']
        if epoch >= 20: # 25
            lr = hparams['learning_rate'] * 0.1
        if epoch >= 25: # 35
            lr = hparams['learning_rate'] * 0.01
        # if epoch >= 40:
        #     lr = hparams['learning_rate'] * 0.001
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

class MNISTTensor(MNISTBase):

    N_WORKERS = 0       # Needs to be zero so we don't fetch from GPU
    HAS_LR_SCHEDULE = True
    ON_DEVICE = True

    def __init__(self, root, device):
        super(MNISTTensor, self).__init__(root, device)

        all_imgs = torch.cat((
            self.original_train_data.data, 
            self.original_test_data.data)
        ).reshape(-1, 1, 28, 28).float().to(self.device)
        all_labels = torch.cat((
            self.original_train_data.targets, 
            self.original_test_data.targets)
        ).to(self.device)

        self.splits = {
            'train': TensorDataset(all_imgs, all_labels),
            'validation': TensorDataset(all_imgs, all_labels),
            'test': TensorDataset(all_imgs, all_labels)
        }

class MNISTWithGeneratedData(MNISTBase):

    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(MNISTWithGeneratedData, self).__init__(root, device)
        
        real_train_data = TensorDataset(
            self.original_train_data.data.reshape(-1, 1, 28, 28).float().div_(255.0),
            self.original_train_data.targets
        )

        gen_train_dataset = TensorDataset(
            *torch.load('./advbench/data/generated_mnist/400K.pth')
        )

        train_dataset = ConcatDataset([
            Subset(real_train_data, range(54000)), 
            gen_train_dataset
        ])

        self.splits = {
            'train': train_dataset,
            'validation': Subset(self.original_train_data, range(54000, 60000)),
            'test': self.original_test_data
        }

class MNIST(MNISTBase):

    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(MNIST, self).__init__(root, device)

        self.splits = {
            'train': Subset(self.original_train_data, range(54000)),
            'validation': Subset(self.original_train_data, range(54000, 60000)),
            'test': self.original_test_data
        }

class MNISTSubset(MNISTBase):

    HAS_LR_SCHEDULE = True

    def __init__(self, root, device, num_train_data):
        super(MNISTSubset, self).__init__(root, device)

        self.splits = {
            'train': Subset(self.original_train_data, range(num_train_data)),
            'validation': Subset(self.original_train_data, range(54000, 60000)),
            'test': self.original_test_data
        }

class MNIST100(MNISTSubset):

    def __init__(self, root, device):
        super(MNIST100, self).__init__(root, device, num_train_data=100)

class MNIST1000(MNISTSubset):

    def __init__(self, root, device):
        super(MNIST1000, self).__init__(root, device, num_train_data=1000)

class MNIST10000(MNISTSubset):

    def __init__(self, root, device):
        super(MNIST10000, self).__init__(root, device, num_train_data=10000)

class MNIST25000(MNISTSubset):

    def __init__(self, root, device):
        super(MNIST25000, self).__init__(root, device, num_train_data=25000)

class QuantizedMNIST(MNISTBase):

    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(QuantizedMNIST, self).__init__(root, device)

        class QuantizedDataset(Dataset):
            def __init__(self, data):
                self.data = data

            def __getitem__(self, index):
                x, y = self.data[index]
                quantized_x = torch.where(x > 0.5, 1.0, 0.0)
                return (quantized_x, y)

            def __len__(self):
                return len(self.data)

        train_data = Subset(self.original_train_data, range(54000))
        validation_data = Subset(self.original_train_data, range(54000, 60000))
        test_data = self.original_test_data

        self.splits = {
            'train': QuantizedDataset(train_data),
            'validation': QuantizedDataset(validation_data),
            'test': QuantizedDataset(test_data)
        }

class MNISTandQMNIST(MNISTBase):

    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(MNISTandQMNIST, self).__init__(root, device)

        qmnist_train_data = TorchvisionQMNIST(
            root=root, 
            what='train', 
            transform=transforms.ToTensor())

        train_dataset = ConcatDataset([
            Subset(self.original_train_data, range(54000)),
            Subset(qmnist_train_data, range(10000, 60000))
        ])        

        self.splits = {
            'train': train_dataset,
            'validation': Subset(self.original_train_data, range(54000, 60000)),
            'test': self.original_test_data
        }

class MNISTTrainOnTest(MNISTBase):

    HAS_LR_SCHEDULE = True

    def __init__(self, root, device):
        super(MNISTTrainOnTest, self).__init__(root, device)

        all_data = ConcatDataset([
            Subset(self.original_train_data, range(54000)), 
            self.original_test_data
        ])

        self.splits = {
            'train': all_data,
            'validation': Subset(self.original_train_data, range(54000, 60000)),
            'test': self.original_test_data
        }


class SVHN(AdvRobDataset):
     
    INPUT_SHAPE = (3, 32, 32)
    NUM_CLASSES = 10
    N_EPOCHS = 115
    CHECKPOINT_FREQ = 10
    LOG_INTERVAL = 100
    HAS_LR_SCHEDULE = False

    def __init__(self, root, device):
        super(SVHN, self).__init__(device)

        train_transforms = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()])
        test_transforms = transforms.ToTensor()

        train_data = SVHN_(root, split='train', transform=train_transforms, download=True)
        self.splits['train'] = train_data
        self.splits['test'] = SVHN_(root, split='test', transform=test_transforms, download=True)

    @staticmethod
    def adjust_lr(optimizer, epoch, hparams):
        lr = hparams['learning_rate']
        if epoch >= 55:    # 150
            lr = hparams['learning_rate'] * 0.1
        if epoch >= 75:    # 175
            lr = hparams['learning_rate'] * 0.01
        if epoch >= 90:    # 190
            lr = hparams['learning_rate'] * 0.001
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr