import sys

from torchvision import datasets, transforms
from base import BaseDataLoader
from data_loader.cifar10 import get_cifar10
from data_loader.cifar100 import get_cifar100
from data_loader.mini_imagenet import get_miniimagenet
from parse_config import ConfigParser
from PIL import Image

class TwoCropTransform:
    """Create two crops of the same image"""
    # From https://github.com/HobbitLong/SupContrast/blob/a8a275b3a8b9b9bdc9c527f199d5b9be58148543/util.py
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class CIFAR10DataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0,  training=True, num_workers=4, pin_memory=True, do_sup_con = False, ce = False):
        config = ConfigParser.get_instance()
        cfg_trainer = config['trainer']
        
        if cfg_trainer["do_adv"]:
            print("Doint adv. attack")
            transform_train = transforms.Compose([
                #transforms.RandomCrop(32, padding=4),
                transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])
            transform_val = transforms.Compose([
                transforms.ToTensor(),
            ])
        elif do_sup_con: # Do supervised contrastive learning
            print("Using Cifar10 dataset with supervised contrastive loss based augmentation")
            transform_train = transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                    ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            transform_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        else:
            print("Using Cifar10 dataset with CE loss based augmentation")
            transform_train = transforms.Compose([
                #transforms.RandomCrop(32, padding=4),
                transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            transform_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        
        
        
        self.data_dir = data_dir

        # noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
        
        if do_sup_con and not ce:
            self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=TwoCropTransform(transform_train), transform_val=transform_val)
        elif do_sup_con and ce:
            self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_val=transform_val)
        else:
            self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_val=transform_val)

        super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
                         val_dataset = self.val_dataset)


class CIFAR100DataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True, do_sup_con = False, ce = False):
        config = ConfigParser.get_instance()
        cfg_trainer = config['trainer']
        
        """
        Augment Level:
        1 = Normalization
        2 = Normalization + RandomCrop
        3 = Normalization + RandomCrop + RandomHorizontalFlip
        4 = Normalization + RandomCrop + RandomHorizontalFlip + ColorJitter
        5 = Normalization + RandomCrop + RandomHorizontalFlip + ColorJitter + RandomGrayscale
        """
        
        if "augment_level" in cfg_trainer:
            augment_level = cfg_trainer["augment_level"]
        else:
            augment_level = 3
            
        if cfg_trainer["do_adv"]:
            print("Doint adv. attack")
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])
            transform_val = transforms.Compose([
                transforms.ToTensor(),
            ])
        elif do_sup_con: # Do supervised contrastive learning
            print("Using Cifar100 dataset with supervised contrastive loss based augmentation")
            print(f"Augment Level is {augment_level}")
            
            if augment_level == 1:
                transform_train = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            elif augment_level == 2:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            elif augment_level == 3:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            elif augment_level == 4:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply([
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                        ], p=0.8),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            elif augment_level == 5:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply([
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                        ], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            else:
                raise ValueError("Augment Level not implemented")
                
            # transform_train = transforms.Compose([
            #     transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
            #     transforms.RandomHorizontalFlip(),
            #     transforms.RandomApply([
            #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            #         ], p=0.8),
            #     transforms.RandomGrayscale(p=0.2),
            #     transforms.ToTensor(),
            #     transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            # ])
            
            # transform_train = transforms.Compose([
            #     transforms.RandomCrop(32, padding=4),
            #     transforms.RandomHorizontalFlip(),
            #     transforms.ToTensor(),
            #     transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            # ])

            transform_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ])
        else:
            print("Using Cifar100 dataset CE loss based augmentation")
            transform_train = transforms.Compose([
                    #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
                    transforms.RandomCrop(32, padding=4),
                    #transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            transform_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ])
        self.data_dir = data_dir

        # noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
        
        # supcon is turned on for CE, ce is turned off for CE + infoNCE
        if do_sup_con and not ce:
            self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=TwoCropTransform(transform_train), transform_val=transform_val)
        elif do_sup_con and ce:
            self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_val=transform_val)
        else:
            self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_val=transform_val)

        super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
                         val_dataset = self.val_dataset)
        
        
        
####### Added ####### 
##################### 
class MiniImageNetDataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=0, pin_memory=True, do_sup_con = False, ce = False):
        config = ConfigParser.get_instance()
        cfg_trainer = config['trainer']
        
        if "augment_level" in cfg_trainer:
            augment_level = cfg_trainer["augment_level"]
        else:
            augment_level = 3
            
        if do_sup_con: # Do supervised contrastive learning
            print("Using MiniImagenet dataset with supervised contrastive loss based augmentation")
            print(f"Augment Level is {augment_level}")
            
            if augment_level == 1:
                transform_train = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
            elif augment_level == 2:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(84, padding=4),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
            elif augment_level == 3:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(84, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
            elif augment_level == 4:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(84, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply([
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                        ], p=0.8),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
            elif augment_level == 5:
                transform_train = transforms.Compose([
                    transforms.RandomCrop(84, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply([
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                        ], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
            else:
                raise ValueError("Augment Level not implemented")
                
            transform_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
                
        else:
            print("Using MiniImagenet dataset CE loss based augmentation")
            transform_train = transforms.Compose([
                    #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
                    transforms.RandomCrop(84, padding=4),
                    #transforms.RandomResizedCrop(size=84, scale=(0.2, 1.)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])
            transform_val = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
        
        self.data_dir = data_dir
        
        if do_sup_con and not ce:
            self.train_dataset, self.val_dataset = get_miniimagenet(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=TwoCropTransform(transform_train), transform_val=transform_val)
        elif do_sup_con and ce:
            self.train_dataset, self.val_dataset = get_miniimagenet(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_val=transform_val)
        else:
            self.train_dataset, self.val_dataset = get_miniimagenet(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_val=transform_val)

        super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
                         val_dataset = self.val_dataset)
        
        
# class MiniImageNetDataLoader(BaseDataLoader):
#     def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=0, pin_memory=True):
#         config = ConfigParser.get_instance()
#         cfg_trainer = config['trainer']
        
#         transform_train = transforms.Compose([
#             transforms.RandomCrop(84, padding=4),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#             transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
#         ])
#         transform_val = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
#         ])
        
#         self.data_dir = data_dir

#         self.train_dataset, self.val_dataset = get_miniimagenet(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
#                                                            transform_train=transform_train, transform_val=transform_val)

#         super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
#                          val_dataset = self.val_dataset)