from .base import *

class CUBirds(BaseDataset2):
    raw_path ='raw'
    def __init__(self, root, mode, transform = None, seen_rate=1.0):
        self.root = root
        self.mode = mode
        self.transform = transform
        self.label_mat = None

        self.classes = range(0, int(200*seen_rate))
        
        BaseDataset2.__init__(self, self.root, self.mode, self.transform)
        # index = 0
        # for i in torchvision.datasets.ImageFolder(root =
        #         os.path.join(self.root, 'images')).imgs:
        #     # i[1]: label, i[0]: root
        #     y = i[1]
        #     # fn needed for removing non-images starting with `._`
        #     fn = os.path.split(i[0])[1]
        #     if y in self.classes and fn[:2] != '._':
        #         self.ys += [y]
        #         self.I += [index]
        #         self.im_paths.append(os.path.join(self.root, i[0]))
        #         index += 1

    ## mode = Train, Gallery, Query
        images_file_path = os.path.join(self.root, 'raw/CUB_200_2011/images/')

        all_images_list_path = os.path.join(self.root, 'raw/CUB_200_2011/images.txt')
        all_images_list = np.genfromtxt(all_images_list_path, dtype=str)
        train_test_list_path = os.path.join(self.root, 'raw/CUB_200_2011/train_test_split.txt')
        train_test_list = np.genfromtxt(train_test_list_path, dtype=int)

        imgs = []
        labels = []
        test_imgs = []
        test_labels = []

        for i in range(0, len(all_images_list)):
            fname = all_images_list[i, 1]
            full_path = os.path.join(images_file_path, fname)
            if train_test_list[i, 1] == 1:
                imgs.append(full_path)
                labels.append(int(fname[0:3]) - 1)
            elif train_test_list[i, 1] == 0:
                test_imgs.append(full_path)
                test_labels.append(int(fname[0:3]) - 1)

        imgs = np.array(imgs + test_imgs)
        labels = np.array(labels + test_labels)

        query_imgs = imgs[::10]
        query_labels = labels[::10]
        imgs = np.delete(imgs, np.arange(0, np.shape(imgs)[0], 10), axis=0)
        labels = np.delete(labels, np.arange(0, np.shape(labels)[0], 10), axis=0)

        train_imgs = imgs[::3]
        train_labels = labels[::3]
        gallery_imgs = np.delete(imgs, np.arange(0, np.shape(imgs)[0], 3), axis=0)
        gallery_labels = np.delete(labels, np.arange(0, np.shape(imgs)[0], 3), axis=0)

        if self.mode == 'train':
            imgs = list(zip(train_imgs, train_labels))
        elif self.mode == 'gallery':
            imgs = list(zip(gallery_imgs, gallery_labels))
        elif self.mode == 'query':
            query_one_hot = np.eye(200)[query_labels]
            gallery_one_hot = np.eye(200)[gallery_labels]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))
            imgs = list(zip(query_imgs, query_labels))

        else:
            assert False

        index = 0
        for i in imgs:
            # i[1]: label, i[0]: root
            y = i[1]
            # fn needed for removing non-images starting with `._`
            fn = os.path.split(i[0])[1]
            #if y in self.classes and fn[:2] != '._':
            if y in self.classes:
                self.ys += [y]
                self.I += [index]
                self.im_paths.append(i[0])
                index += 1