import re
import os
import numpy as np
import pandas as pd
import cv2
import tqdm
import pickle
import numpy.random as random
# import random
import torch
import torch.utils.data as data
import json
from PIL import Image
from nltk.tokenize import RegexpTokenizer
from transformers import AutoTokenizer
from CARZero.constants import *
import ast
from CARZero.models.CARZero_model_dqn_wo_self_atten_gl_mlp import CARZeroDQNWOSAGLMLP
from CARZero import CARZero
from sklearn.preprocessing import MultiLabelBinarizer

class OpenIImageDataset(data.Dataset):
    def __init__(self, cfg, split="train", transform=None):

        if OpenI_DATA_DIR is None:
            raise RuntimeError("OpenI data path empty")

        self.cfg = cfg
        self.cfg_classify = cfg
        self.transform = transform

        # creat label
        mapping = dict()
        mapping["Pleural_Thickening"] = ["pleural thickening"]
        mapping["Infiltration"] = ["Infiltrate"]
        mapping["Atelectasis"] = ["Atelectases"]
        csv = pd.read_csv(OpenI_LABEL_CSV)
        csv = csv.replace(np.nan, "-1")
        gt = []
        for pathology in OpenI_pathologies:
            mask = csv["labels_automatic"].str.contains(pathology.lower())
            if pathology in mapping:
                for syn in mapping[pathology]:
                    # print("mapping", syn)
                    mask |= csv["labels_automatic"].str.contains(syn.lower())
            gt.append(mask.values)
        gt = np.asarray(gt).T
        gt = gt.astype(np.float32)
        gt[np.where(np.sum(gt, axis=1) == 0), -1] = 1
        label = gt[:, :-1]

        # Rename pathologies
        pathologies = np.char.replace(OpenI_pathologies, "Opacity", "Lung Opacity")
        pathologies = np.char.replace(pathologies, "Lesion", "Lung Lesion")
        pathologies = np.char.replace(pathologies, "Pleural_Thickening", "pleural thickening")
        pathologies = np.char.replace(pathologies, "Infiltration", "Infiltrate")
        pathologies = np.char.replace(pathologies, "Atelectasis", "Atelectases")

        # read in csv file
        self.df = pd.read_csv(OpenI_TRAIN_CSV)

        # get path
        self.df[OpenI_PATH_COL] = self.df[OpenI_PATH_COL].apply(lambda x: x.replace("/defaultShare/OpenI/NLMCXR_png/", OpenI_ABS_PATH))

        # clean nan data
        delete_index = [0, 1]
        self.df = self.df.values.tolist()
        self.df = [self.df[i] for i in range(len(self.df)) if i not in delete_index]
        label = [label[i] for i in range(len(label)) if i not in delete_index]
        self.df = pd.DataFrame(self.df, columns=[OpenI_PATH_COL])
        label = pd.DataFrame(label, columns=pathologies.tolist()[:-1])

        # sample data
        self.test_input = self.df.sample(frac=0.4, random_state=42)
        self.test_label = label.sample(frac=0.4, random_state=42)
        self.train_input = self.df.drop(self.test_input.index)
        self.train_label = label.drop(self.test_label.index)
        self.train_input = self.train_input.sample(frac=self.cfg_classify.data.frac, random_state=42)
        self.train_label = self.train_label.sample(frac=self.cfg_classify.data.frac, random_state=42)

        test_input_save_path = OpenI_TEST_INPUT_CSV.replace('.csv', f'_{self.cfg_classify.data.frac}.csv')
        test_label_save_path = OpenI_TEST_LABEL_CSV.replace('.csv', f'_{self.cfg_classify.data.frac}.csv')

        if not (os.path.exists(test_input_save_path) or os.path.exists(test_label_save_path)):
            self.test_input.to_csv(test_input_save_path, index=False, header=self.test_input.columns.tolist())
            self.test_label.to_csv(test_label_save_path, index=False, header=self.test_label.columns.tolist())

        # text data
        with open(self.cfg_classify.data.text.path, 'r') as f:
            cls_prompts = json.load(f)
        bert_type = self.cfg_classify.model.text.bert_type
        self.tokenizer = AutoTokenizer.from_pretrained(bert_type)
        self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
        processed_txt = {}
        for k, v in cls_prompts.items():
            processed_txt[k] = self.process_text(v, "cpu")
        self.oral_texts = []
        for k, v in cls_prompts.items():
            self.oral_texts.append(v[0])
        caption_ids, attention_mask, token_type_ids = [], [], []
        for cls_name, txts in processed_txt.items():
            caption_ids.append(txts["caption_ids"])
            attention_mask.append(txts["attention_mask"])
            token_type_ids.append(txts["token_type_ids"])
        caption_ids = torch.cat(caption_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)
        self.text_batch = {"caption_ids": caption_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}


    def __getitem__(self, index):
        row_input = self.train_input.iloc[index]
        row_label = self.train_label.iloc[index]
        # get image
        img_path = row_input[OpenI_PATH_COL]
        image = self.get_imgs(img_path, self.transform)
        # get labels
        label = row_label.tolist()
        label = torch.tensor(label)
        # get text
        text = self.text_batch
        oral_text = self.oral_texts
        return image, label, text, oral_text

    def __len__(self):
        return len(self.train_input)

    def get_imgs(self, img_path, transform=None):

        x = cv2.imread(img_path, 0)

        # tranform images
        x = self._resize_img(x, self.cfg_classify.data.image.imsize)
        img = Image.fromarray(x).convert("RGB")

        if os.path.exists(self.cfg.model.ckpt_path):
            img = transform(img)
        elif self.cfg.model.ckpt_path == 'medclip':
            pass
        else:
            raise ValueError

        return img

    def _resize_img(self, img, scale):
        """
        Args:
            img - image as numpy array (cv2)
            scale - desired output image-size as scale x scale
        Return:
            image resized to scale x scale with shortest dimension 0-padded
        """
        size = img.shape
        max_dim = max(size)
        max_ind = size.index(max_dim)

        # Resizing
        if max_ind == 0:
            # image is heigher
            wpercent = scale / float(size[0])
            hsize = int((float(size[1]) * float(wpercent)))
            desireable_size = (scale, hsize)
        else:
            # image is wider
            hpercent = scale / float(size[1])
            wsize = int((float(size[0]) * float(hpercent)))
            desireable_size = (wsize, scale)
        resized_img = cv2.resize(img, desireable_size[::-1], interpolation=cv2.INTER_AREA)  # this flips the desireable_size vector

        # Padding
        if max_ind == 0:
            # height fixed at scale, pad the width
            pad_size = scale - resized_img.shape[1]
            left = int(np.floor(pad_size / 2))
            right = int(np.ceil(pad_size / 2))
            top = int(0)
            bottom = int(0)
        else:
            # width fixed at scale, pad the height
            pad_size = scale - resized_img.shape[0]
            top = int(np.floor(pad_size / 2))
            bottom = int(np.ceil(pad_size / 2))
            left = int(0)
            right = int(0)
        resized_img = np.pad(resized_img, [(top, bottom), (left, right)], "constant", constant_values=0)

        return resized_img

    def process_text(self, text, device):

        if type(text) == str:
            text = [text]

        processed_text_tensors = []
        for t in text:
            # use space instead of newline
            t = t.replace("\n", " ")

            # split sentences
            splitter = re.compile("[0-9]+\.")
            captions = splitter.split(t)
            captions = [point.split(".") for point in captions]
            captions = [sent for point in captions for sent in point]

            all_sents = []

            for t in captions:
                t = t.replace("\ufffd\ufffd", " ")
                tokenizer = RegexpTokenizer(r"\w+")
                tokens = tokenizer.tokenize(t.lower())

                if len(tokens) <= 1:
                    continue

                included_tokens = []
                for t in tokens:
                    t = t.encode("ascii", "ignore").decode("ascii")
                    if len(t) > 0:
                        included_tokens.append(t)
                all_sents.append(" ".join(included_tokens))

            t = " ".join(all_sents)

            text_tensors = self.tokenizer(
                t,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=self.cfg_classify.data.text.word_num,
            )
            text_tensors["sent"] = [
                self.idxtoword[ix] for ix in text_tensors["input_ids"][0].tolist()
            ]
            processed_text_tensors.append(text_tensors)

        caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
        attention_mask = torch.stack(
            [x["attention_mask"] for x in processed_text_tensors]
        )
        token_type_ids = torch.stack(
            [x["token_type_ids"] for x in processed_text_tensors]
        )

        if len(text) == 1:
            caption_ids = caption_ids.squeeze(0).to(device)
            attention_mask = attention_mask.squeeze(0).to(device)
            token_type_ids = token_type_ids.squeeze(0).to(device)
        else:
            caption_ids = caption_ids.squeeze().to(device)
            attention_mask = attention_mask.squeeze().to(device)
            token_type_ids = token_type_ids.squeeze().to(device)

        cap_lens = []
        for txt in text:
            cap_lens.append(len([w for w in txt if not w.startswith("[")]))

        return {
            "caption_ids": caption_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "cap_lens": cap_lens,
        }


class ChestXray14ImageDataset(data.Dataset):
    def __init__(self, cfg, split="train", transform=None):

        if ChestXray14_DATA_DIR is None:
            raise RuntimeError("ChestXray14 data path empty")

        self.cfg_classify = cfg
        self.transform = transform

        # read label and data
        data = pd.read_csv(ChestXray14_TRAIN_DATA, sep=' ', names=[ChestXray14_PATH_COL] + ChestXray14_pathologies)
        self.label = data[ChestXray14_pathologies].values
        self.df = data[ChestXray14_PATH_COL].values
        self.label = pd.DataFrame(self.label, columns=ChestXray14_pathologies)
        self.df = pd.DataFrame(self.df, columns=[ChestXray14_PATH_COL])

        # get path
        self.df[ChestXray14_PATH_COL] = self.df[ChestXray14_PATH_COL].apply(lambda x: '/mnt/nvme_share/wuwl/dataset/ChestXray14/images/' + x[11:])

        # sample data
        self.test_input = self.df.sample(frac=0.4, random_state=42)
        self.test_label = self.label.sample(frac=0.4, random_state=42)
        self.train_input = self.df.drop(self.test_input.index)
        self.train_label = self.label.drop(self.test_label.index)
        self.train_input = self.train_input.sample(frac=self.cfg_classify.data.frac, random_state=42)
        self.train_label = self.train_label.sample(frac=self.cfg_classify.data.frac, random_state=42)

        test_input_save_path = ChestXray14_TEST_INPUT_CSV.replace('.csv', f'_{self.cfg_classify.data.frac}.csv')
        test_label_save_path = ChestXray14_TEST_LABEL_CSV.replace('.csv', f'_{self.cfg_classify.data.frac}.csv')

        if not (os.path.exists(test_input_save_path) or os.path.exists(test_label_save_path)):
            self.test_input.to_csv(test_input_save_path, index=False, header=self.test_input.columns.tolist())
            self.test_label.to_csv(test_label_save_path, index=False, header=self.test_label.columns.tolist())

        # text data
        with open(self.cfg_classify.data.text.path, 'r') as f:
            cls_prompts = json.load(f)
        bert_type = self.cfg_classify.model.text.bert_type
        self.tokenizer = AutoTokenizer.from_pretrained(bert_type)
        self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
        processed_txt = {}
        for k, v in cls_prompts.items():
            processed_txt[k] = self.process_text(v, "cpu")
        caption_ids, attention_mask, token_type_ids = [], [], []
        for cls_name, txts in processed_txt.items():
            caption_ids.append(txts["caption_ids"])
            attention_mask.append(txts["attention_mask"])
            token_type_ids.append(txts["token_type_ids"])
        caption_ids = torch.cat(caption_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)
        self.text_batch = {"caption_ids": caption_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}


    def __getitem__(self, index):
        row_input = self.train_input.iloc[index]
        row_label = self.train_label.iloc[index]
        # get image
        img_path = row_input[ChestXray14_PATH_COL]
        image = self.get_imgs(img_path, self.transform)
        # get labels
        label = row_label.tolist()
        label = torch.tensor(label)
        # get text
        text = self.text_batch
        return image, label, text

    def __len__(self):
        return len(self.train_input)

    def get_imgs(self, img_path, transform=None):

        x = cv2.imread(img_path, 0)

        # tranform images
        x = self._resize_img(x, self.cfg_classify.data.image.imsize)
        img = Image.fromarray(x).convert("RGB")

        if transform is not None:
            img = transform(img)

        return img

    def _resize_img(self, img, scale):
        """
        Args:
            img - image as numpy array (cv2)
            scale - desired output image-size as scale x scale
        Return:
            image resized to scale x scale with shortest dimension 0-padded
        """
        size = img.shape
        max_dim = max(size)
        max_ind = size.index(max_dim)

        # Resizing
        if max_ind == 0:
            # image is heigher
            wpercent = scale / float(size[0])
            hsize = int((float(size[1]) * float(wpercent)))
            desireable_size = (scale, hsize)
        else:
            # image is wider
            hpercent = scale / float(size[1])
            wsize = int((float(size[0]) * float(hpercent)))
            desireable_size = (wsize, scale)
        resized_img = cv2.resize(img, desireable_size[::-1], interpolation=cv2.INTER_AREA)  # this flips the desireable_size vector

        # Padding
        if max_ind == 0:
            # height fixed at scale, pad the width
            pad_size = scale - resized_img.shape[1]
            left = int(np.floor(pad_size / 2))
            right = int(np.ceil(pad_size / 2))
            top = int(0)
            bottom = int(0)
        else:
            # width fixed at scale, pad the height
            pad_size = scale - resized_img.shape[0]
            top = int(np.floor(pad_size / 2))
            bottom = int(np.ceil(pad_size / 2))
            left = int(0)
            right = int(0)
        resized_img = np.pad(resized_img, [(top, bottom), (left, right)], "constant", constant_values=0)

        return resized_img

    def process_text(self, text, device):

        if type(text) == str:
            text = [text]

        processed_text_tensors = []
        for t in text:
            # use space instead of newline
            t = t.replace("\n", " ")

            # split sentences
            splitter = re.compile("[0-9]+\.")
            captions = splitter.split(t)
            captions = [point.split(".") for point in captions]
            captions = [sent for point in captions for sent in point]

            all_sents = []

            for t in captions:
                t = t.replace("\ufffd\ufffd", " ")
                tokenizer = RegexpTokenizer(r"\w+")
                tokens = tokenizer.tokenize(t.lower())

                if len(tokens) <= 1:
                    continue

                included_tokens = []
                for t in tokens:
                    t = t.encode("ascii", "ignore").decode("ascii")
                    if len(t) > 0:
                        included_tokens.append(t)
                all_sents.append(" ".join(included_tokens))

            t = " ".join(all_sents)

            text_tensors = self.tokenizer(
                t,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=self.cfg_classify.data.text.word_num,
            )
            text_tensors["sent"] = [
                self.idxtoword[ix] for ix in text_tensors["input_ids"][0].tolist()
            ]
            processed_text_tensors.append(text_tensors)

        caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
        attention_mask = torch.stack(
            [x["attention_mask"] for x in processed_text_tensors]
        )
        token_type_ids = torch.stack(
            [x["token_type_ids"] for x in processed_text_tensors]
        )

        if len(text) == 1:
            caption_ids = caption_ids.squeeze(0).to(device)
            attention_mask = attention_mask.squeeze(0).to(device)
            token_type_ids = token_type_ids.squeeze(0).to(device)
        else:
            caption_ids = caption_ids.squeeze().to(device)
            attention_mask = attention_mask.squeeze().to(device)
            token_type_ids = token_type_ids.squeeze().to(device)

        cap_lens = []
        for txt in text:
            cap_lens.append(len([w for w in txt if not w.startswith("[")]))

        return {
            "caption_ids": caption_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "cap_lens": cap_lens,
        }


class ChestXDet10ImageDataset(data.Dataset):
    def __init__(self, cfg, split="train", transform=None):

        if ChestXray14_DATA_DIR is None:
            raise RuntimeError("ChestXDet10 data path empty")

        self.cfg_classify = cfg
        self.transform = transform

        # read label and data
        self.df = pd.read_csv(ChestXDet10_TRAIN_INPUT)
        self.df[ChestXDet10_PATH_COL] = self.df[ChestXDet10_PATH_COL].apply(lambda x: x.replace("/defaultShare/ChestX-Det10-Dataset/", "/mnt/nvme_share/wuwl/dataset/ChestXDet10/"))
        with open(ChestXDet10_TRAIN_LABEL, 'r') as f:
            data = json.load(f)
        all_label = []
        for d in data:
            all_label.append(d['syms'])
        mlb = MultiLabelBinarizer(classes=ChestXDet10_pathologies)
        label = mlb.fit_transform(all_label)
        self.label = pd.DataFrame(np.asarray(label))

        # sample data
        self.test_input = self.df.sample(frac=0.4, random_state=42)
        self.test_label = self.label.sample(frac=0.4, random_state=42)
        self.train_input = self.df.drop(self.test_input.index)
        self.train_label = self.label.drop(self.test_label.index)
        self.train_input = self.train_input.sample(frac=self.cfg_classify.data.frac, random_state=42)
        self.train_label = self.train_label.sample(frac=self.cfg_classify.data.frac, random_state=42)

        self.train_input.columns = [ChestXDet10_PATH_COL]
        self.train_label.columns = ChestXDet10_pathologies
        self.test_input.columns = [ChestXDet10_PATH_COL]
        self.test_label.columns = ChestXDet10_pathologies

        test_input_save_path = ChestXDet10_TEST_INPUT_CSV.replace('.csv', f'_{self.cfg_classify.data.frac}.csv')
        test_label_save_path = ChestXDet10_TEST_LABEL_CSV.replace('.csv', f'_{self.cfg_classify.data.frac}.csv')

        if not (os.path.exists(test_input_save_path) or os.path.exists(test_label_save_path)):
            self.test_input.to_csv(test_input_save_path, index=False, header=self.test_input.columns.tolist())
            self.test_label.to_csv(test_label_save_path, index=False, header=self.test_label.columns.tolist())

        # text data
        with open(self.cfg_classify.data.text.path, 'r') as f:
            cls_prompts = json.load(f)
        bert_type = self.cfg_classify.model.text.bert_type
        self.tokenizer = AutoTokenizer.from_pretrained(bert_type)
        self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
        processed_txt = {}
        for k, v in cls_prompts.items():
            processed_txt[k] = self.process_text(v, "cpu")
        caption_ids, attention_mask, token_type_ids = [], [], []
        for cls_name, txts in processed_txt.items():
            caption_ids.append(txts["caption_ids"])
            attention_mask.append(txts["attention_mask"])
            token_type_ids.append(txts["token_type_ids"])
        caption_ids = torch.cat(caption_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)
        self.text_batch = {"caption_ids": caption_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids}


    def __getitem__(self, index):
        row_input = self.train_input.iloc[index]
        row_label = self.train_label.iloc[index]
        # get image
        img_path = row_input[ChestXDet10_PATH_COL]
        image = self.get_imgs(img_path, self.transform)
        # get labels
        label = row_label.tolist()
        label = torch.tensor(label)
        # get text
        text = self.text_batch
        return image, label, text

    def __len__(self):
        return len(self.train_input)

    def get_imgs(self, img_path, transform=None):

        x = cv2.imread(img_path, 0)

        # tranform images
        x = self._resize_img(x, self.cfg_classify.data.image.imsize)
        img = Image.fromarray(x).convert("RGB")

        if transform is not None:
            img = transform(img)

        return img

    def _resize_img(self, img, scale):
        """
        Args:
            img - image as numpy array (cv2)
            scale - desired output image-size as scale x scale
        Return:
            image resized to scale x scale with shortest dimension 0-padded
        """
        size = img.shape
        max_dim = max(size)
        max_ind = size.index(max_dim)

        # Resizing
        if max_ind == 0:
            # image is heigher
            wpercent = scale / float(size[0])
            hsize = int((float(size[1]) * float(wpercent)))
            desireable_size = (scale, hsize)
        else:
            # image is wider
            hpercent = scale / float(size[1])
            wsize = int((float(size[0]) * float(hpercent)))
            desireable_size = (wsize, scale)
        resized_img = cv2.resize(img, desireable_size[::-1], interpolation=cv2.INTER_AREA)  # this flips the desireable_size vector

        # Padding
        if max_ind == 0:
            # height fixed at scale, pad the width
            pad_size = scale - resized_img.shape[1]
            left = int(np.floor(pad_size / 2))
            right = int(np.ceil(pad_size / 2))
            top = int(0)
            bottom = int(0)
        else:
            # width fixed at scale, pad the height
            pad_size = scale - resized_img.shape[0]
            top = int(np.floor(pad_size / 2))
            bottom = int(np.ceil(pad_size / 2))
            left = int(0)
            right = int(0)
        resized_img = np.pad(resized_img, [(top, bottom), (left, right)], "constant", constant_values=0)

        return resized_img

    def process_text(self, text, device):

        if type(text) == str:
            text = [text]

        processed_text_tensors = []
        for t in text:
            # use space instead of newline
            t = t.replace("\n", " ")

            # split sentences
            splitter = re.compile("[0-9]+\.")
            captions = splitter.split(t)
            captions = [point.split(".") for point in captions]
            captions = [sent for point in captions for sent in point]

            all_sents = []

            for t in captions:
                t = t.replace("\ufffd\ufffd", " ")
                tokenizer = RegexpTokenizer(r"\w+")
                tokens = tokenizer.tokenize(t.lower())

                if len(tokens) <= 1:
                    continue

                included_tokens = []
                for t in tokens:
                    t = t.encode("ascii", "ignore").decode("ascii")
                    if len(t) > 0:
                        included_tokens.append(t)
                all_sents.append(" ".join(included_tokens))

            t = " ".join(all_sents)

            text_tensors = self.tokenizer(
                t,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=self.cfg_classify.data.text.word_num,
            )
            text_tensors["sent"] = [
                self.idxtoword[ix] for ix in text_tensors["input_ids"][0].tolist()
            ]
            processed_text_tensors.append(text_tensors)

        caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
        attention_mask = torch.stack(
            [x["attention_mask"] for x in processed_text_tensors]
        )
        token_type_ids = torch.stack(
            [x["token_type_ids"] for x in processed_text_tensors]
        )

        if len(text) == 1:
            caption_ids = caption_ids.squeeze(0).to(device)
            attention_mask = attention_mask.squeeze(0).to(device)
            token_type_ids = token_type_ids.squeeze(0).to(device)
        else:
            caption_ids = caption_ids.squeeze().to(device)
            attention_mask = attention_mask.squeeze().to(device)
            token_type_ids = token_type_ids.squeeze().to(device)

        cap_lens = []
        for txt in text:
            cap_lens.append(len([w for w in txt if not w.startswith("[")]))

        return {
            "caption_ids": caption_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "cap_lens": cap_lens,
        }


def Classify_collate_fn(batch):
    """sort sequence"""

    images, labels, caption_ids, attention_mask, token_type_ids = [], [], [], [], []

    # flattern
    for b in batch:
            image, label, text, oral_text = b
            images.append(image)
            labels.append(label)
    # stack
    try:
        images = torch.stack(images)
    except:
        pass

    labels = torch.stack(labels)

    # add to dictionary
    return_dict = {
        "caption_ids": batch[0][2]["caption_ids"],
        "token_type_ids": batch[0][2]["token_type_ids"],
        "attention_mask": batch[0][2]["attention_mask"],
        "images": images,
        "texts": batch[0][-1],
        "labels": labels
    }

    return return_dict
    

class MultimodalPretrainingXHDataset(data.Dataset):
    def __init__(self, cfg, split="train", transform=None):

        if MIMIC_DATA_DIR is None:
            raise RuntimeError(
                "MIMIC data path empty\n"
                + "Make sure to download data from:\n"
                + "    https://stanfordmlgroup.github.io/competitions/MIMIC/"
                + f" and update MIMIC_DATA_DIR in ./CARZero/constants.py"
            )

        self.cfg = cfg
        self.transform = transform
        self.max_word_num = self.cfg.data.text.captions_per_image

        # read MIMIC csv file
        csv_path = os.path.join(MIMIC_DATA_DIR, MIMIC_MASTER_CSV_XH)

        self.df = pd.read_csv(csv_path)

        if 'sent_label' in self.cfg.experiment_name:
            filtered_df = self.df[self.df[MIMIC_VIEW_COL] == 'Frontal']
            filtered_indices = filtered_df.index
            self.df = filtered_df.reset_index(drop=True)
            sent_path = os.path.join(MIMIC_DATA_DIR, SENT_Path)
            sent = pd.read_csv(sent_path, header=None)
            sent = sent.loc[filtered_indices].reset_index(drop=True).values.tolist()
            label_path = os.path.join(MIMIC_DATA_DIR, LABEL_Path)
            label = pd.read_csv(label_path, header=None)
            label = label.loc[filtered_indices].reset_index(drop=True).values.tolist()
            self.sent = []
            for sent_sample in tqdm.tqdm(sent, desc='Loading chopped sentences'):
                self.sent.append(self.parse_string_to_list(sent_sample[0]))
            self.label = []
            for label_sample in tqdm.tqdm(label, desc='Loading sentences labels'):
                temp = [element for element in label_sample[:100] if element != '0' and element != 0]
                self.label.append(temp)
        else:
            self.df = self.df[self.df[MIMIC_VIEW_COL] == "Frontal"]

        # load studies and study to text mapping
        self.filenames, self.path2sent, self.label2sent, self.label_ids = self.load_text_data(split)

        # create BERT tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model.text.bert_type)

    def parse_string_to_list(self, string):
        try:
            string = re.sub(r"(?<=\[|,)\s*'(.*?)'\s*(?=,|\])", r'"\1"', string)
            result = ast.literal_eval(string)
            if isinstance(result, list):
                return result
            else:
                raise ValueError("The parsed result is not a list.")
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing string: {e}")
            return None

    def load_text_data(self, split):
        label2sent, label_ids = None, None
        # get study to captions mapping
        if 'sent_label' in self.cfg.experiment_name or 'test_medclip' in self.cfg.experiment_name:
            # filepath = os.path.join(MIMIC_DATA_DIR, "captions_sent_label_with_XH.pickle")
            # filepath = os.path.join(MIMIC_DATA_DIR, "captions_without_XH_sent_label.pickle")
            filepath = os.path.join(MIMIC_DATA_DIR, "captions_sent_label_with_LLM.pickle")
            if not os.path.isfile(filepath):
                print(f"Caption file {filepath} does not exit. Creating captions...")
                path2sent, label2sent, to_remove, label_ids = self.create_path_2_sent_mapping_sent_label(
                    self.df, self.sent, self.label
                )
                with open(filepath, "wb") as f:
                    pickle.dump([path2sent, label2sent, to_remove, label_ids], f, protocol=2)
                    print("Save to: ", filepath)
            else:
                with open(filepath, "rb") as f:
                    print(f"Loading captions from {filepath}")
                    path2sent, label2sent, to_remove, label_ids = pickle.load(f)
        else:
            # filepath = os.path.join(MIMIC_DATA_DIR, "captions_XH.pickle")
            filepath = os.path.join(MIMIC_DATA_DIR, "captions_without_XH.pickle")
            if not os.path.isfile(filepath):
                print(f"Caption file {filepath} does not exit. Creating captions...")
                path2sent, to_remove = self.create_path_2_sent_mapping(
                    self.df, self.max_word_num
                )
                with open(filepath, "wb") as f:
                    pickle.dump([path2sent, to_remove], f, protocol=2)
                    print("Save to: ", filepath)
            else:
                with open(filepath, "rb") as f:
                    print(f"Loading captions from {filepath}")
                    path2sent, to_remove = pickle.load(f)

        # filter studies to use for current split
        filenames = self.df[self.df[MIMIC_SPLIT_COL] == split][
            MIMIC_PATH_COL
        ].tolist()
        filenames = [f for f in filenames if f not in to_remove]

        with open( os.path.join(MIMIC_DATA_DIR, 'cxr_report_noise_sample.pickle'), 'rb') as f:
            noise_sample = pickle.load(f)
        
        tempnoise_sample = [ os.path.basename(f) for f in noise_sample]
        filenames = [f for f in filenames if os.path.basename(f) not in tempnoise_sample]
        
        # filenames = [f for f in filenames if f not in noise_sample]

        return filenames, path2sent, label2sent, label_ids

    def get_caption(self, path):

        series_sents = self.path2sent[path]

        if 'sent_label' in self.cfg.experiment_name or 'test_medclip' in self.cfg.experiment_name:
            series_labels = self.label2sent[path]
            if len(series_sents) != len(series_labels):
                raise Exception("no sentence for path")

        if self.cfg.data.text.full_report is True:
            sent = " ".join(series_sents)
        else:
            sent_ix = random.randint(0, len(series_sents))
            sent = series_sents[sent_ix]
            if 'sent_label_with_CL_loss' in self.cfg.experiment_name or 'sent_label_plus_with_CL_loss' in self.cfg.experiment_name \
                    or 'sent_label_gl' in self.cfg.experiment_name or 'sent_label_plus_gl' in self.cfg.experiment_name or 'CL_sent_label_plus' in self.cfg.experiment_name:
                label_sample = series_labels[sent_ix]
                label_list = series_labels
            elif 'sent_label_with_classify_loss' in self.cfg.experiment_name or 'test_medclip' in self.cfg.experiment_name:
                label_list = series_labels

        tokens = self.tokenizer(
            sent,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=self.cfg.data.text.word_num,
        )
        x_len = len([t for t in tokens["input_ids"][0] if t != 0])

        if 'sent_label_with_CL_loss' in self.cfg.experiment_name or 'sent_label_plus_with_CL_loss' in self.cfg.experiment_name \
                or 'sent_label_gl' in self.cfg.experiment_name or 'sent_label_plus_gl' in self.cfg.experiment_name or 'CL_sent_label_plus' in self.cfg.experiment_name:
            return tokens, x_len, label_sample, label_list
        elif 'sent_label_with_classify_loss' in self.cfg.experiment_name or 'test_medclip' in self.cfg.experiment_name:
            return tokens, x_len, label_list
        else:
            return tokens, x_len

    def get_imgs(self, img_path, transform=None):

        img_abs_path = os.path.join(PWD_Path, img_path.replace("/defaultShare/MIMIC-CXR/", ""))

        x = cv2.imread(str(img_abs_path), 0)

        # tranform images
        x = self._resize_img(x, self.cfg.data.image.imsize)
        img = Image.fromarray(x).convert("RGB")

        if transform is not None:
            img = transform(img)

        return img

    def __getitem__(self, index):

        key = self.filenames[index]

        imgs = self.get_imgs(key, self.transform)

        # randomly select a sentence
        if 'sent_label_with_classify_loss' in self.cfg.experiment_name or 'test_medclip' in self.cfg.experiment_name:
            caps, cap_len, label_list = self.get_caption(key)
            return imgs, caps, cap_len, key, label_list, self.label_ids
        elif 'sent_label_with_CL_loss' in self.cfg.experiment_name or 'sent_label_plus_with_CL_loss' in self.cfg.experiment_name \
                or 'sent_label_gl' in self.cfg.experiment_name or 'sent_label_plus_gl' in self.cfg.experiment_name or 'CL_sent_label_plus' in self.cfg.experiment_name:
            caps, cap_len, label_sample, label_list = self.get_caption(key)
            return imgs, caps, cap_len, key, label_sample, label_list, self.label_ids
        else:
            caps, cap_len = self.get_caption(key)
            return imgs, caps, cap_len, key

    def __len__(self):
        return len(self.filenames)

    def create_path_2_sent_mapping(self, df, max_word_num):

        sent_lens, num_sents, to_remove = [], [], []
        path2sent = {}
        for idx, row in tqdm.tqdm(df.iterrows(), total=df.shape[0]):

            # pick impression, findings, last_paragraph
            captions = ""
            if type(row[MIMIC_REPORT_COL]) == str:
                captions += row[MIMIC_REPORT_COL]
            if type(row[MIMIC_XH_REPORT_COL]) == str:
                captions += row[MIMIC_XH_REPORT_COL]

            img_path = row[MIMIC_PATH_COL]
            img_abs_path = os.path.join(PWD_Path, img_path.replace("/defaultShare/MIMIC-CXR/", ""))
            if not os.path.exists(img_abs_path):
                to_remove.append(row[MIMIC_PATH_COL])

            # remove empty reports
            if len(captions) == 0:
                to_remove.append(row[MIMIC_PATH_COL])

            # use space instead of newline
            captions = captions.replace("\n", " ")

            # split sentences
            splitter = re.compile("[0-9]+\.")
            captions = splitter.split(captions)
            captions = [point.split(".") for point in captions]
            captions = [sent for point in captions for sent in point]

            cnt = 0
            study_sent = []
            # create tokens from captions
            for cap in captions:

                if len(cap) == 0:
                    continue

                cap = cap.replace("\ufffd\ufffd", " ")
                # picks out sequences of alphanumeric characters as tokens
                # and drops everything else
                tokenizer = RegexpTokenizer(r"\w+")
                tokens = tokenizer.tokenize(cap.lower())

                # TODO: < 3 has instances of ['no', 'pneumothorax'], ['clear', 'lung']
                if len(tokens) <= 1:
                    # if len(tokens) < 3:
                    continue

                # filter tokens for current sentence
                included_tokens = []
                for t in tokens:
                    t = t.encode("ascii", "ignore").decode("ascii")
                    if len(t) > 0:
                        included_tokens.append(t)
                study_sent.append(" ".join(included_tokens))

                # check if reached maximum number of words in the sentences
                cnt += len(included_tokens)
                if cnt == max_word_num:
                    break

                sent_lens.append(len(included_tokens))
            num_sents.append(len(study_sent))

            # remove paths without setnences
            if len(study_sent) > 0:
                path2sent[row[MIMIC_PATH_COL]] = study_sent
            else:
                to_remove.append(row[MIMIC_PATH_COL])

        # get report word/setence statistics
        sent_lens = np.array(sent_lens)
        num_sents = np.array(num_sents)
        print(
            f"sent lens: {sent_lens.min()},{sent_lens.mean()},{sent_lens.max()} [{np.percentile(sent_lens, 5)}, {np.percentile(sent_lens, 95)}]"
        )
        print(
            f"num sents: {num_sents.min()},{num_sents.mean()},{num_sents.max()} [{np.percentile(num_sents, 5)}, {np.percentile(num_sents, 95)}]"
        )

        return path2sent, to_remove

    def create_path_2_sent_mapping_sent_label(self, df, sent, label):

        # process label class prompts
        CARZero = CARZeroDQNWOSAGLMLP(self.cfg)
        device = "cpu"
        with open(os.path.join(MIMIC_DATA_DIR, MIMIC_label), 'r') as f:
            cls_prompts = json.load(f)
        label_ids = {}
        for k, v in cls_prompts.items():
            label_ids[k] = CARZero.process_text(v, device)
        del CARZero

        sent_lens, num_sents, to_remove = [], [], []
        path2sent = {}
        label2sent = {}
        for idx, row in tqdm.tqdm(df.iterrows(), total=df.shape[0], desc='Creating captions'):

            # pick impression, findings, last_paragraph
            captions = ""
            if type(row[MIMIC_REPORT_COL]) == str:
                captions += row[MIMIC_REPORT_COL]

            # if type(row[MIMIC_XH_REPORT_COL]) == str:
            #     cleaned_text = row[MIMIC_XH_REPORT_COL].replace('\n', '')
            #     sentences = cleaned_text.split('.')
            #     sentences = [sentence + '.' for sentence in sentences if sentence]
            #     sent[idx] += sentences

            img_path = row[MIMIC_PATH_COL]
            img_abs_path = os.path.join(PWD_Path, img_path.replace("/defaultShare/MIMIC-CXR/", ""))
            if not os.path.exists(img_abs_path):
                to_remove.append(row[MIMIC_PATH_COL])

            # remove empty reports
            if len(captions) == 0:
                to_remove.append(row[MIMIC_PATH_COL])

            if len(sent[idx]) == 0:
                to_remove.append(row[MIMIC_PATH_COL])
            else:
                study_sent = []
                study_label = []
                for cap, lab in zip(sent[idx], label[idx]):
                    initial_length = len(study_sent)
                    tokenizer = RegexpTokenizer(r"\w+")
                    tokens = tokenizer.tokenize(cap.lower())
                    included_tokens = []
                    for t in tokens:
                        t = t.encode("ascii", "ignore").decode("ascii")
                        if len(t) > 0:
                            included_tokens.append(t)
                    study_sent.append(" ".join(included_tokens))
                    sent_lens.append(len(cap))
                    if type(lab) != str:
                        lab = str(lab)
                    if len(study_sent) > initial_length:
                        study_label.append(lab)
                num_sents.append(len(study_sent))
                path2sent[row[MIMIC_PATH_COL]] = study_sent
                label2sent[row[MIMIC_PATH_COL]] = study_label
        # get report word/setence statistics
        sent_lens = np.array(sent_lens)
        num_sents = np.array(num_sents)
        print(
            f"sent lens: {sent_lens.min()},{sent_lens.mean()},{sent_lens.max()} [{np.percentile(sent_lens, 5)}, {np.percentile(sent_lens, 95)}]"
        )
        print(
            f"num sents: {num_sents.min()},{num_sents.mean()},{num_sents.max()} [{np.percentile(num_sents, 5)}, {np.percentile(num_sents, 95)}]"
        )

        return path2sent, label2sent, to_remove, label_ids

    def _resize_img(self, img, scale):
        """
        Args:
            img - image as numpy array (cv2)
            scale - desired output image-size as scale x scale
        Return:
            image resized to scale x scale with shortest dimension 0-padded
        """
        size = img.shape
        max_dim = max(size)
        max_ind = size.index(max_dim)

        # Resizing
        if max_ind == 0:
            # image is heigher
            wpercent = scale / float(size[0])
            hsize = int((float(size[1]) * float(wpercent)))
            desireable_size = (scale, hsize)
        else:
            # image is wider
            hpercent = scale / float(size[1])
            wsize = int((float(size[0]) * float(hpercent)))
            desireable_size = (wsize, scale)
        resized_img = cv2.resize(
            img, desireable_size[::-1], interpolation=cv2.INTER_AREA
        )  # this flips the desireable_size vector

        # Padding
        if max_ind == 0:
            # height fixed at scale, pad the width
            pad_size = scale - resized_img.shape[1]
            left = int(np.floor(pad_size / 2))
            right = int(np.ceil(pad_size / 2))
            top = int(0)
            bottom = int(0)
        else:
            # width fixed at scale, pad the height
            pad_size = scale - resized_img.shape[0]
            top = int(np.floor(pad_size / 2))
            bottom = int(np.ceil(pad_size / 2))
            left = int(0)
            right = int(0)
        resized_img = np.pad(
            resized_img, [(top, bottom), (left, right)], "constant", constant_values=0
        )

        return resized_img
    

def multimodal_collate_fn(batch):
    """sort sequence"""

    imgs, cap_len, ids, tokens, attention, path, label_sample, label_list = [], [], [], [], [], [], [], []
    label_ids = None

    # flattern
    for b in batch:
        if len(b) == 4:
            img, cap, cap_l, p = b
            imgs.append(img)
            cap_len.append(cap_l)
            ids.append(cap["input_ids"])
            tokens.append(cap["token_type_ids"])
            attention.append(cap["attention_mask"])
            path.append(p)
        elif len(b) == 6:
            img, cap, cap_l, p, l_l, l_i = b
            imgs.append(img)
            cap_len.append(cap_l)
            ids.append(cap["input_ids"])
            tokens.append(cap["token_type_ids"])
            attention.append(cap["attention_mask"])
            path.append(p)
            label_list.append(l_l)
        elif len(b) == 7:
            img, cap, cap_l, p, l_s, l_l, l_i = b
            imgs.append(img)
            cap_len.append(cap_l)
            ids.append(cap["input_ids"])
            tokens.append(cap["token_type_ids"])
            attention.append(cap["attention_mask"])
            path.append(p)
            label_sample.append(l_s)
            label_list.append(l_l)

    if len(batch[0]) == 6:
        label_ids = batch[0][-1]

    # stack
    imgs = torch.stack(imgs)
    ids = torch.stack(ids).squeeze()
    tokens = torch.stack(tokens).squeeze()
    attention = torch.stack(attention).squeeze()

    # sort and add to dictionary
    sorted_cap_lens, sorted_cap_indices = torch.sort(torch.tensor(cap_len), 0, True)
    return_dict = {
        "caption_ids": ids[sorted_cap_indices],
        "token_type_ids": tokens[sorted_cap_indices],
        "attention_mask": attention[sorted_cap_indices],
        "imgs": imgs[sorted_cap_indices],
        "cap_lens": sorted_cap_lens,
        "path": path,
        "label_sample": [label_sample[i] for i in sorted_cap_indices] if label_sample else [],
        "label_list": [label_list[i] for i in sorted_cap_indices] if label_list else [],
        "label_ids": label_ids
    }

    return return_dict