import os
from .lt_data import LT_Dataset


class ImageNet_LT(LT_Dataset):
    classnames_txt = "./datasets/ImageNet_LT/classnames.txt"
    train_txt = "./datasets/ImageNet_LT/ImageNet_LT_train.txt"
    test_txt = "./datasets/ImageNet_LT/ImageNet_LT_test.txt"

    def __init__(self, root, train=True, transform=None, drop_shot=0):
        super().__init__(root, train, transform)

        self.classnames = self.read_classnames()

        self.names = []
        with open(self.txt) as f:
            for line in f:
                self.names.append(self.classnames[int(line.split()[1])])
        
        self.drop_class_samples(drop_shot=drop_shot)

    def __getitem__(self, index):
        image, label = super().__getitem__(index)
        name = self.names[index]
        return image, label, name

    @classmethod
    def read_classnames(self):
        classnames = []
        with open(self.classnames_txt, "r") as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip().split(" ")
                folder = line[0]
                classname = " ".join(line[1:])
                classnames.append(classname)
        return classnames

    def drop_class_samples(self, drop_shot):
        new_path = []
        new_labels = []
        new_names = []
        for path, label, name in zip(self.img_path, self.labels, self.names):
            if self.cls_num_list[label] > drop_shot:
                new_path.append(path)
                new_labels.append(label)
                new_names.append(name)
        self.img_path = new_path
        self.labels = new_labels
        self.names = new_names
