import torch
import torch.nn as nn
import os
import copy
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from transformers import RobertaTokenizer
from get_id_list import get_id_list_separate
from transformers import DataCollatorForLanguageModeling
from PIL import Image
from random import shuffle
from utils import *
import yaml


with open('./bt_got_mlm_itm_config.yml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

class CCImageTextDataset(Dataset):

    def __init__(self, image_list, text_list, maxlen = 100, model_name = 'roberta-base', image_transform = CCImagePairTransform(train_transform = True), text_transform = CCTextPairTransform(train_transform = True)):

        self.maxlen = maxlen
        self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
        self.image_list = image_list
        self.text_list = text_list
        self.image_transform = image_transform
        self.text_transform = text_transform
        self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=True, mlm_probability=config["mlm_prob"])


    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = Image.open(self.image_list[idx]).convert('RGB')
        image_sample = self.image_transform(image)

        text_data = self.text_list[idx]
        aug_text_data = self.text_transform(text_data)

        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded_pair = self.tokenizer(list(aug_text_data), 
                                      padding='max_length',  # Pad to max_length
                                      truncation=True,  # Truncate to max_length
                                      max_length=self.maxlen,  
                                      return_tensors='pt')  # Return torch.Tensor objects
                        
        token_ids = encoded_pair['input_ids']  # tensor of token ids
        attn_masks = encoded_pair['attention_mask']  # binary tensor with "0" for padded values and "1" for the other values

        if len(aug_text_data) == 2:
        
            data_text_input_list_0 = []
            for idx in range(token_ids[0].size(0)):
                data_text_input_list_0.append(token_ids[0][idx])
            mlm_dict_0 = self.data_collator(data_text_input_list_0)
            text_mlm_ids_0 = mlm_dict_0['input_ids']
            text_mlm_labels_0 = mlm_dict_0['labels']
            
            data_text_input_list_1 = []
            for idx in range(token_ids[1].size(0)):
                data_text_input_list_1.append(token_ids[1][idx])
            mlm_dict_1 = self.data_collator(data_text_input_list_1)
            text_mlm_ids_1 = mlm_dict_1['input_ids']
            text_mlm_labels_1= mlm_dict_1'labels']

        else:

            data_text_input_list = []
            for idx in range(token_ids.size(0)):
                data_text_input_list.append(token_ids[idx])
            mlm_dict = self.data_collator(data_text_input_list)
            text_mlm_ids = mlm_dict['input_ids']
            text_mlm_labels = mlm_dict['labels']
            

        if len(aug_text_data) == 2:
            return image_sample, (token_ids[0], token_ids[1]), (attn_masks[0], attn_masks[1]), (text_mlm_ids_0, text_mlm_ids_1), (text_mlm_labels_0, text_mlm_labels_1)
        else:
            return image_sample, token_ids, attn_masks, text_mlm_ids, text_mlm_labels
