import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import json
import numpy as np
import os
from PIL import Image
import pdb
class LoadVisualGenomeDataset(Dataset):
    def __init__(self, dataset_json_file, image_conf=None, relation_num = 115):
        """
        Dataset that manages a set of paired images and object-relation lists in the Visual Genome dataset

        :param dataset_json_file
        :param image_conf: Dictionary containing 'crop_size' and 'center_crop'
        :param relation_num: number of relation labels - default=115 after filtering the dataset
        """
        with open(dataset_json_file, 'r') as fp:
            data_json = json.load(fp)
        self.data = data_json['data']
        self.image_base_path = data_json['image_base_path']
        self.relation_num = relation_num

        if not image_conf:
            self.image_conf = {}
        else:
            self.image_conf = image_conf

        center_crop = self.image_conf.get('center_crop', False)
        crop_size = self.image_conf.get('crop_size', 224)

        if center_crop:
            self.image_resize_and_crop = transforms.Compose(
                [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
        else:
            self.image_resize_and_crop = transforms.Compose(
                [transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])

        RGB_mean = self.image_conf.get('RGB_mean', [0.485, 0.456, 0.406])
        RGB_std = self.image_conf.get('RGB_std', [0.229, 0.224, 0.225])
        self.image_normalize = transforms.Normalize(mean=RGB_mean, std=RGB_std)

        # tokenizer for Bert model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def _LoadObject(self, object_name_list, object_class_list, relation_list):
        """
        returns:
            object_input_all: (max_num_object, max_num_tokens) tokenized object descriptions
            segments_tensors_all: (max_num_object, max_num_tokens) all 1s
            input_token_mask_all: (max_num_object, max_num_tokens) 1 for valid tokens; 0 for padding tokens
            input_mask: (max_num_object,) 1 for valid objects; 0 for padding objects
            n_object: int; number of valid objects
            object_target: (max_num_object,) ground truth label for object class; 0 for padding objects
            relation_target: (max_num_object, max_num_object) ground truth label for pari-wise relation; 0 for padding object pairs
        """
        # curriculum learning
        relation_num = self.relation_num
        # load tokens in captions
        max_num_object = 32  # each image with at most 32 objects
        max_num_tokens = 10  # each object with at most 10 tokens in its description
        object_input_all = torch.zeros(max_num_object, max_num_tokens).type('torch.LongTensor')  # (max_num_object, max_num_tokens)
        segments_tensors_all = torch.zeros(max_num_object, max_num_tokens).type('torch.LongTensor')  # (max_num_object, max_num_tokens)
        input_token_mask_all = torch.zeros(max_num_object, max_num_tokens).type('torch.LongTensor')  # (max_num_object, max_num_tokens)
        object_target = torch.zeros(max_num_object).type('torch.LongTensor')  # (num_object,)
        relation_target = torch.zeros(max_num_object, max_num_object).type('torch.LongTensor')  # (num_object, num_object)
        relation_list = np.asarray(relation_list)
        # create object number mask
        n_object = len(object_name_list)
        input_mask = np.ones(n_object)
        p = max_num_object - n_object
        if p > 0:
            input_mask = np.pad(input_mask, (0, p), 'constant', constant_values=(0, 0))
        elif p < 0:
            input_mask = input_mask[0:p]
            n_object = max_num_object
        input_mask = torch.from_numpy(input_mask)
        input_mask = input_mask.type('torch.LongTensor')
        # create labels for object and relation
        if(n_object>0):
            object_target[0:n_object] = torch.from_numpy(np.array(object_class_list))[0:n_object]
            relation_target[0:n_object, 0:n_object] = torch.from_numpy(np.array(relation_list))[0:n_object, 0:n_object]
        # create input for each object
        for i in range(n_object):
            caption = '[CLS] ' + object_name_list[i] + ' [SEP]'
            tokenized_text = self.tokenizer.tokenize(caption)
            nw = len(tokenized_text)
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
            input_token_mask = np.ones(len(tokenized_text))
            # cut sequence if too long
            p = max_num_tokens - nw
            if p > 0:
                indexed_tokens = np.pad(indexed_tokens, (0, p), 'constant', constant_values=(0, 0))
                input_token_mask = np.pad(input_token_mask, (0, p), 'constant', constant_values=(0, 0))
            elif p < 0:
                indexed_tokens = indexed_tokens[0:p]
                input_token_mask = input_token_mask[0:p]
                nw = max_num_tokens
            tokens_tensor = torch.tensor([indexed_tokens]).type('torch.LongTensor')
            tokens_tensor = tokens_tensor.squeeze()
            assert len(tokens_tensor) == max_num_tokens
            segments_tensors = torch.ones(max_num_tokens)
            segments_tensors = segments_tensors.type('torch.LongTensor')
            input_token_mask = torch.from_numpy(input_token_mask)
            input_token_mask = input_token_mask.type('torch.LongTensor')
            object_input_all[i,:] = tokens_tensor
            segments_tensors_all[i,:] = segments_tensors
            input_token_mask_all[i,:] = input_token_mask
        return object_input_all, segments_tensors_all, input_token_mask_all, input_mask, n_object, object_target, relation_target

    def _LoadImage(self, impath):
        """
        returns: original image, croped and normalized image for model input
        """
        img = Image.open(impath).convert('RGB')
        img_original = self.image_resize_and_crop(img)
        img = self.image_normalize(img_original)
        return img_original, img

    def __getitem__(self, index):
        """
        returns: combination of visual and language input
        """
        datum = self.data[index]
        # caption = os.path.join(self.caption_base_path, datum['caption'])
        imgpath = os.path.join(self.image_base_path, datum['image'])
        img_original, img = self._LoadImage(imgpath)
        object_name_list = datum['object_description']  # object_name, object_description
        object_class_list = datum['object_class']
        relation_list = datum['relation']
        object_input_all, segments_tensors_all, input_token_mask_all, input_mask, n_object, object_target, relation_target = self._LoadObject(object_name_list, object_class_list, relation_list)
        return img_original, img, object_input_all, segments_tensors_all, input_token_mask_all, input_mask, n_object, object_target, relation_target

    def __len__(self):
        return len(self.data)

class LoadCoCoDataset(Dataset):
    def __init__(self, dataset_json_file, image_conf=None):
        """
        Dataset that manages a set of paired images and captions in the COCO dataset

        :param dataset_json_file
        :param image_conf: Dictionary containing 'crop_size' and 'center_crop'
        """
        with open(dataset_json_file, 'r') as fp:
            data_json = json.load(fp)
        self.data = data_json['data']
        self.image_base_path = data_json['image_base_path']
        self.caption_base_path = data_json['caption_base_path']

        if not image_conf: #图像预处理操作 比如中央裁剪
            self.image_conf = {}
        else:
            self.image_conf = image_conf

        center_crop = self.image_conf.get('center_crop', False)
        crop_size = self.image_conf.get('crop_size', 224)

        if center_crop:
            self.image_resize_and_crop = transforms.Compose(
                [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
        else:
            self.image_resize_and_crop = transforms.Compose(
                [transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor()])

        RGB_mean = self.image_conf.get('RGB_mean', [0.485, 0.456, 0.406])
        RGB_std = self.image_conf.get('RGB_std', [0.229, 0.224, 0.225])
        self.image_normalize = transforms.Normalize(mean=RGB_mean, std=RGB_std)

        # tokenizer for Bert model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def _LoadCaption(self, caption):
        """
        returns:
            tokens_tensor: (sequence_max,) tokenized caption
            segments_tensors: (sequence_max,) all 1s
            input_mask: (sequence_max,) 1 for valid tokens; 0 for padding tokens
            nw: int; number of valid tokens in the input caption
        """
        with open(caption) as f:
            lines = f.read().splitlines()
        # randomly choose 1 out of 5 captions in each iteration 五个中随机选择一个
        caption = '[CLS] ' + lines[np.random.randint(0, len(lines))] + ' [SEP]'
        # tokenize the sequence
        tokenized_text = self.tokenizer.tokenize(caption)
        nw = len(tokenized_text)
        sequence_max = 32
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        input_mask = np.ones(len(tokenized_text))
        # cut sequence if too long
        p = sequence_max - nw
        if p > 0:
            indexed_tokens = np.pad(indexed_tokens, (0, p), 'constant', constant_values=(0, 0))
            input_mask = np.pad(input_mask, (0, p), 'constant', constant_values=(0, 0))
        elif p < 0:
            indexed_tokens = indexed_tokens[0:p]
            input_mask = input_mask[0:p]
            nw = sequence_max
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.squeeze()
        segments_tensors = torch.ones(sequence_max)
        segments_tensors = segments_tensors.type('torch.LongTensor')
        input_mask = torch.from_numpy(input_mask)
        input_mask = input_mask.type('torch.LongTensor')
        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        assert len(tokens_tensor) == sequence_max
        assert len(input_mask) == sequence_max
        assert len(segments_tensors) == sequence_max

        return tokens_tensor, segments_tensors, input_mask, nw

    def _LoadImage(self, impath):
        """
        returns: original image, croped and normalized image for model input
        """
        img = Image.open(impath).convert('RGB')
        img_original = self.image_resize_and_crop(img) 
        img = self.image_normalize(img_original)
        return img_original, img #原图  被normalize处理后的图

    def __getitem__(self, index):
        """
        returns: combination of visual and language input
        """
        datum = self.data[index]
        caption = os.path.join(self.caption_base_path, datum['caption'])#标题地址
        imgpath = os.path.join(self.image_base_path, datum['image'])
        caption, token_type, input_mask, nwords = self._LoadCaption(caption) # revise here
        img_original, img = self._LoadImage(imgpath)
        return img_original, img, caption, token_type, input_mask, nwords

    def __len__(self):
        return len(self.data)

# sys libs
import os
import pickle
import six.moves
import tarfile

# other libs
import numpy as np
import PIL.Image
import torch
import torch.utils.data

def pil_loader(path):
    with open(path, 'rb') as f:
        img = PIL.Image.open(f)
        return img.convert('RGB')
    
Encoded_Prototypes_Path = {"vgg16":"/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_vgg16.pt",
                      "vgg19":"/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_vgg19.pt",
                      "resnet34":"/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_resnet34.pt",
                      "resnet152":"/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_resnet152.pt",
                      "densenet121":"/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_densenet121.pt",
                      "densenet161":"/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/prototype_image_densenet161.pt"}        
class CUB200(torch.utils.data.Dataset):
    """
    CUB200 dataset.

    Variables
    ----------
        _root, str: Root directory of the dataset.
        _train, bool: Load train/test data.
        _transform, callable: A function/transform that takes in a PIL.Image
            and transforms it.
        _train_data, list of np.array.
        _train_labels, list of int.
        _train_parts, list np.array.
        _train_boxes, list np.array.
        _test_data, list of np.array.
        _test_labels, list of int.
        _test_parts, list np.array.
        _test_boxes, list np.array.
    """
    def __init__(self, root, train=True, transform=None, cropped = True, resize=224):
        """
        Load the dataset.

        Args
        ----------
        root: str
            Root directory of the dataset.
        train: bool
            train/test data split.
        transform: callable
            A function/transform that takes in a PIL.Image and transforms it.
        resize: int
            Length of the shortest of edge of the resized image. Used for transforming landmarks and bounding boxes.

        """
        self._root = root #/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011
        self._train = train
        self._transform = transform
        self.loader = pil_loader
        self.newsize = resize
        # 15 key points provided by CUB
        self.num_kps = 15
        if cropped == True:
            self.train_path = self._root + '/CUB_200_2011/new_datasets/cub200_cropped/train_cropped'
            self.train_descriptions = self._root + "/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/descriptions.json"
           

        # Load all data into memory for best IO efficiency. This might take a while
        if self._train:
            self._train_data, self._train_labels = self._get_file_list(train=True)
            assert (len(self._train_data) == 5994 and len(self._train_labels) == 5994)
            #print(len(self._train_data),len(self._train_labels))
        else:
            self._test_data, self._test_labels = self._get_file_list(train=False)
            assert (len(self._test_data) == 5794 and len(self._test_labels) == 5794)
            


    def __getitem__(self, index):
        """
        Retrieve data samples.

        Args
        ----------
        index: int
            Index of the sample.

        Returns
        ----------
        image: torch.FloatTensor, [3, H, W]
            Image of the given index.
        target: int
            Label of the given index.
        parts: torch.FloatTensor, [15, 4]
            Landmark annotations.
        boxes: torch.FloatTensor, [5, ]
            Bounding box annotations.
        """
        # load the variables according to the current index and split
        if self._train:
            image_path = self._train_data[index]
            target = self._train_labels[index]


        else:
            image_path = self._test_data[index]
            target = self._test_labels[index]
            

        # load the image
        image = self.loader(image_path)
        image = np.array(image)
        
        #load caption
        #caption, token_type, input_mask, nwords = self._LoadCaption(caption_path) # revise here

        # convert the image into a PIL image for transformation
        image = PIL.Image.fromarray(image)

        # apply transformation
        if self._transform is not None:
            image = self._transform(image)
        if self._train:
            return image, target
        else:
            return image, target
    
    def __len__(self):
        """Return the length of the dataset."""
        if self._train:
            return len(self._train_data)
        return len(self._test_data)

    def _get_file_list(self, train=True):
        """Prepare the data for train/test split and save onto disk."""

        # load the list into numpy arrays
        image_path = self._root + '/CUB_200_2011/images/'
        #cpation_path = self._root + "/CUB_200_2011/text_c10"
        id2name = np.genfromtxt(self._root + '/CUB_200_2011/images.txt', dtype=str)
        id2train = np.genfromtxt(self._root + '/CUB_200_2011/train_test_split.txt', dtype=int)
        #id2part = np.genfromtxt(self._root + '/CUB_200_2011/parts/part_locs.txt', dtype=float)
        #id2box = np.genfromtxt(self._root + '/CUB_200_2011/bounding_boxes.txt', dtype=float)

        # creat empty lists
        train_data = []
        train_caption = []
        train_labels = []
        #train_parts = []
        #train_boxes = []
        
        test_data = []
        test_caption = []
        test_labels = []
        #test_parts = []
        #test_boxes = []

        # iterating all samples in the whole dataset
        for id_ in range(id2name.shape[0]):

            # load each variable
            image = os.path.join(image_path, id2name[id_, 1])

            # Label starts with 0
            label = int(id2name[id_, 1][:3]) - 1


            # training split
            if id2train[id_, 1] == 1:
                
                train_data.append(image)
                #train_caption.append(caption)
                train_labels.append(label)
                #train_parts.append(parts)
                #train_boxes.append(boxes)
            # testing split
            else:
                test_data.append(image)
                #test_caption.append(caption)
                test_labels.append(label)
                #test_parts.append(parts)
                #test_boxes.append(boxes)

        # return accoring to different splits
        if train == True:
            return train_data, train_labels
        else:
            return test_data, test_labels
        
class CUB200_Concept(torch.utils.data.Dataset):
    """
    CUB200 dataset.

    Variables
    ----------
        _root, str: Root directory of the dataset.
        _train, bool: Load train/test data.
        _transform, callable: A function/transform that takes in a PIL.Image
            and transforms it.
        _train_data, list of np.array.
        _train_labels, list of int.
        _train_parts, list np.array.
        _train_boxes, list np.array.
        _test_data, list of np.array.
        _test_labels, list of int.
        _test_parts, list np.array.
        _test_boxes, list np.array.
    """
    def __init__(self, root, arch, train=True, transform=None, cropped = True):
        """
        Load the dataset.

        Args
        ----------
        root: str
            Root directory of the dataset.
        train: bool
            train/test data split.
        transform: callable
            A function/transform that takes in a PIL.Image and transforms it.
        resize: int
            Length of the shortest of edge of the resized image. Used for transforming landmarks and bounding boxes.

        """
        self._root = root #/data/fengyi/wangjiaqi/datasets/CUB_VL/CUB_200_2011
        self.arch = arch
        self._train = train
        self._transform = transform
        self.loader = pil_loader

        # 15 key points provided by CUB
        self.num_kps = 15
        if cropped == True:
            self.train_path = self._root + '/CUB_200_2011/new_datasets/cub200_cropped/train_cropped'
            with open( self._root + "/CUB_200_2011/new_datasets/cub200_cropped/train_cropped/descriptions.json","r") as f:
                self.descriptions = json.load(f)
            self.train_encoded_prototype = torch.tensor(torch.load(Encoded_Prototypes_Path[self.arch])) #[59940,512]
    

        # Load all data into memory for best IO efficiency. This might take a while
        if self._train:
            self._train_data, self._train_labels = self._get_file_list(train=True)
            self.train_descriptions = self.descriptions["train_images"]
            assert (len(self._train_data) == 5994 and len(self._train_labels) == 5994)
            #print(len(self._train_data),len(self._train_labels))
        else:
            self._test_data, self._test_labels = self._get_file_list(train=False)
            self.test_descriptions = self.descriptions["test_images"]
            assert (len(self._test_data) == 5794 and len(self._test_labels) == 5794)
            


    def __getitem__(self, index):
        """
        Retrieve data samples.

        Args
        ----------
        index: int
            Index of the sample.

        Returns
        ----------
        image: torch.FloatTensor, [3, H, W]
            Image of the given index.
        target: int
            Label of the given index.
        parts: torch.FloatTensor, [15, 4]
            Landmark annotations.
        boxes: torch.FloatTensor, [5, ]
            Bounding box annotations.
        """
        # load the variables according to the current index and split
        if self._train:
            image_path = self._train_data[index]
            target = self._train_labels[index]
            encoded_img = self.train_encoded_prototype[10*index:(10)*(index+1)]

        else:
            image_path = self._test_data[index]
            target = self._test_labels[index]
            

        # load the image
        image = self.loader(image_path)
        image = np.array(image)
        
        #load caption
        #caption, token_type, input_mask, nwords = self._LoadCaption(caption_path) # revise here

        # convert the image into a PIL image for transformation
        image = PIL.Image.fromarray(image)

        # apply transformation
        if self._transform is not None:
            image = self._transform(image)
        if self._train:
            return image, target, encoded_img # , caption, token_type, input_mask, nwords
        else:
            return image, target
    
    def __len__(self):
        """Return the length of the dataset."""
        if self._train:
            return len(self._train_data)
        return len(self._test_data)

    def _get_file_list(self, train=True):
        """Prepare the data for train/test split and save onto disk."""

        # load the list into numpy arrays
        image_path = self._root + '/CUB_200_2011/images/'
        #cpation_path = self._root + "/CUB_200_2011/text_c10"
        id2name = np.genfromtxt(self._root + '/CUB_200_2011/images.txt', dtype=str)
        id2train = np.genfromtxt(self._root + '/CUB_200_2011/train_test_split.txt', dtype=int)
        #id2part = np.genfromtxt(self._root + '/CUB_200_2011/parts/part_locs.txt', dtype=float)
        #id2box = np.genfromtxt(self._root + '/CUB_200_2011/bounding_boxes.txt', dtype=float)

        # creat empty lists
        train_data = []
        train_caption = []
        train_labels = []
        #train_parts = []
        #train_boxes = []
        
        test_data = []
        test_caption = []
        test_labels = []
        #test_parts = []
        #test_boxes = []

        # iterating all samples in the whole dataset
        for id_ in range(id2name.shape[0]):

            # load each variable
            image = os.path.join(image_path, id2name[id_, 1])

            # Label starts with 0
            label = int(id2name[id_, 1][:3]) - 1


            # training split
            if id2train[id_, 1] == 1:
                
                train_data.append(image)
                #train_caption.append(caption)
                train_labels.append(label)
                #train_parts.append(parts)
                #train_boxes.append(boxes)
            # testing split
            else:
                test_data.append(image)
                #test_caption.append(caption)
                test_labels.append(label)
                #test_parts.append(parts)
                #test_boxes.append(boxes)

        # return accoring to different splits
        if train == True:
            return train_data, train_labels
        else:
            return test_data, test_labels


if __name__ == "__main__":
    import torchvision.transforms as transforms
    train_transforms = transforms.Compose([
            transforms.Resize(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.1),
            transforms.RandomCrop(size=224),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
    test_transforms = transforms.Compose([
            transforms.Resize(size=224),
            transforms.CenterCrop(size=224),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
    cub200_datasets = CUB200(root="/data/fengyi/wangjiaqi/datasets/CUB_VL", arch="vgg16",train=True,transform=train_transforms, cropped = True, resize=224)
    
    image, target, encoded_img = cub200_datasets.__getitem__(0)
    print(image.size(), target, encoded_img.size())
    #[3,224,224]0 [10,512]