import re, os
import pickle
import json
import contextlib
from pathlib import Path
import pandas as pd
from ast import literal_eval
from nltk.tokenize import RegexpTokenizer
from tqdm import tqdm
import numpy as np

import torch
from PIL import ImageFile
from torchvision import transforms

from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    Resized,
    LoadImaged,
    Orientationd,
    RandSpatialCropd,
    ScaleIntensityRanged,
    ToTensord
)

from .input_dataset import collate_fn, pad_or_cut_img_tensors

class BIMCVCOVID19Dataset(torch.utils.data.Dataset):
    def __init__(self, args,
        transform=None,
        data_pct=1.0,
        max_words=400,
    ):
        super().__init__()
        if not os.path.exists(args.dataset_path):
            raise RuntimeError(f"{args.dataset_path} does not exits!")
        fin = open(args.dataset_path, 'r')
        list_dataset = json.load(fin)

        self.data_pct = data_pct
        self.transform = transform
        self.tokenizer = args.tokenizer
        self.max_words = max_words
        self.dataset_path = args.dataset_path
        self.feature_path = args.feature_path
        if self.feature_path is not None:
            self.transform = None

        assert args.split in ["train", "validate", "test"]
        self.num_images_per_sample = 32
        self.patch_image_size = args.patch_image_size
        self.instruction = args.instruction
        self.med_patch_image_size = 128
        self.vision_encode_mode = args.vision_encode_mode
        self.split = args.split
        self.dummy = args.dummy

        self.bos_item = torch.LongTensor([args.tokenizer.bos_token_id])
        self.eos_item = torch.LongTensor([args.tokenizer.eos_token_id])
        self.bos_mask = torch.LongTensor([1])
        self.eos_mask = torch.LongTensor([1])

        # Trasnform for 3D medical encoder, from xiaoxuan
        self.med_patch_resize_transform = Compose([
            LoadImaged(keys=["image"]),
            EnsureChannelFirstd(keys=["image"]),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"], a_min=-1000, a_max=3000, b_min=0.0, b_max=1.0, clip=True
            ),
            Resized(keys="image", spatial_size=(160, 160, 64)),
            RandSpatialCropd(keys="image", roi_size=[128, 128, 32], random_size=False),
            ToTensord(keys=["image"])
        ])

        # Resize to fit in original flamingo encoder
        self.resize_transform = transforms.Resize(self.patch_image_size, antialias=True)

        self.filenames = []
        self.path2sent = {}
        for k, v in list_dataset.items():
            self.read_per_dataset(k, v)

    def read_per_dataset(self, key, value):
        self.df = pd.read_csv(value["report"])
        self.df.drop(self.df[self.df["Split"] != self.split].index, inplace=True)
        root = value['data_path']
        format = value['format']

        self.load_text_data(key, root, format)
        if self.data_pct != 1.0 and self.split == "train":
            self.df = self.df.sample(frac=self.data_pct, random_state=42)
        self.df.reset_index(drop=True, inplace=True)

    def load_text_data(self, name, root, format):
        # get study to captions mapping
        # TODO: check this
        filepath = f"{os.path.dirname(self.dataset_path)}/captions_{name}_{self.split}.pkl"
        if not os.path.isfile(filepath):
            print(
                f"Caption file {filepath} does not exit. Creating captions...")
            self.create_path_2_sent_mapping(root)
            with open(filepath, "wb") as f:
                pickle.dump(self.path2sent, f, protocol=2)
                print("Save to: ", filepath)
        else:
            with open(filepath, "rb") as f:
                self.path2sent.update(pickle.load(f))
        # filter studies to use for current split
        for row in self.df.itertuples():
            path = os.path.join(root, getattr(row, 'Filepath'))
            # img = literal_eval(getattr(row, 'Filename'))
            img = getattr(row, 'Filename')
            path = os.path.join(root, path)
            if path in self.path2sent:
                self.filenames.append({'path': path, 'img': img, 'format': format})

    def create_path_2_sent_mapping(self, root):
        sent_lens, num_sents = [], []
        # iterrows is not faster than itertuples ...  but it is ok
        for _, row in tqdm(self.df.iterrows(), total=self.df.shape[0]):
            report_list = []
            report = literal_eval(row["Report"])
            for i in range(len(report)):
                captions = ""
                captions += report[i]

                # use space instead of newline
                captions = captions.replace("\n", " ")

                # split sentences
                splitter = re.compile("[0-9]+\.")
                captions = splitter.split(captions)
                captions = [point.split(".") for point in captions]
                captions = [sent for point in captions for sent in point]

                cnt = 0
                study_sent = []
                # create tokens from captions
                for cap in captions:
                    if len(cap) == 0:
                        continue

                    cap = cap.replace("\ufffd\ufffd", " ")
                    # picks out sequences of alphanumeric characters as tokens
                    # and drops everything else
                    tokenizer = RegexpTokenizer(r"\w+")
                    tokens = tokenizer.tokenize(cap.lower())
                    # TODO: < 3 has instances of ['no', 'pneumothorax'], ['clear', 'lung']
                    if len(tokens) <= 1:
                        continue

                    # filter tokens for current sentence
                    included_tokens = []
                    for t in tokens:
                        t = t.encode("ascii", "ignore").decode("ascii")
                        if len(t) > 0:
                            included_tokens.append(t)

                    if len(included_tokens) > 0:
                        study_sent.append(" ".join(included_tokens))

                    cnt += len(included_tokens)
                if cnt >= 3:
                    sent_lens.append(cnt)
                    num_sents.append(len(study_sent))
                    report_list.append(study_sent)
                self.path2sent[os.path.join(root, row['Filepath'])] = report_list
        sent_lens = np.array(sent_lens)
        num_sents = np.array(num_sents)
        print(sent_lens, num_sents)

    def __len__(self):
        return len(self.filenames)

    def get_caption(self, path, instruction):
        series_sents = self.path2sent[path][:-1]

        if len(series_sents) == 0:
            series_sents = [self.path2sent[path][-1]]
            # raise Exception("no sentence for path")

        full_sent = ""
        for sent in series_sents:
            full_sent += ','.join(list(filter(lambda x: x != "", sent)))

        # idx = len(series_sents) - 1
        # series_sents = series_sents[idx]
        # # separate different sentences
        # series_sents = list(filter(lambda x: x != "", series_sents))
        # sent = " ".join(series_sents)

        tokens = self.tokenizer(
            f"<image> {instruction} <answer> {full_sent}<|endofchunk|>",
            return_tensors="pt",
            truncation=True,
            add_special_tokens=False,
            max_length=self.max_words,
        )
        x_len = len([t for t in tokens["input_ids"][0] if t != 0])

        return tokens, x_len

    def __getitem__(self, index):
        sample = self.filenames[index]
        src_text, src_text_len = self.get_caption(sample["path"], self.instruction)
        src_item = src_text["input_ids"].squeeze(0)
        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        src_item_mask = src_text["attention_mask"].squeeze(0)
        src_item_mask = torch.cat([self.bos_mask, src_item_mask, self.eos_mask])

        # idx = np.random.choice(len(sample['img']), (1,))[0]
        # path = os.path.join(root, sample['img'][idx])
        root = sample["path"]
        path = os.path.join(root, sample["img"])

        feats, patch_images, med_patch_images = None, None, None
        if self.feature_path:
            assert self.vision_encode_mode != "original"
            feature_file_name = "_".join(os.path.basename(path).split('_')[:2])+".npy"
            feats = torch.from_numpy(np.load(os.path.join(self.feature_path, feature_file_name)))
        else:
            if self.vision_encode_mode != "original":
                med_patch_images = self.med_patch_resize_transform({'image': path})['image']
                med_patch_images = torch.permute(med_patch_images, (3,0,1,2))
                patch_images = self.resize_transform(med_patch_images)
                patch_images = torch.repeat_interleave(patch_images, 3, dim=1)
            else:
                # Implement image loading independent of medical encoder monai transforms
                raise NotImplementedError

        if self.dummy:
            patch_images = torch.zeros(patch_images.shape)
            med_patch_images = torch.zeros(med_patch_images.shape)

        patch_mask = torch.tensor([True])
        conf = torch.tensor([1.0])
        example = {
            "id": index,
            "source": src_item,
            "text_mask": src_item_mask,
            "patch_image": patch_images,
            "med_patch_image": med_patch_images,
            "features": feats,
            "patch_mask": patch_mask,
            "conf": conf,
        }
        examples = [example]
        return examples

    def collate(self, samples):
        """Merge samples of different tasks to form two mini-batches.
        Args:
            samples (List[Tuple]): samples to collate
        Returns:
            Tuple[dict]: two mini-batch containing the data of different tasks
        """

        # if self.feature_path is not None:
        #     for sample in samples:


        for sample in samples:
            sample[0]["patch_image"] = pad_or_cut_img_tensors(
                sample[0]["patch_image"],
                self.patch_image_size,
                self.num_images_per_sample
            )
            sample[0]["med_patch_image"] = pad_or_cut_img_tensors(
                sample[0]["med_patch_image"],
                self.med_patch_image_size,
                self.num_images_per_sample
            )

        samples_v1 = []  # containing image-text pairs
        for sample_tuple in samples:
            samples_v1.append(sample_tuple[0])

        res_v1 = collate_fn(
            samples_v1,
            pad_idx=self.tokenizer.pad_token_id,
            eos_idx=self.tokenizer.eos_token_id,
        )
        return res_v1

    def __str__(self):
        return f"type: {type(self)}, length: {len(self)}"

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch