import os
import random
import tarfile
import shutil
from scipy.io import loadmat
from PIL import ImageFilter
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset, Subset


class GaussianBlur(object):
    # Gaussian blur augmentation from: https://arxiv.org/abs/2002.05709
    
    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, image):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        image = image.filter(ImageFilter.GaussianBlur(radius=sigma))
        return image


class TwoCropsTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]
    

class DropLabelDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        data, _ = self.dataset[index]
        return data

    def __len__(self):
        return len(self.dataset)
    

def prepare_ImageNet(logger, root='./data', num_classes=10, labels=None):
    if not os.path.isdir(root):
        os.makedirs(root, exist_ok=True)

    archives = {
        'train': 'ILSVRC2012_img_train.tar',
        'val':   'ILSVRC2012_img_val.tar',
        'devkit':'ILSVRC2012_devkit_t12.tar.gz'
    }

    for _, filename in archives.items():
        archive_path = os.path.join(root, filename)
        
        if not os.path.isfile(archive_path):
            raise FileNotFoundError(f"Missing file: {filename}. Please download it from the official ImageNet website and place it in {root}.")

    subfolder_name = f'imagenet{num_classes}' if num_classes != 1000 else 'imagenet1k'

    train_directory = os.path.join(root, subfolder_name, 'train')
    validation_directory = os.path.join(root, subfolder_name, 'val')
    devkit_directory = os.path.join(root, subfolder_name, 'devkit')

    os.makedirs(train_directory, exist_ok=True)
    os.makedirs(validation_directory, exist_ok=True)
    os.makedirs(devkit_directory, exist_ok=True)

    train_classes = [folder for folder in os.listdir(train_directory) if os.path.isdir(os.path.join(train_directory, folder))]  
    validation_classes = [folder for folder in os.listdir(validation_directory) if os.path.isdir(os.path.join(validation_directory, folder))]

    if labels is not None:
        if set(train_classes) == set(labels) and set(validation_classes) == set(labels):
            logger.info(f"Detected {num_classes} class folders in both {train_directory} and {validation_directory}. Skipping extraction..")
            return train_directory, validation_directory
    
    else:
        if len(train_classes) == num_classes and len(validation_classes) == num_classes:
            logger.info(f"Detected {num_classes} class folders in both {train_directory} and {validation_directory}. Skipping extraction..")
            return train_directory, validation_directory
    
    # Extract the devkit archive
    logger.info("Extracting devkit archive...")
    devkit_tar_path = os.path.join(root, archives['devkit'])

    with tarfile.open(devkit_tar_path) as tar:
        tar.extractall(path=devkit_directory)
    
    logger.info("Devkit archive extracted.")

    # Extract the train dataset archive
    logger.info("Extracting training tarball...")
    train_tar_path = os.path.join(root, archives['train'])
    
    with tarfile.open(train_tar_path) as tar:
        if num_classes == 1000:
            tar.extractall(path=train_directory)

        else:
            all_members = [member for member in tar.getmembers() if member.name.endswith('.tar')]

            if labels is not None:
                target_labels = set(labels)
            else:
                # If labels are not provided, sort the class tar files alphabetically and select the first `num_classes` classes. 
                all_members = sorted(all_members, key=lambda m: m.name)
                target_labels = set(member.name.replace('.tar', '') for member in all_members[:num_classes])

            for member in all_members:
                class_name = member.name.replace('.tar', '')

                if class_name in target_labels:
                    tar.extract(member, path=train_directory)

    logger.info("Training tarball extracted.")

    logger.info("Moving training images to their respective folders...")
    for filename in os.listdir(train_directory):
        file_path = os.path.join(train_directory, filename)

        if filename.endswith('.tar'):
            class_name = filename.replace('.tar', '')
            class_folder = os.path.join(train_directory, class_name)
            os.makedirs(class_folder, exist_ok=True)
            
            with tarfile.open(file_path) as tar:
                tar.extractall(path=class_folder)
            
            os.remove(file_path)
    
    logger.info("Training images moved to their respective folders.")

    # Extract the validation dataset archive
    logger.info("Extracting validation tarball...")
    validation_tar_path = os.path.join(root, archives['val'])

    with tarfile.open(validation_tar_path) as tar:
        tar.extractall(path=validation_directory)

    validation_ground_truth = os.path.join(devkit_directory, 'ILSVRC2012_devkit_t12', 'data', 'ILSVRC2012_validation_ground_truth.txt')

    with open(validation_ground_truth, 'r') as f:
        validation_label_idx = [int(line.strip()) -1 for line in f.readlines()]

    validation_images = sorted(os.listdir(validation_directory))
    
    # Load WordNet IDs (WNIDs) from meta.mat
    meta_path = os.path.join(devkit_directory, 'ILSVRC2012_devkit_t12', 'data', 'meta.mat')
    meta = loadmat(meta_path, squeeze_me=True)['synsets']
    idx_to_synset = [entry['WNID'] for entry in meta[:1000]]
    train_classes = [
        f for f in os.listdir(train_directory) 
        if os.path.isdir(os.path.join(train_directory, f))
    ]

    for filename, label_idx in zip(validation_images, validation_label_idx):
        synset = idx_to_synset[label_idx]
        original_file_path = os.path.join(validation_directory, filename)

        if synset not in train_classes:
            os.remove(original_file_path)
            continue

        synset_folder = os.path.join(validation_directory, synset)
        os.makedirs(synset_folder, exist_ok=True)
        
        new_file_path = os.path.join(synset_folder, filename)
        shutil.move(original_file_path, new_file_path)
    
    logger.info("Validation tarball extracted.")

    return train_directory, validation_directory


def prepare_tinyimagenet_val_folder(val_dir):
    """
    Reorganize Tiny ImageNet val folder structure so it works with torchvision.datasets.ImageFolder.
    """

    images_dir = os.path.join(val_dir, 'images')
    ann_file = os.path.join(val_dir, 'val_annotations.txt')

    if not os.path.exists(images_dir):
        return  # Already reorganized

    print("Reorganizing Tiny ImageNet validation images...")

    with open(ann_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            img_name, class_id = parts[0], parts[1]

            class_dir = os.path.join(val_dir, class_id)
            os.makedirs(class_dir, exist_ok=True)

            src = os.path.join(images_dir, img_name)
            dst = os.path.join(class_dir, img_name)
            if os.path.exists(src):
                shutil.move(src, dst)

    shutil.rmtree(images_dir)
    print("Validation folder reorganized.")



def get_dataloaders(args, logger, batch_size, augmentation=False, supervised=False):
    if "CIFAR" in args.dataset:
        if args.dataset == "CIFAR10":
            base_dataset_class = datasets.CIFAR10
            mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)

        elif args.dataset == "CIFAR100":
            base_dataset_class = datasets.CIFAR100
            mean, std = (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
        
        normalize_transform = transforms.Normalize(mean, std)

        if augmentation:
            base_transform = transforms.Compose([
                transforms.RandomResizedCrop(32),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur(sigma=[.1, 2.])], p=0.5),
                transforms.ToTensor(),
                normalize_transform
            ])
        else:
            base_transform = transforms.Compose([transforms.ToTensor(), normalize_transform])

        if not supervised:
            transform = TwoCropsTransform(base_transform)
            train_dataset = base_dataset_class(root='./data', train=True, download=True, transform=transform)
            train_dataset = DropLabelDataset(train_dataset)
        else:
            train_dataset = base_dataset_class(root='./data', train=True, download=True, transform=base_transform)

        if args.fast_debug:
            train_dataset = Subset(train_dataset, list(range(batch_size)))
            train_shuffle = False 
        else:
            train_shuffle = True

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=train_shuffle, drop_last=True, num_workers=args.num_workers, pin_memory=True)

        test_transform = base_transform
        test_dataset = base_dataset_class(root='./data', train=False, download=True, transform=test_transform)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

        logger.info(f"{args.dataset} dataset loaded.")

    elif "ImageNet" in args.dataset:
        if args.dataset == "TinyImageNet":
            # url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
            root = './data/tiny-imagenet-200'

            train_dataset_directory = os.path.join(root, "train")
            test_dataset_directory = os.path.join(root, "val")
            prepare_tinyimagenet_val_folder(test_dataset_directory)

            mean, std = (0.4802, 0.4481, 0.3975), (0.2764, 0.2689, 0.2816)
            normalize_transform = transforms.Normalize(mean, std)

            if augmentation:
                base_transform = transforms.Compose([
                    transforms.RandomResizedCrop(64),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.RandomApply([GaussianBlur(sigma=[.1, 2.])], p=0.5),
                    transforms.ToTensor(),
                    normalize_transform
                ])
            else:
                base_transform = transforms.Compose([transforms.ToTensor(), normalize_transform])

            if not supervised:
                transform = TwoCropsTransform(base_transform)
                train_dataset = datasets.ImageFolder(train_dataset_directory, transform=transform)
                train_dataset = DropLabelDataset(train_dataset)
            else:
                train_dataset = datasets.ImageFolder(train_dataset_directory, transform=base_transform)

            if args.fast_debug:
                train_dataset = Subset(train_dataset, list(range(batch_size)))
                train_shuffle = False
            else:
                train_shuffle = True

            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=train_shuffle, drop_last=True, num_workers=args.num_workers, pin_memory=True)

            test_transform = base_transform
            test_dataset = datasets.ImageFolder(test_dataset_directory, transform=test_transform)
            test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

            logger.info(f"{args.dataset} dataset loaded.")

        else:            
            if args.dataset == "ImageNet100":
                # Class labels obtained from https://github.com/HobbitLong/CMC
                labels = [
                    'n02869837', 'n01749939', 'n02488291', 'n02107142', 'n13037406', 'n02091831', 'n04517823', 'n04589890', 'n03062245', 'n01773797', 
                    'n01735189', 'n07831146', 'n07753275', 'n03085013', 'n04485082', 'n02105505', 'n01983481', 'n02788148', 'n03530642', 'n04435653', 
                    'n02086910', 'n02859443', 'n13040303', 'n03594734', 'n02085620', 'n02099849', 'n01558993', 'n04493381', 'n02109047', 'n04111531', 
                    'n02877765', 'n04429376', 'n02009229', 'n01978455', 'n02106550', 'n01820546', 'n01692333', 'n07714571', 'n02974003', 'n02114855', 
                    'n03785016', 'n03764736', 'n03775546', 'n02087046', 'n07836838', 'n04099969', 'n04592741', 'n03891251', 'n02701002', 'n03379051', 
                    'n02259212', 'n07715103', 'n03947888', 'n04026417', 'n02326432', 'n03637318', 'n01980166', 'n02113799', 'n02086240', 'n03903868', 
                    'n02483362', 'n04127249', 'n02089973', 'n03017168', 'n02093428', 'n02804414', 'n02396427', 'n04418357', 'n02172182', 'n01729322', 
                    'n02113978', 'n03787032', 'n02089867', 'n02119022', 'n03777754', 'n04238763', 'n02231487', 'n03032252', 'n02138441', 'n02104029', 
                    'n03837869', 'n03494278', 'n04136333', 'n03794056', 'n03492542', 'n02018207', 'n04067472', 'n03930630', 'n03584829', 'n02123045', 
                    'n04229816', 'n02100583', 'n03642806', 'n04336792', 'n03259280', 'n02116738', 'n02108089', 'n03424325', 'n01855672', 'n02090622'
                ]
                train_dataset_directory, test_dataset_directory = prepare_ImageNet(logger, root, args.num_classes, labels)
            
            elif args.dataset == "ImageNet1000":
                train_dataset_directory, test_dataset_directory = prepare_ImageNet(logger, root, args.num_classes)

            root = './data'
            mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
            normalize_transform = transforms.Normalize(mean, std)

            if augmentation:
                base_transform = transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
                    transforms.RandomGrayscale(p=0.2),
                    transforms.RandomApply([GaussianBlur(sigma=[.1, 2.])], p=0.5),
                    transforms.ToTensor(),
                    normalize_transform
                ])
            else:
                base_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize_transform])

            if not supervised:
                transform = TwoCropsTransform(base_transform)
                train_dataset = datasets.ImageFolder(train_dataset_directory, transform=transform)
                train_dataset = DropLabelDataset(train_dataset)
            else:
                train_dataset = datasets.ImageFolder(train_dataset_directory, transform=base_transform)

            if args.fast_debug:
                train_dataset = Subset(train_dataset, list(range(batch_size)))
                train_shuffle = False
            else:
                train_shuffle = True

            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=train_shuffle, drop_last=True, num_workers=args.num_workers, pin_memory=True)

            test_transform = base_transform
            test_dataset = datasets.ImageFolder(test_dataset_directory, transform=test_transform)
            test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

            logger.info(f"{args.dataset} dataset loaded.")

    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    return train_loader, test_loader