
import os
import sys
import copy
import json
from dataclasses import dataclass
from typing import Dict, Sequence

import numpy as np

import torch
from einops import rearrange
from torch.nn import functional as F
from PIL.Image import Image

from torch.utils.data.dataset import Dataset
import transformers

from llava.constants import IGNORE_INDEX
from llava.train import rank0_print, DataArguments, preprocess_multimodal, preprocess, DataCollatorForSupervisedDataset
from llava.datasets.fmri_vit3d_datasets import calculate_total_mean_variance_from_std


@dataclass
class fMRIDataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        if 'image' in instances[0]:
            images = [instance['image'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch['images'] = torch.stack(images)
            else:
                batch['images'] = images

        if 'vision_embeds' in instances[0]:
            vision_embeds = [instance['vision_embeds'] for instance in instances]
            if all(x is not None and x.shape == vision_embeds[0].shape for x in vision_embeds):
                batch['vision_embeds'] = torch.stack(vision_embeds)
            else:
                batch['vision_embeds'] = vision_embeds

        return batch


class fMRILazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path: str,
        tokenizer: transformers.PreTrainedTokenizer,
        data_args: DataArguments,
        simple_conversation_only: bool = False
    ):
        super(fMRILazySupervisedDataset, self).__init__()
        data_dict = json.load(open(data_path, "r"))

        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.list_data_dict = data_dict['conversations']
        self.data_args = data_args

        mean = torch.tensor(np.load(data_dict['info']['fmri_mean']))
        std = torch.tensor(np.load(data_dict['info']['fmri_std']))
        self.fmri_shape = mean.shape

        self.mean, self.std = calculate_total_mean_variance_from_std(mean, std)

        if simple_conversation_only:
            self.list_data_dict = [sample for sample in self.list_data_dict if sample['conversation_type'] == 'briefly_descriptions']


        if data_args.select_brain_region:
            subj_paths = data_dict['info']['atlas']
            print(subj_paths)

            subj_mask = {}
            # whole_mask = torch.zeros((1, 1, *self.fmri_shape), dtype=torch.uint8)
            for subj in subj_paths:
                atlas_json = json.load(open(subj_paths[subj].replace("atlas", "atlas_general"), "r"))

                region_name2ids = atlas_json[1]
                atlas = rearrange(torch.tensor(atlas_json[0]), 'z y x -> x y z')
                mask = torch.zeros_like(atlas, dtype=torch.uint8)

                mask[atlas == region_name2ids['nsdgeneral']] = 1
                subj_mask[subj] = mask

            self.subj_mask = subj_mask
        else:
            self.subj_mask = None

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = 288 if 'image' in sample else 0   # TODO: is this correct?
            length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
        return length_list

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
            cur_len = cur_len if 'fmri' in sample else -cur_len
            length_list.append(cur_len)

        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'fmri' in sources[0]:
            processor = self.data_args.image_processor
            fmri_file = self.list_data_dict[i]['fmri']

            fmri = torch.tensor(np.load(fmri_file))

            if self.subj_mask is not None:
                subj = fmri_file.split("/")[-3]
                mask = self.subj_mask[subj]
                fmri = fmri * mask

            fmri = processor(fmri, mean=self.mean, std=self.std)

            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]), self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
            print("no fMRIs in ", sources[0])


        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=('fmri' in self.list_data_dict[i]))

        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])

        # image exist in the data
        if 'fmri' in self.list_data_dict[i]:
            data_dict['image'] = fmri  # use image as the key, to be consistent with multimodal model
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            data_dict['image'] = torch.zeros(*self.fmri_shape)

        if 'vision_embeds' in self.list_data_dict[i] and self.data_args.requires_vision_embeds:
            vision_embeds = np.load(self.list_data_dict[i]['vision_embeds'])
            vision_embeds = torch.tensor(vision_embeds)
            data_dict['vision_embeds'] = vision_embeds

        # print(self.data_args.requires_vision_embeds, data_dict.keys(), 'source:', sources[0], 'dict:', self.list_data_dict[i])

        return data_dict


def make_supervised_fmri_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                     data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = fMRILazySupervisedDataset(tokenizer=tokenizer,
                                              data_path=data_args.data_path,
                                              data_args=data_args)
    eval_dataset = fMRILazySupervisedDataset(tokenizer=tokenizer,
                                             data_path=data_args.val_data_path,
                                             data_args=data_args)
    data_collator = fMRIDataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=data_collator)


if __name__ == '__main__':
    dataset = fMRILazySupervisedDataset(
        data_path='/mnt/NSD_dataset/datasets/nsd/sft_data/subj01/sft_subj01_tr.json',
        tokenizer=transformers.AutoTokenizer.from_pretrained('gpt2'),
        data_args=DataArguments()
    )

    for data in dataset:
        print(data)
        break
