from .base import *
import contextlib

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


class CUBirds(BaseDataset2):
    raw_path ='raw'
    def __init__(self, root, mode, transform = None, seen_rate=0.5):
        self.root = root
        self.mode = mode
        self.transform = transform
        self.label_mat = None
        self.seen_rate = seen_rate

        self.classes = range(0, int(200*seen_rate))

        
        BaseDataset2.__init__(self, self.root, self.mode, self.transform)
        # index = 0
        # for i in torchvision.datasets.ImageFolder(root =
        #         os.path.join(self.root, 'images')).imgs:
        #     # i[1]: label, i[0]: root
        #     y = i[1]
        #     # fn needed for removing non-images starting with `._`
        #     fn = os.path.split(i[0])[1]
        #     if y in self.classes and fn[:2] != '._':
        #         self.ys += [y]
        #         self.I += [index]
        #         self.im_paths.append(os.path.join(self.root, i[0]))
        #         index += 1

    ## mode = Train, Gallery, Query
        images_file_path = os.path.join(self.root, 'raw/CUB_200_2011/images/')

        all_images_list_path = os.path.join(self.root, 'raw/CUB_200_2011/images.txt')
        all_images_list = np.genfromtxt(all_images_list_path, dtype=str)
        train_test_list_path = os.path.join(self.root, 'raw/CUB_200_2011/train_test_split.txt')
        train_test_list = np.genfromtxt(train_test_list_path, dtype=int)

        imgs = []
        query_imgs = []

        for i in range(0, len(all_images_list)):
            fname = all_images_list[i, 1]
            full_path = os.path.join(images_file_path, fname)
            if train_test_list[i, 1] == 1:
                imgs.append((full_path, int(fname[0:3]) - 1))
            elif train_test_list[i, 1] == 0:
                query_imgs.append((full_path, int(fname[0:3]) - 1))

        n_target_classes = int(len(self.classes))

        self.all_labels = []

        imgs = [(x, y, True if y < n_target_classes else False) for x, y in imgs]
        query_imgs = [(x, y, True if y < n_target_classes else False) for x, y in query_imgs]

        # relabel unlabeled class samples and get idxs each
        # self.labeled_idxs, self.unlabeled_idxs = self._relabel_data(imgs)

        total_imgs = imgs + query_imgs
        total_target_class_imgs = [(y[0], y[1]) for y in filter(lambda x: x[2], total_imgs)]

        ## split train, gallery dataset
        train_ratio = int(0.3 * len(total_target_class_imgs))
        with temp_seed(0):
            idxs = np.random.permutation(len(total_target_class_imgs))[:train_ratio]

        target_gallery_imgs = []
        train_imgs = []
        for i in range(len(total_target_class_imgs)):
            if i in idxs:
                train_imgs.append(total_target_class_imgs[i])
            else:
                target_gallery_imgs.append(total_target_class_imgs[i])

        unseen_gallery_imgs = [(y[0], y[1]) for y in filter(lambda x: not x[2], imgs)]
        gallery_imgs = target_gallery_imgs + unseen_gallery_imgs
        if self.mode == 'train':
            imgs = train_imgs
        elif self.mode == 'gallery':
            imgs = gallery_imgs
        elif self.mode == 'query':
            query_imgs = [(y[0], y[1]) for y in filter(lambda x: not x[2], query_imgs)]

            query_labels = np.array([label for (path, label) in query_imgs])
            gallery_labels = np.array([label for (path, label) in gallery_imgs])

            query_one_hot = np.eye(200)[query_labels]
            gallery_one_hot = np.eye(200)[gallery_labels]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))
            imgs = query_imgs

        else:
            assert False

        index = 0
        for i in imgs:
            # i[1]: label, i[0]: root
            y = i[1]
            # fn needed for removing non-images starting with `._`
            fn = os.path.split(i[0])[1]
            #if y in self.classes and fn[:2] != '._':
            #if y in self.classes:
            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))