from .base_dataset import BaseDataset
import torch
import random
import os


class MMIMDBDataset(BaseDataset):
    def __init__(self, *args, split="", missing_info={}, enable_mt=False, text_model='bert-base-uncased', **kwargs):
        assert split in ["train", "val", "test"]
        self.split = split
        self.enable_mt = enable_mt

        if split == "train":
            names = ["mmimdb_train"]
        elif split == "val":
            names = ["mmimdb_dev"]
        elif split == "test":
            names = ["mmimdb_test"]  

        super().__init__(
            *args,
            **kwargs,
            names=names,
            text_column_name="plots",
            remove_duplicate=False,
            text_model=text_model,
        )
        
        # missing modality control        
        self.simulate_missing = missing_info['simulate_missing']
        missing_ratio = missing_info['ratio'][split]
        mratio = str(missing_ratio).replace('.','')
        missing_type = missing_info['type'][split]    
        both_ratio = missing_info['both_ratio']
        missing_table_root = missing_info['missing_table_root']
        missing_table_name = f'{names[0]}_missing_{missing_type}_{mratio}.pt'
        missing_table_path = os.path.join(missing_table_root, missing_table_name)
        
        # use image data to formulate missing table
        total_num = len(self.table['image'])

        if os.path.exists(missing_table_path):
            missing_table = torch.load(missing_table_path)
            if len(missing_table) != total_num:
                print('missing table mismatched!')
                exit()
        else:
            missing_table = torch.zeros(total_num)
            
            if missing_ratio > 0:
                missing_index = random.sample(range(total_num), int(total_num*missing_ratio))

                if missing_type == 'text':
                    missing_table[missing_index] = 1
                elif missing_type == 'image':
                    missing_table[missing_index] = 2
                elif missing_type == 'both':
                    missing_table[missing_index] = 1
                    missing_index_image  = random.sample(missing_index, int(len(missing_index)*both_ratio))
                    missing_table[missing_index_image] = 2
                    
                torch.save(missing_table, missing_table_path)

        self.missing_table = missing_table
        
    def __getitem__(self, index):
        image_index, question_index = self.index_mapper[index]
        
        simulate_missing_type = 0
        if self.split == 'train' and self.simulate_missing and self.missing_table[image_index] == 0:
            simulate_missing_type = random.choice([0,1,2])
            
        image_tensor = self.get_image(index)["image"]
        
        # missing image, dummy image is 0
        if self.missing_table[image_index] == 2 or simulate_missing_type == 2:
            image_tensor = torch.zeros(image_tensor.size()).float()
            
        # missing text, dummy text is ''
        if self.missing_table[image_index] == 1 or simulate_missing_type == 1:
            text = ''
        else:
            text = self.get_text(index, enable_mt=self.enable_mt)["text"]

        
        labels = self.table["label"][image_index].as_py()

        return {
            "image": image_tensor,
            "text": text,
            "label": labels,
            "missing_type": self.missing_table[image_index].item()+simulate_missing_type,
            "filename": self.table["image_id"][image_index].as_py(),
            "index": image_index,
        }
