import csv
import contextlib
from pathlib import Path
import numpy as np

import torch
from torch.utils.data import Dataset
from torchvision import transforms

# from .transforms import *
from .input_dataset import collate_fn, pad_or_cut_img_tensors

from otter.biovil_encoder.image.data.io import load_image
from otter.biovil_encoder.image.data.transforms import (
    create_chest_xray_transform_for_inference,
    create_chest_xray_augmentation_for_training
)


@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
    """Context manager which seeds the NumPy PRNG with the specified seed and
    restores the state afterward"""
    if seed is None:
        yield
        return
    if len(addl_seeds) > 0:
        seed = int(hash((seed, *addl_seeds)) % 1e6)
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

class MIMICCXRDataset(Dataset):
    """Load processed MIMIC-CXR master.csv dataset in cvs format"""
    def __init__(self, args):
        super().__init__()
        assert args.split in ["train", "validate", "test"]
        self.args = args
        self.separator = ","
        self.num_images_per_sample = 2
        # self.code_dict_size = args.code_dict_size
        # self.code_image_size = args.code_image_size
        self.epoch = 0
        self.patch_image_size = args.patch_image_size
        self.med_patch_image_size = 480
        self.vision_encode_mode = args.vision_encode_mode
        self.max_src_length = 350
        self.file_path = args.dataset_path
        self.tokenizer = args.tokenizer
        self.instruction = args.instruction
        self.dummy = args.dummy
        self.seed = args.seed

        self.bos_item = torch.LongTensor([args.tokenizer.bos_token_id])
        self.eos_item = torch.LongTensor([args.tokenizer.eos_token_id])
        self.bos_mask = torch.LongTensor([1])
        self.eos_mask = torch.LongTensor([1])

        # Trasnform for medical encoder, values from biovil_encder/image/utils.py
        self.med_patch_resize_transform = create_chest_xray_transform_for_inference(
            resize=512, center_crop_size=self.med_patch_image_size
        )
        # Resize to fit in original flamingo encoder
        self.resize_transform = transforms.Resize(self.patch_image_size, antialias=True)
        # Augmentation values from https://arxiv.org/abs/2204.09817 Table E.1: Fine-tuning for Downstream Tasks
        self.training_image_augmentation = create_chest_xray_augmentation_for_training(
            degree=45, shear=25, brightness=0.2, contrast=0.2, flip_p=0.5
        )

        def filter_split(line):
            line = line.rstrip("\n\t ")
            if line.endswith(args.split):
                return True
            return False

        with open(self.file_path) as f:
            self.dataset = list(filter(filter_split, list(f.readlines())))

    def process_image_text_pair(self, index):
        (
            uniq_id,
            patient_id,
            study_id,
            image_paths,
            image_labels,
            findings,
            impression,
            report,
            split
        ) = list(csv.reader([self.dataset[index].rstrip("\n")]))[0]

        image_paths = image_paths.split(',')
        image_labels = image_labels.split(',')

        # Only use samples with PA and AP
        filtered_image_paths = []
        for i in range(len(image_labels)):
            if image_labels[i] in ["PA", "AP"]:
                filtered_image_paths.append(image_paths[i])
        if len(filtered_image_paths) == 0:
            return None

        patch_images, med_patch_images = [], []
        for image_path in filtered_image_paths:
            image_path = Path(self.args.dataset_path).parent / image_path
            image = load_image(image_path)
            image = self.med_patch_resize_transform(image)
            image = self.training_image_augmentation(image)
            med_patch_images.append(image)
            patch_images.append(self.resize_transform(image))

        patch_images = torch.stack(patch_images)  # (T,C,H,W)
        med_patch_images = torch.stack(med_patch_images)

        patch_mask = torch.tensor([True])
        conf = torch.tensor([1.0])
        pos_src_item = None
        neg_src_item = None

        src_text = self.tokenizer(
            f"<image> {self.instruction} <answer> {report}<|endofchunk|>",
            return_tensors="pt",
            add_special_tokens=False,
            max_length=self.max_src_length,
            truncation=True
        )
        # src_text = self.tokenizer(
        #     " final report ",
        #     return_tensors="pt",
        #     add_special_tokens=False,
        # )
        # In Otter code tgt_item (report) not used
        # tgt_item = self.tokenizer(
        #     " {}".format(report), return_tensors="pt", add_special_tokens=False
        # ).input_ids.squeeze(0)
        src_item = src_text["input_ids"].squeeze(0)
        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        src_item_mask = src_text["attention_mask"].squeeze(0)
        src_item_mask = torch.cat([self.bos_mask, src_item_mask, self.eos_mask])

        if self.dummy:
            patch_images = torch.zeros(patch_images.shape)
            med_patch_images = torch.zeros(med_patch_images.shape)

        example = {
            "id": uniq_id,
            "source": src_item,
            "text_mask": src_item_mask,
            "patch_image": patch_images,
            "med_patch_image": med_patch_images,
            "patch_mask": patch_mask,
            "conf": conf,
        }

        examples = [example]

        return examples

    def collate(self, samples):
        """Merge samples of different tasks to form two mini-batches.
        Args:
            samples (List[Tuple]): samples to collate
        Returns:
            Tuple[dict]: two mini-batch containing the data of different tasks
        """


        for sample in samples:
            sample[0]["patch_image"] = pad_or_cut_img_tensors(
                sample[0]["patch_image"],
                self.patch_image_size,
                self.num_images_per_sample
            )
            sample[0]["med_patch_image"] = pad_or_cut_img_tensors(
                sample[0]["med_patch_image"],
                self.med_patch_image_size,
                self.num_images_per_sample
            )

        samples_v1 = []  # containing image-text pairs
        for sample_tuple in samples:
            samples_v1.append(sample_tuple[0])

        res_v1 = collate_fn(
            samples_v1,
            pad_idx=self.tokenizer.pad_token_id,
            eos_idx=self.tokenizer.eos_token_id,
        )
        return res_v1

    def __getitem__(self, index):
        with numpy_seed(self.seed, self.epoch):
            pair_samples = self.process_image_text_pair(index)
            # if dataset is not supported
            if pair_samples is None:
                return self.__getitem__(index + 1)
        return pair_samples

    def __str__(self):
        return f"type: {type(self)}, length: {len(self)}"

    def __len__(self):
        return len(self.dataset)

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch