import pandas as pd
import pickle
import os
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchvision import transforms
from torch.utils.data import Dataset
import torchvision.datasets as dset
from torch.utils.data import Dataset
from transformers import AutoTokenizer

MODALITY = ["vision", "text"]

class MSCOCODataset(Dataset):

    def __init__(self,
                 root,
                 annFile_cap,
                 annFile_det,
                 image_dict_template, 
                 label_columns, 
                 opt = None,
                 transform=None,
                 image_size = 224,
                 img_ids_list = None,
                 use_img_ids_list = False,
                 one_hot_labels = None,
            #     bert_embs = None,
                 text_format = 'raw',  
                 debug = False, 
                ):
        '''
        text_format: how to encode the notes
            "bert": use BERT encodings
            "raw": return the text as a string
        '''
        super(MSCOCODataset, self).__init__()
        from pycocotools.coco import COCO

        self.root = os.path.expanduser(root)
        # Text captions
        self.coco_cap = COCO(annFile_cap)
        self.labels = one_hot_labels
        # Object detection
        self.coco_det = COCO(annFile_det)
        self.label_columns = label_columns
        print("MSCOCO labels are {}".format(self.label_columns))
        
        # MSCOCO numerical labels map to order from 0 to len(label_columns) - 1
        label_to_torch = {}
        for i in range(len(self.label_columns)): 
            label_to_torch[self.label_columns[i]] = i
        self.label_to_torch = label_to_torch

        # List of image ids in dataset
        if use_img_ids_list:
            self.ids = img_ids_list
        else:
            self.ids = list(self.coco_cap.imgs.keys())

        # Transformations to perform for dataset
        self.transform = transform
        # Size of image
        self.image_size = image_size
        
        # Store the parser flags
        self.opt = opt

        # self.bert_embs = bert_embs
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.text_format = text_format
        self.one_hot_labels = one_hot_labels
        
        self.debug = debug
        
        img_id_to_index = {}
        counter = 0
        for img_id in self.ids: 
            img_id_to_index[img_id] = counter
            counter += 1
        
        image_dict = {}
        # Iterate over classes
        for key, value in image_dict_template.items():
            # Only if category is selected
            if key in self.label_columns:
                new_list = []
                # Iterate over images per class
                for img_path in value: 
                    # Get the index in dataset 
                    if img_path[1] in img_id_to_index:
                        new_list.append([img_path[0], img_id_to_index[img_path[1]]])
                image_dict[self.label_to_torch[key]] = new_list
            
        self.image_dict = image_dict
        # run to compute the image_list and rest from image_dict
        self.init_setup()

    # Get len of dataset
    def __len__(self):
        return len(self.ids)
    
    def init_setup(self):
        self.n_files = np.sum([len(self.image_dict[key]) for key in self.image_dict.keys()])
        self.avail_classes = sorted(list(self.image_dict.keys()))
        
        df = pd.concat([pd.DataFrame(self.image_dict[key]).assign(label = key) for key in self.image_dict])
        df.columns = ['path', 'idx', 'label']
        df = df.sort_values(by = 'idx', ascending = True)
        idx_mapping = df.drop_duplicates(subset = ['path']).set_index('path')['idx']
        agg_list = df.groupby('path').agg({'label': lambda x: list(x)}).loc[idx_mapping.index].reset_index().to_numpy().tolist()
        
        if self.opt.exclusive:
            for i in agg_list:
                assert len(i[-1]) == 1 # one label per image
                i[-1] = i[-1][0]
        else:
            for i in agg_list:
                i[-1] = np.array(i[-1])
        self.image_list = agg_list
        self.image_paths = self.image_list
        self.is_init = True

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target)
        """
        img_id = self.ids[index]  
        
        if self.text_format == 'raw':
            # Open the caption annotation file
            ann_ids_cap = self.coco_cap.getAnnIds(imgIds=img_id)
            anns_cap = self.coco_cap.loadAnns(ann_ids_cap)
            # Text caption
            target_cap = [ann['caption'] for ann in anns_cap]
            caption = "".join(target_cap).lower()
        elif self.text_format == 'bert':
            raise NotImplementedError()
            # caption = self.bert_embs[index, :]
        else:
            raise NotImplementedError("Text format for caption is not supported")

        text_enc = self.tokenizer.encode_plus(caption, return_tensors="pt", max_length = 512,
                                        padding = 'max_length', truncation = True)
        input_ids, token_type_ids, attention_mask  = text_enc['input_ids'][0, :], text_enc['token_type_ids'][0, :], text_enc['attention_mask'][0, :]
        
        if self.opt.exclusive: 
            ann_ids_det = self.coco_det.getAnnIds(imgIds=img_id)
            anns_det = self.coco_det.loadAnns(ann_ids_det)
            # Multiclass -> single integer
            label = self.label_to_torch[anns_det[0]['category_id']]
        else: 
            # Multilabel -> one hot
            label = self.one_hot_labels[index, :]
            
        # For debugging purposes only
        if self.debug: 
            path = self.coco_cap.loadImgs(img_id)[0]['file_name']
            return img_id, label, path 

        # Image path
        path = self.coco_cap.loadImgs(img_id)[0]['file_name']
        # Open image
        img = Image.open(os.path.join(self.root, path))
        
        if len(img.size) == 2: 
            img = img.convert('RGB')
            
        img = transforms.Compose([
                        transforms.Resize([self.image_size, self.image_size])
                        ])(img)
        # Apply additional transform
        if self.transform is not None:
            img = self.transform(img)
            
        if self.opt.multimodal: 
            return {
                'labels': label, 
                'x':img, 
                'input_ids': input_ids, 
                'token_type_ids': token_type_ids,
                'attention_mask': attention_mask, 
                'idx': index
            }
        else: 
            # To match with MultiLabel Dataset
            # label, image, 0, index
            if "vision" in self.opt.modality:
                return {
                    'labels': label, 
                    'x':img, 
                    'idx': index
                }
            elif "text" in self.opt.modality:
                return {
                    'labels': label, 
                    'input_ids': input_ids, 
                    'token_type_ids': token_type_ids,
                    'attention_mask': attention_mask, 
                    'idx': index
                }

    
def Give(opt, label_names, datapath):
    opt.label_names = label_names
    datapath = Path(datapath)
    
    if opt.multimodal: 
        print("MSCOCO multimodal setting")
     
    # Load Multiclass version of the labels
    if opt.exclusive: 
        with open(datapath/"coco/preprocessed/train2017_label_one_hot_multiclass.pkl", 'rb') as f:
            one_hot_labels = pickle.load(f)
    # Load the Multilabel version of the labels
    else: 
        with open(datapath/"coco/preprocessed/train2017_label_one_hot.pkl", 'rb') as f:
            one_hot_labels = pickle.load(f)
    
    # Load the conversion files between number and label name
    with open(datapath/"coco/preprocessed/train2017_name_to_id.pkl", 'rb') as f:
        name_to_id = pickle.load(f)
    with open(datapath/"coco/preprocessed/train2017_id_to_name.pkl", 'rb') as f:
        id_to_name = pickle.load(f)
    
    # Only get the images with labels that we are interested in 
    one_hot_labels = one_hot_labels[[name_to_id[x] for x in opt.label_names]]
    new_indices = ~(one_hot_labels == 0).all(axis=1)
    one_hot_labels = one_hot_labels[new_indices]
    print("Total len of dataset: {}".format(len(one_hot_labels)))
    
    # if opt.multimodal:
    #     # Load BERT embeddings
    #     bert_embs = torch.load(datapath/"coco/preprocessed/mscoco_bert_embs.pt")
    #     if opt.exclusive: 
    #         all_train_ids = pd.read_csv(datapath/"coco/preprocessed/train2017_img_ids.csv")
    #         indices = all_train_ids['img_id'].isin(one_hot_labels.index)
    #         bert_embs = bert_embs[indices, :]
    
    # Load the template 
    with open(datapath/"coco/preprocessed/train2017_image_dict_template_v2.pkl", 'rb') as f:
        label_to_path = pickle.load(f)
    
    # List of all image ids in the dataset
    img_ids_list = one_hot_labels.index.astype(int)

    train_idx, test_idx = train_test_split(np.arange(len(img_ids_list)), test_size=(1-opt.tv_split_perc))
    val_idx, test_idx = train_test_split(test_idx, test_size=0.5)

    train_ids, val_ids, test_ids = img_ids_list[train_idx], img_ids_list[val_idx], img_ids_list[test_idx]
    train_label, val_label, test_label = one_hot_labels.iloc[train_idx, :], one_hot_labels.iloc[val_idx, :], one_hot_labels.iloc[test_idx, :]
    
    # if opt.multimodal:
    #     train_bert, val_bert, test_bert = bert_embs[train_idx, :], bert_embs[val_idx, :], bert_embs[test_idx, :]
    # else: 
    #     train_bert, val_bert, test_bert = None, None, None
    
    # The square size to center-crop the images to
    image_size = 224

    transform_list = []
    # Add augmentation
    if opt.augmentation:
        transform_list.extend([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15)
        ])
    transform_list.extend([
            transforms.CenterCrop([image_size, image_size]),
            transforms.ToTensor()
            ])
    # normalize the values of the images
    if not opt.not_pretrained:
        transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]))
    # compose the transforms together
    train_transform = transforms.Compose(transform_list)

    eval_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    # Training
    train_dataset = MSCOCODataset(root = datapath/"coco/train2017",
                                annFile_cap = datapath/"coco/annotations/captions_train2017.json",
                                annFile_det = datapath/"coco/annotations/instances_train2017.json",
                                opt = opt,
                                one_hot_labels = train_label.to_numpy(),
                           #     bert_embs = train_bert,
                                text_format = 'raw',
                                transform = train_transform,
                                image_size = image_size,
                                img_ids_list = [x.item() for x in train_ids.values], 
                                use_img_ids_list = True,
                                debug = False,
                                image_dict_template = label_to_path, 
                                label_columns = train_label.columns)
    val_dataset = MSCOCODataset(root = datapath/"coco/train2017",
                                annFile_cap = datapath/"coco/annotations/captions_train2017.json",
                                annFile_det = datapath/"coco/annotations/instances_train2017.json",
                                opt = opt,
                                one_hot_labels = val_label.to_numpy(),
                           #     bert_embs = val_bert,
                                text_format = 'raw',
                                transform = eval_transform,
                                image_size = image_size,
                                img_ids_list = [x.item() for x in val_ids.values],
                                use_img_ids_list = True,
                                debug = False,
                                image_dict_template = label_to_path, 
                                label_columns = val_label.columns)
    test_dataset = MSCOCODataset(root = datapath/"coco/train2017",
                                annFile_cap = datapath/"coco/annotations/captions_train2017.json",
                                annFile_det = datapath/"coco/annotations/instances_train2017.json",
                                opt = opt,
                                one_hot_labels = test_label.to_numpy(),
                              #  bert_embs = test_bert,
                                text_format = 'raw',
                                transform = eval_transform,
                                image_size = image_size,
                                img_ids_list = [x.item() for x in test_ids.values], 
                                use_img_ids_list = True,
                                image_dict_template = label_to_path, 
                                label_columns = test_label.columns)
    eval_dataset = MSCOCODataset(root = datapath/"coco/train2017",
                                annFile_cap = datapath/"coco/annotations/captions_train2017.json",
                                annFile_det = datapath/"coco/annotations/instances_train2017.json",
                                opt = opt,
                                one_hot_labels = train_label.to_numpy(),
                             #   bert_embs = train_bert,
                                text_format = 'raw',
                                # Eval transform instead
                                transform = eval_transform,
                                image_size = image_size,
                                img_ids_list = [x.item() for x in train_ids.values],
                                use_img_ids_list = True,
                                image_dict_template = label_to_path, 
                                label_columns = train_label.columns)
    evaluation_train_dataset = MSCOCODataset(root = datapath/"coco/train2017",
                                annFile_cap = datapath/"coco/annotations/captions_train2017.json",
                                annFile_det = datapath/"coco/annotations/instances_train2017.json",
                                opt = opt,
                                one_hot_labels = train_label.to_numpy(),
                             #   bert_embs = train_bert,
                                text_format = 'raw',
                                transform = train_transform,
                                image_size = image_size,
                                img_ids_list = [x.item() for x in train_ids.values], 
                                use_img_ids_list = True,
                                debug = False,
                                image_dict_template = label_to_path, 
                                label_columns = train_label.columns)

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':evaluation_train_dataset}
