import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import json
import numpy as np
import os
from PIL import Image
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import os
# other libs
import numpy as np
import PIL.Image
import torch
import torch.utils.data

def pil_loader(path):
    with open(path, 'rb') as f:
        img = PIL.Image.open(f)
        return img.convert('RGB')

class CUB200_CAPTION(torch.utils.data.Dataset):
    """
    CUB200 dataset.

    Variables
    ----------
        _root, str: Root directory of the dataset.
        _train, bool: Load train/test data.
        _transform, callable: A function/transform that takes in a PIL.Image
            and transforms it.
        _train_data, list of np.array.
        _train_labels, list of int.
        _train_parts, list np.array.
        _train_boxes, list np.array.
        _test_data, list of np.array.
        _test_labels, list of int.
        _test_parts, list np.array.
        _test_boxes, list np.array.
    """
    def __init__(self, root, train=True, transform=None, resize=448):
        """
        Load the dataset.

        Args
        ----------
        root: str
            Root directory of the dataset.
        train: bool
            train/test data split.
        transform: callable
            A function/transform that takes in a PIL.Image and transforms it.
        resize: int
            Length of the shortest of edge of the resized image. Used for transforming landmarks and bounding boxes.

        """
        self._root = root
        self._train = train
        self._transform = transform
        self.loader = pil_loader
        self.newsize = resize
        # 15 key points provided by CUB
        self.num_kps = 15

        if not os.path.isdir(root):
            os.mkdir(root)

        # Load all data into memory for best IO efficiency. This might take a while
        if self._train:
            self._train_data, self._train_caption, self._train_labels, self._train_parts, self._train_boxes = self._get_file_list(train=True)
            assert (len(self._train_data) == 5994
                    and len(self._train_labels) == 5994)
        else:
            self._test_data, self._test_caption, self._test_labels, self._test_parts, self._test_boxes = self._get_file_list(train=False)
            assert (len(self._test_data) == 5794
                    and len(self._test_labels) == 5794)
            
        # tokenizer for Bert model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __getitem__(self, index):
        """
        Retrieve data samples.

        Args
        ----------
        index: int
            Index of the sample.

        Returns
        ----------
        image: torch.FloatTensor, [3, H, W]
            Image of the given index.
        target: int
            Label of the given index.
        parts: torch.FloatTensor, [15, 4]
            Landmark annotations.
        boxes: torch.FloatTensor, [5, ]
            Bounding box annotations.
        """
        # load the variables according to the current index and split
        if self._train:
            image_path = self._train_data[index]
            caption_path = self._train_caption[index]
            target = self._train_labels[index]
            parts = self._train_parts[index]
            boxes = self._train_boxes[index]

        else:
            image_path = self._test_data[index]
            caption_path = self._test_caption[index]
            target = self._test_labels[index]
            parts = self._test_parts[index]
            boxes = self._test_boxes[index]

        # load the image
        image = self.loader(image_path)
        image = np.array(image)
        
        #load caption
        caption, token_type, input_mask, nwords = self._LoadCaption(caption_path) # revise here
        # numpy arrays to pytorch tensors
        parts = torch.from_numpy(parts).float()
        boxes = torch.from_numpy(boxes).float()

        # calculate the resize factor
        # if original image height is larger than width, the real resize factor is based on width
        if image.shape[0] >= image.shape[1]:
            factor = self.newsize / image.shape[1]
        else:
            factor = self.newsize / image.shape[0]

        # transform 15 landmarks according to the new shape
        # each landmark has a 4-element annotation: <landmark_id, column, row, existence>
        for j in range(self.num_kps):

            # step in only when the current landmark exists
            if abs(parts[j][-1]) > 1e-5:
                # calculate the new location according to the new shape
                parts[j][-3] = parts[j][-3] * factor
                parts[j][-2] = parts[j][-2] * factor

        # rescale the annotation of bounding boxes
        # the annotation format of the bounding boxes are <image_id, col of top-left corner, row of top-left corner, width, height>
        boxes[1:] *= factor

        # convert the image into a PIL image for transformation
        image = PIL.Image.fromarray(image)

        # apply transformation
        if self._transform is not None:
            image = self._transform(image)

        return image, target, parts, boxes, caption, token_type, input_mask, nwords
    
    def _LoadCaption(self, caption):
        """
        returns:
            tokens_tensor: (sequence_max,) tokenized caption  向量化后的标题
            segments_tensors: (sequence_max,) all 1s   
            input_mask: (sequence_max,) 1 for valid tokens; 0 for padding tokens
            nw: int; number of valid tokens in the input caption 标题的长度
        """
        with open(caption) as f:
            lines = f.read().splitlines()
        # randomly choose 1 out of 10 captions in each iteration
        caption = '[CLS] ' + lines[np.random.randint(0, len(lines))] + ' [SEP]'
        # tokenize the sequence
        tokenized_text = self.tokenizer.tokenize(caption)
        nw = len(tokenized_text)
        sequence_max = 32
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        input_mask = np.ones(len(tokenized_text))
        # cut sequence if too long
        p = sequence_max - nw
        if p > 0:
            indexed_tokens = np.pad(indexed_tokens, (0, p), 'constant', constant_values=(0, 0))
            input_mask = np.pad(input_mask, (0, p), 'constant', constant_values=(0, 0))
        elif p < 0:
            indexed_tokens = indexed_tokens[0:p]
            input_mask = input_mask[0:p]
            nw = sequence_max
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.squeeze()
        segments_tensors = torch.ones(sequence_max)
        segments_tensors = segments_tensors.type('torch.LongTensor')
        input_mask = torch.from_numpy(input_mask)
        input_mask = input_mask.type('torch.LongTensor')
        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        assert len(tokens_tensor) == sequence_max
        assert len(input_mask) == sequence_max
        assert len(segments_tensors) == sequence_max

        return tokens_tensor, segments_tensors, input_mask, nw



    def __len__(self):
        """Return the length of the dataset."""
        if self._train:
            return len(self._train_data)
        return len(self._test_data)

    def _get_file_list(self, train=True):
        """Prepare the data for train/test split and save onto disk."""

        # load the list into numpy arrays
        image_path = self._root + '/CUB_200_2011/images/'
        cpation_path = self._root + "/CUB_200_2011/text_c10"
        id2name = np.genfromtxt(self._root + '/CUB_200_2011/images.txt', dtype=str)
        id2train = np.genfromtxt(self._root + '/CUB_200_2011/train_test_split.txt', dtype=int)
        id2part = np.genfromtxt(self._root + '/CUB_200_2011/parts/part_locs.txt', dtype=float)
        id2box = np.genfromtxt(self._root + '/CUB_200_2011/bounding_boxes.txt', dtype=float)

        # creat empty lists
        train_data = []
        train_caption = []
        train_labels = []
        train_parts = []
        train_boxes = []
        
        test_data = []
        test_caption = []
        test_labels = []
        test_parts = []
        test_boxes = []

        # iterating all samples in the whole dataset
        for id_ in range(id2name.shape[0]):

            # load each variable
            image = os.path.join(image_path, id2name[id_, 1])
            class_name = id2name[id_,1].split("/")[0]
            file_name = id2name[id_,1].split("/")[1].split(".")[-2]+".txt"
            caption = os.path.join(cpation_path,class_name,file_name)
            # Label starts with 0
            label = int(id2name[id_, 1][:3]) - 1
            parts = id2part[id_*self.num_kps : id_*self.num_kps+self.num_kps][:, 1:]
            boxes = id2box[id_]

            # training split
            if id2train[id_, 1] == 1:
                train_data.append(image)
                train_caption.append(caption)
                train_labels.append(label)
                train_parts.append(parts)
                train_boxes.append(boxes)
            # testing split
            else:
                test_data.append(image)
                test_caption.append(caption)
                test_labels.append(label)
                test_parts.append(parts)
                test_boxes.append(boxes)

        # return accoring to different splits
        if train == True:
            return train_data, train_caption, train_labels, train_parts, train_boxes
        else:
            return test_data, test_caption, test_labels, test_parts, test_boxes

class CUB200_CAPTION_AUG(torch.utils.data.Dataset):
    """
    CUB200 dataset.

    Variables
    ----------
        _root, str: Root directory of the dataset.
        _train, bool: Load train/test data.
        _transform, callable: A function/transform that takes in a PIL.Image
            and transforms it.
        _train_data, list of np.array.
        _train_labels, list of int.
        _train_parts, list np.array.
        _train_boxes, list np.array.
        _test_data, list of np.array.
        _test_labels, list of int.
        _test_parts, list np.array.
        _test_boxes, list np.array.
    """
    def __init__(self, root, train=True, transform=None, cropped = True, resize=448):
        """
        Load the dataset.

        Args
        ----------
        root: str
            Root directory of the dataset.
        train: bool
            train/test data split.
        transform: callable
            A function/transform that takes in a PIL.Image and transforms it.
        resize: int
            Length of the shortest of edge of the resized image. Used for transforming landmarks and bounding boxes.

        """
        self._root = root
        self._train = train
        self._transform = transform
        self.loader = pil_loader
        self.newsize = resize
        # 15 key points provided by CUB
        self.num_kps = 15
        if cropped == True:
            self.train_aug_path = '/mnt/workspace/datasets/CUB_200_2011/CUB_200_2011/new_datasets/cub200_cropped/train_cropped_augmented_2/'
        else:
            self.train_aug_path =  '/mnt/workspace/datasets/CUB_200_2011/CUB_200_2011/new_datasets/cub200_nocropped/train_augmented_2/'
        if not os.path.isdir(root):
            os.mkdir(root)

        # Load all data into memory for best IO efficiency. This might take a while
        if self._train:
            self._train_data, self._train_caption, self._train_labels = self._get_file_list(train=True)
            assert (len(self._train_data) == 239760
                    and len(self._train_labels) == 239760)
            #print(len(self._train_data),len(self._train_labels))
        else:
            self._test_data, self._test_caption, self._test_labels = self._get_file_list(train=False)
            assert (len(self._test_data) == 5794
                    and len(self._test_labels) == 5794)
            
        # tokenizer for Bert model
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __getitem__(self, index):
        """
        Retrieve data samples.

        Args
        ----------
        index: int
            Index of the sample.

        Returns
        ----------
        image: torch.FloatTensor, [3, H, W]
            Image of the given index.
        target: int
            Label of the given index.
        parts: torch.FloatTensor, [15, 4]
            Landmark annotations.
        boxes: torch.FloatTensor, [5, ]
            Bounding box annotations.
        """
        # load the variables according to the current index and split
        if self._train:
            image_path = self._train_data[index]
            caption_path = self._train_caption[index]
            target = self._train_labels[index]
            

        else:
            image_path = self._test_data[index]
            caption_path = self._test_caption[index]
            target = self._test_labels[index]
            

        # load the image
        image = self.loader(image_path)
        image = np.array(image)
        
        #load caption
        caption, token_type, input_mask, nwords = self._LoadCaption(caption_path) # revise here

        # convert the image into a PIL image for transformation
        image = PIL.Image.fromarray(image)

        # apply transformation
        if self._transform is not None:
            image = self._transform(image)

        return image, target, caption, token_type, input_mask, nwords
    
    def _LoadCaption(self, caption):
        """
        returns:
            tokens_tensor: (sequence_max,) tokenized caption
            segments_tensors: (sequence_max,) all 1s   
            input_mask: (sequence_max,) 1 for valid tokens; 0 for padding tokens
            nw: int; number of valid tokens in the input caption
        """
        with open(caption) as f:
            lines = f.read().splitlines()
        # randomly choose 1 out of 10 captions in each iteration
        caption = '[CLS] ' + lines[np.random.randint(0, len(lines))] + ' [SEP]'
        # tokenize the sequence
        tokenized_text = self.tokenizer.tokenize(caption)
        nw = len(tokenized_text)
        sequence_max = 32
        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
        input_mask = np.ones(len(tokenized_text))
        # cut sequence if too long
        p = sequence_max - nw
        if p > 0:
            indexed_tokens = np.pad(indexed_tokens, (0, p), 'constant', constant_values=(0, 0))
            input_mask = np.pad(input_mask, (0, p), 'constant', constant_values=(0, 0))
        elif p < 0:
            indexed_tokens = indexed_tokens[0:p]
            input_mask = input_mask[0:p]
            nw = sequence_max
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.squeeze()
        segments_tensors = torch.ones(sequence_max)
        segments_tensors = segments_tensors.type('torch.LongTensor')
        input_mask = torch.from_numpy(input_mask)
        input_mask = input_mask.type('torch.LongTensor')
        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        assert len(tokens_tensor) == sequence_max
        assert len(input_mask) == sequence_max
        assert len(segments_tensors) == sequence_max

        return tokens_tensor, segments_tensors, input_mask, nw


    def __len__(self):
        """Return the length of the dataset."""
        if self._train:
            return len(self._train_data)
        return len(self._test_data)

    def _get_file_list(self, train=True):
        """Prepare the data for train/test split and save onto disk."""

        # load the list into numpy arrays
        image_path = self._root + '/CUB_200_2011/images/'
        cpation_path = self._root + "/CUB_200_2011/text_c10"
        id2name = np.genfromtxt(self._root + '/CUB_200_2011/images.txt', dtype=str)
        id2train = np.genfromtxt(self._root + '/CUB_200_2011/train_test_split.txt', dtype=int)
        #id2part = np.genfromtxt(self._root + '/CUB_200_2011/parts/part_locs.txt', dtype=float)
        #id2box = np.genfromtxt(self._root + '/CUB_200_2011/bounding_boxes.txt', dtype=float)

        # creat empty lists
        train_data = []
        train_caption = []
        train_labels = []
        #train_parts = []
        #train_boxes = []
        
        test_data = []
        test_caption = []
        test_labels = []
        #test_parts = []
        #test_boxes = []

        # iterating all samples in the whole dataset
        for id_ in range(id2name.shape[0]):

            # load each variable
            image = os.path.join(image_path, id2name[id_, 1])
            #print(id2name[id_, 1]) #001.Black_footed_Albatross/Black_Footed_Albatross_0046_18.jpg
            class_name = id2name[id_,1].split("/")[0]
            file_name = id2name[id_,1].split("/")[1].split(".")[-2]+".txt"
            caption = os.path.join(cpation_path,class_name,file_name)
            # Label starts with 0
            label = int(id2name[id_, 1][:3]) - 1


            # training split
            if id2train[id_, 1] == 1:
                file_path = os.path.join(self.train_aug_path,class_name)
                filtered_list = list(filter(lambda x: id2name[id_,1].split("/")[1] in x, os.listdir(file_path)))
                for aug_img in filtered_list:
                    train_data.append(os.path.join(file_path,aug_img))
                    train_caption.append(caption)
                    train_labels.append(label)
                #train_parts.append(parts)
                #train_boxes.append(boxes)
            # testing split
            else:
                test_data.append(image)
                test_caption.append(caption)
                test_labels.append(label)
                #test_parts.append(parts)
                #test_boxes.append(boxes)

        # return accoring to different splits
        if train == True:
            return train_data, train_caption, train_labels
        else:
            return test_data, test_caption, test_labels

if __name__ == "__main__":
    import torchvision.transforms as transforms
    train_transforms = transforms.Compose([
            transforms.Resize(size=224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.1),
            transforms.RandomCrop(size=224),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
    test_transforms = transforms.Compose([
            transforms.Resize(size=224),
            transforms.CenterCrop(size=224),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
    cub200_datasets = CUB200_AUG(root="/mnt/workspace/datasets/CUB_200_2011",train=True,transform=train_transforms, cropped = True, resize=224)
    