from .mnist_data_module import MNISTDataModule
from .cifar10_data_module import CIFAR10dataModule
from torchvision import transforms
from torchvision.transforms import v2
from .discrete_rotation import DiscreteRotation

def get_data_module(params):
    if params.data_set == 'mnist':
        train_transform = transforms.Compose([
                                    transforms.Pad(params.padding), 
                                    transforms.ToTensor(),  
                                    transforms.Normalize((0.1307,), (0.3081,)) 
                                    ])
        if params.test_flip:
            test_transform = transforms.Compose([
                                        transforms.Pad(params.padding),
                                        transforms.RandomRotation(params.test_rotation, interpolation=v2.InterpolationMode.BICUBIC),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),  
                                        transforms.Normalize((0.1307,), (0.3081,))
                                    ])
        else:
            test_transform = transforms.Compose([
                                        transforms.Pad(params.padding),
                                        transforms.RandomRotation(params.test_rotation, interpolation=v2.InterpolationMode.BICUBIC),
                                        transforms.ToTensor(),  
                                        transforms.Normalize((0.1307,), (0.3081,))
                                    ])
        un_augmented_transform = transforms.Compose([
                                        transforms.Pad(params.padding),
                                        transforms.ToTensor(),  
                                        transforms.Normalize((0.1307,), (0.3081,))
                                    ])
        return MNISTDataModule(data_dir=params.data_dir,
                                batch_size=params.batch_size,
                                ntrain=params.ntrain,
                                train_discard_classes=params.train_discard_classes,
                                train_transform=train_transform,
                                test_discard_classes=params.test_discard_classes,
                                test_transform=test_transform,
                                num_workers=params.num_workers,
                                val_ratio=params.val_ratio,
                                un_augmented_transform=un_augmented_transform,
                                test_batch=params.test_batch)
    elif params.data_set == 'cifar10':
        train_transform = transforms.Compose([
                                    transforms.Pad(params.padding),
                                    transforms.ToTensor(),  
                                    transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0)) 
                                ])
        if params.test_flip:
            print("Using test flip")
            test_transform = transforms.Compose([
                                        transforms.Pad(params.padding),
                                        transforms.RandomRotation(params.test_rotation, interpolation=v2.InterpolationMode.BICUBIC),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),  
                                        transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
                                    ])
        else:
            test_transform = transforms.Compose([
                                        transforms.Pad(params.padding),
                                        transforms.RandomRotation(params.test_rotation, interpolation=v2.InterpolationMode.BICUBIC),
                                        transforms.ToTensor(),  
                                        transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
                                    ])
        un_augmented_transform = transforms.Compose([
                                        transforms.Pad(params.padding),
                                        transforms.ToTensor(),  
                                        transforms.Normalize((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
                                    ])
 
        return CIFAR10dataModule(data_dir=params.data_dir,
                                batch_size=params.batch_size,
                                ntrain=params.ntrain,
                                train_discard_classes=params.train_discard_classes,
                                train_transform=train_transform,
                                test_discard_classes=params.test_discard_classes,
                                test_transform=test_transform,
                                num_workers=params.num_workers,
                                val_ratio=params.val_ratio,
                                un_augmented_transform=un_augmented_transform,
                                test_batch=params.test_batch)
    else:
        raise ValueError(f'Dataset {params.data_set} not found')
