import re, os
import pickle
import json
import contextlib
from pathlib import Path
import pandas as pd
from ast import literal_eval
from nltk.tokenize import RegexpTokenizer
from tqdm import tqdm
import numpy as np

import torch
from PIL import ImageFile
from torchvision import transforms

from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    Resized,
    LoadImaged,
    Orientationd,
    RandSpatialCropd,
    ScaleIntensityRanged,
    ToTensord
)

from .input_dataset import collate_fn, pad_or_cut_img_tensors

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, args,
        transform=None,
        data_pct=1.0,
        max_words=400,
    ):
        super().__init__()
        if not os.path.exists(args.dataset_path):
            raise RuntimeError(f"{args.dataset_path} does not exits!")

        self.data_pct = data_pct
        self.transform = transform
        self.tokenizer = args.tokenizer
        self.max_words = max_words
        self.dataset_path = args.dataset_path
        self.feature_path = args.feature_path
        if self.feature_path is not None:
            self.transform = None

        assert args.split in ["train", "validate", "test"]
        self.num_images_per_sample = 32
        self.patch_image_size = args.patch_image_size
        self.instruction = args.instruction
        self.med_patch_image_size = 128
        self.vision_encode_mode = args.vision_encode_mode
        self.split = args.split
        self.dummy = args.dummy

        self.bos_item = torch.LongTensor([args.tokenizer.bos_token_id])
        self.eos_item = torch.LongTensor([args.tokenizer.eos_token_id])
        self.bos_mask = torch.LongTensor([1])
        self.eos_mask = torch.LongTensor([1])

        # Trasnform for 3D medical encoder, from xiaoxuan
        self.med_patch_resize_transform = Compose([
            LoadImaged(keys=["image"]),
            EnsureChannelFirstd(keys=["image"]),
            Orientationd(keys=["image"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"], a_min=-1000, a_max=3000, b_min=0.0, b_max=1.0, clip=True
            ),
            Resized(keys="image", spatial_size=(160, 160, 64)),
            RandSpatialCropd(keys="image", roi_size=[128, 128, 32], random_size=False),
            ToTensord(keys=["image"])
        ])

        # Resize to fit in original flamingo encoder
        self.resize_transform = transforms.Resize(self.patch_image_size, antialias=True)

        self.df = pd.read_csv(args.dataset_path, sep='\t')
        self.df.drop(self.df[self.df["Split"] != self.split].index, inplace=True)
        if self.data_pct != 1.0 and self.split == "train":
            self.df = self.df.sample(frac=self.data_pct, random_state=42)
        self.df.reset_index(drop=True, inplace=True)


    def get_caption(self, report, instruction):
        tokens = self.tokenizer(
            f"<image> {instruction} <answer> {report}<|endofchunk|>",
            return_tensors="pt",
            truncation=True,
            add_special_tokens=False,
            max_length=self.max_words,
        )
        x_len = len([t for t in tokens["input_ids"][0] if t != 0])

        return tokens, x_len


    def collate(self, samples):
        """Merge samples of different tasks to form two mini-batches.
        Args:
            samples (List[Tuple]): samples to collate
        Returns:
            Tuple[dict]: two mini-batch containing the data of different tasks
        """

        # if self.feature_path is not None:
        #     for sample in samples:


        for sample in samples:
            sample[0]["patch_image"] = pad_or_cut_img_tensors(
                sample[0]["patch_image"],
                self.patch_image_size,
                self.num_images_per_sample
            )
            sample[0]["med_patch_image"] = pad_or_cut_img_tensors(
                sample[0]["med_patch_image"],
                self.med_patch_image_size,
                self.num_images_per_sample
            )

        samples_v1 = []  # containing image-text pairs
        for sample_tuple in samples:
            samples_v1.append(sample_tuple[0])

        res_v1 = collate_fn(
            samples_v1,
            pad_idx=self.tokenizer.pad_token_id,
            eos_idx=self.tokenizer.eos_token_id,
        )
        return res_v1

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch

    def __str__(self):
        return f"type: {type(self)}, length: {len(self)}"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        sample = self.df.iloc[index]
        src_text, src_text_len = self.get_caption(sample["Report"], self.instruction)
        src_item = src_text["input_ids"].squeeze(0)
        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        src_item_mask = src_text["attention_mask"].squeeze(0)
        src_item_mask = torch.cat([self.bos_mask, src_item_mask, self.eos_mask])

        feats, patch_images, med_patch_images = None, None, None
        if self.feature_path:
            assert self.vision_encode_mode != "original"
            feature_file_name = "_".join(os.path.basename(path).split('_')[:2])+".npy"
            feats = torch.from_numpy(np.load(os.path.join(self.feature_path, feature_file_name)))
        else:
            if self.vision_encode_mode != "original":
                med_patch_images = self.med_patch_resize_transform({'image': sample["Image"]})['image']
                med_patch_images = torch.permute(med_patch_images, (3,0,1,2))
                patch_images = self.resize_transform(med_patch_images)
                patch_images = torch.repeat_interleave(patch_images, 3, dim=1)
            else:
                # Implement image loading independent of medical encoder monai transforms
                raise NotImplementedError

        if self.dummy:
            patch_images = torch.zeros(patch_images.shape)
            med_patch_images = torch.zeros(med_patch_images.shape)

        patch_mask = torch.tensor([True])
        conf = torch.tensor([1.0])
        example = {
            "id": index,
            "source": src_item,
            "text_mask": src_item_mask,
            "patch_image": patch_images,
            "med_patch_image": med_patch_images,
            "features": feats,
            "patch_mask": patch_mask,
            "conf": conf,
        }
        examples = [example]
        return examples