import os
import numpy as np
import socket
import random
import math
import torch
import torch.distributed as dist
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torch.utils.data as data
from torch.utils.data import Sampler, DistributedSampler
from collections import defaultdict


# IMAGENET directory
hostname = socket.gethostname()
data_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), '/your/path/imagenet/')


class RASampler(torch.utils.data.Sampler):
    """Sampler that restricts data loading to a subset of the dataset for distributed,
    with repeated augmentation.
    It ensures that different each augmented version of a sample will be visible to a
    different process (GPU).
    Heavily based on 'torch.utils.data.DistributedSampler'.

    This is borrowed from the DeiT Repo:
    https://github.com/facebookresearch/deit/blob/main/samplers.py
    """

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available!")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available!")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
        self.shuffle = shuffle
        self.seed = seed
        self.repetitions = repetitions

    def __iter__(self):
        if self.shuffle:
            # Deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # Add extra samples to make it evenly divisible
        indices = [ele for ele in indices for i in range(self.repetitions)]
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # Subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices[: self.num_selected_samples])

    def __len__(self):
        return self.num_selected_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


class DistributedFixedClassSampler(DistributedSampler):
    def __init__(self, dataset, num_samples_per_class, seed=None, shuffle=True):
        super().__init__(dataset, shuffle=shuffle)
        self.dataset = dataset
        self.num_samples_per_class = num_samples_per_class
        self.seed = seed
        self.class_indices = self._get_class_indices()
        self.indices = self._sample_indices()

    def _get_class_indices(self):
        # Create a dictionary to store the indices for each class
        class_indices = defaultdict(list)

        # Populate class indices
        for idx, (img, label) in enumerate(self.dataset.samples):
            class_indices[label].append(idx)

        return class_indices

    def _sample_indices(self):
        # Sample fixed number of images from each class
        selected_indices = []
        if self.seed is not None:
            random.seed(self.seed)

        for label, indices in self.class_indices.items():
            if len(indices) < self.num_samples_per_class:
                raise ValueError(f"Not enough samples for class {label}")
            selected_indices.extend(random.sample(indices, self.num_samples_per_class))

        return selected_indices

    def __iter__(self):
        # Use the sampled indices instead of the default DistributedSampler iteration
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


class FixedClassSampler(data.Sampler):
    def __init__(self, dataset, num_samples_per_class, seed=None):
        self.dataset = dataset
        self.num_samples_per_class = num_samples_per_class
        self.seed = seed
        self.class_indices = self._get_class_indices()
        self.indices = self._sample_indices()
    
    def _get_class_indices(self):
        # Create a dictionary to store the indices for each class
        class_indices = defaultdict(list)
        
        # Populate class indices
        for idx, (img, label) in enumerate(self.dataset.samples):
            class_indices[label].append(idx)
        
        return class_indices
    
    def _sample_indices(self):
        # Sample fixed number of images from each class
        selected_indices = []
        if self.seed is not None:
            random.seed(self.seed)
        
        for label, indices in self.class_indices.items():
            if len(indices) < self.num_samples_per_class:
                raise ValueError(f"Not enough samples for class {label}")
            selected_indices.extend(random.sample(indices, self.num_samples_per_class))
        
        return selected_indices
    
    def __iter__(self):
        return iter(self.indices)
    
    def __len__(self):
        return len(self.indices)


class ImageNetCodedTeacher(ImageFolder):
    """: Folder datasets which returns (img, label, index, contrast_index):
    """
    
    def __init__(self, folder, mean, std, org_ResizeShape, transform=None, target_transform=None):
        super().__init__(folder, transform=transform)
        self.seed_offset = 0 
        self.org_ResizeShape = org_ResizeShape
        self.mean = mean
        self.std = std
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        random_state = torch.get_rng_state()

        img, target = super().__getitem__(index)
        img = get_deterministic_transforms(index+self.seed_offset, self.org_ResizeShape, self.mean, self.std)(img)

        torch.set_rng_state(random_state)
        
        return img, target, index


def get_deterministic_transforms(seed, org_ResizeShape, mean, std):
    # Set the random seed for reproducibility
    # mean=(0.485, 0.456, 0.406) 
    # std=(0.229, 0.224, 0.225)
    
    torch.manual_seed(seed)
    random.seed(seed)
    normalize = transforms.Normalize(mean=mean, std=std)


    if org_ResizeShape != 0:
        train_transform = transforms.Compose(
                                                [
                                                    transforms.Resize(org_ResizeShape),
                                                    transforms.RandomResizedCrop(224),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    normalize,
                                                ]
                                            )
    else:
        train_transform = transforms.Compose(
                                                [
                                                    transforms.RandomResizedCrop(224),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    normalize,
                                                ]
                                            )
        
    # Testing augmentation
    # train_transform = transforms.Compose(
    #                                         [
    #                                             transforms.Resize(256),
    #                                             transforms.CenterCrop(224),
    #                                             transforms.ToTensor(),
    #                                             normalize,
    #                                         ]
    #                                     )
    return train_transform


class ImageNet(ImageFolder):
    def __init__(self, folder, transform):
        super().__init__(folder, transform=transform)
        self.seed_offset = 0
    
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


class ImageNetInstanceSample(ImageNet):
    """
    Folder datasets which returns (img, label, index, contrast_index):
    """
    def __init__(self, folder, transform=None, target_transform=None, is_sample=False, k=4096):
        super().__init__(folder, transform=transform)

        self.k = k
        self.is_sample = is_sample
        self.seed_offset = 0
        if self.is_sample:
            print('preparing contrastive data...')
            num_classes = 1000
            num_samples = len(self.samples)
            label = np.zeros(num_samples, dtype=np.int32)
            for i in range(num_samples):
                _, target = self.samples[i]
                label[i] = target

            self.cls_positive = [[] for i in range(num_classes)]
            for i in range(num_samples):
                self.cls_positive[label[i]].append(i)

            self.cls_negative = [[] for i in range(num_classes)]
            for i in range(num_classes):
                for j in range(num_classes):
                    if j == i:
                        continue
                    self.cls_negative[i].extend(self.cls_positive[j])

            self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)]
            self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)]
            print('done.')

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        img, target, index = super().__getitem__(index)

        if self.is_sample:
            # sample contrastive examples
            pos_idx = index
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True)
            sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, index, sample_idx
        else:
            return img, target, index


def get_imagenet_train_transform_resizeInput(mean, std):
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose(
        [   transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,])
    return train_transform


def get_imagenet_train_transform(mean, std):
    train_transform = transforms.Compose(
        [   
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
    return train_transform


def get_imagenet_test_transform(mean, std):
    test_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std),
        ])
    return test_transform


# dataloader

def get_imagenet_dataloaders_crd(batch_size, val_batch_size, num_workers, k=4096, mean=[0., 0., 0.], std=[1/255., 1/255., 1/255.]):
    train_transform = get_imagenet_train_transform(mean=mean, std=std)
    train_folder = os.path.join(data_folder, 'train')
    train_set = ImageNetInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k)
    num_data = len(train_set)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = get_imagenet_val_loader(val_batch_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    return train_loader, test_loader, num_data


def get_imagenet_dataloaders_crd_dist(batch_size, val_batch_size, num_workers, k=4096, mean=[0., 0., 0.], std=[1/255., 1/255., 1/255.]):
    train_transform = get_imagenet_train_transform(mean=mean, std=std)
    train_folder = os.path.join(data_folder, 'train')
    train_set = ImageNetInstanceSample(train_folder, transform=train_transform, is_sample=True, k=k)
    num_data = len(train_set)
    
    test_transform = get_imagenet_test_transform(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    test_folder = os.path.join(data_folder, 'val')
    test_set = ImageFolder(test_folder, transform=test_transform)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_set, shuffle=False)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=val_batch_size, sampler=test_sampler, num_workers=num_workers, pin_memory=True, drop_last=False)
    
    return train_loader, test_loader, num_data


def get_imagenet_dataloaders_default(batch_size, val_batch_size, num_workers, mean=[0., 0., 0.], std=[1/255., 1/255., 1/255.]):
    train_transform = get_imagenet_train_transform(mean=mean, std=std)
    train_folder = os.path.join(data_folder, 'train')
    train_set = ImageNet(train_folder, transform=train_transform)
    num_data = len(train_set)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = get_imagenet_val_loader(val_batch_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    return train_loader, test_loader, num_data


def get_imagenet_dataloaders_default_dist(batch_size, val_batch_size, num_workers, mean=[0., 0., 0.], std=[1/255., 1/255., 1/255.]):
    train_transform = get_imagenet_train_transform(mean=mean, std=std)
    train_folder = os.path.join(data_folder, 'train')
    train_set = ImageNet(train_folder, transform=train_transform)
    num_data = len(train_set)

    test_transform = get_imagenet_test_transform(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    test_folder = os.path.join(data_folder, 'val')
    test_set = ImageFolder(test_folder, transform=test_transform)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_set, shuffle=False)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=val_batch_size, sampler=test_sampler, num_workers=num_workers, pin_memory=True, drop_last=False)

    return train_loader, test_loader, num_data


# def get_imagenet_dataloaders_coded_dist(batch_size, val_batch_size, num_workers, org_ResizeShape = 0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
#     train_folder = os.path.join(data_folder, 'train')
    
#     train_set = ImageNetCodedTeacher(train_folder, mean, std, org_ResizeShape, transform=None)
#     num_data = len(train_set)
#     test_transform = get_imagenet_test_transform(mean, std)
#     test_folder = os.path.join(data_folder, 'val')
#     test_set = ImageFolder(test_folder, transform=test_transform)

#     train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
#     test_sampler = torch.utils.data.distributed.DistributedSampler(test_set, shuffle=False)
    
#     train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True, drop_last= False,)
#     test_loader = torch.utils.data.DataLoader(test_set, batch_size=val_batch_size, sampler=test_sampler, num_workers=num_workers, pin_memory=True, drop_last= False,)

#     return train_loader, test_loader, num_data


def get_imagenet_val_loader(val_batch_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    test_transform = get_imagenet_test_transform(mean, std)
    test_folder = os.path.join(data_folder, 'val')
    test_set = ImageFolder(test_folder, transform=test_transform)

    sampler = torch.utils.data.SequentialSampler(test_set)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=val_batch_size, shuffle=False, num_workers=16, pin_memory=True, sampler=sampler)
    
    return test_loader

