from distutils.command.config import config
import torch
import random
import numpy as np
import os
import sys
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from PIL import Image
class LT_Dataset(Dataset):
    def __init__(self, root, txt, transform=None):
        self.img_path = []
        self.labels = []
        self.transform = transform
        with open(txt) as f:
            for line in f:
                self.img_path.append(os.path.join(root, line.split()[0]))
                self.labels.append(int(line.split()[1]))
        self.targets = self.labels  
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, index):
        path = self.img_path[index]
        label = self.labels[index]
        with open(path, 'rb') as f:
            sample = Image.open(f).convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, label
class ImageNetLTDataLoader(DataLoader):
    def __init__(self, shuffle=True, training=True):
        data_dir = ""
        config_dir = ""
        train_trsfm = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        test_trsfm = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        if training:
            dataset = LT_Dataset(data_dir, config_dir + 'ImageNet_LT_train.txt', train_trsfm)
            val_dataset = LT_Dataset(data_dir, config_dir + 'ImageNet_LT_val.txt', train_trsfm)
        else:  
            dataset = LT_Dataset(data_dir, config_dir + 'ImageNet_LT_test.txt', test_trsfm)
            val_dataset = None
        self.dataset = dataset
        self.val_dataset = val_dataset
        self.n_samples = len(self.dataset)
        num_classes = len(np.unique(dataset.targets))
        assert num_classes == 1000
        self.num_classes = num_classes
        cls_num_list = [0] * num_classes
        for label in dataset.targets:
            cls_num_list[label] += 1
        self.cls_num_list = cls_num_list
        self.shuffle = shuffle
        self.init_kwargs = {
            'shuffle': self.shuffle
        }
        super().__init__(dataset=self.dataset, **self.init_kwargs)
    def split_validation(self):
        return None
        return DataLoader(dataset=self.val_dataset, shuffle=True)
    def classifity_index_label(self):
        list_label2indices = [[] for _ in range(self.num_classes)]
        for idx, label in enumerate(self.dataset):
            list_label2indices[label[1]].append(idx)