import numpy as np
import torch
import h5py
import monai.transforms as motrans
import torchvision.transforms as tctrans
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

class NSDDataset(Dataset):
    def __init__(self, data_list, nii_trans=None):
        self.data_list = data_list
        self.nii_trans = nii_trans
    
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        # nii = [torch.from_numpy(np.load(nii))[None].float()/300. for nii in data['nii']]
        nii = [torch.load(nii)[None].float()/300. for nii in data['nii']]
        if self.nii_trans is not None:
            nii = [self.nii_trans(n) for n in nii]
        # nii = [motrans.LoadImage(dtype=np.float32)(nii) for nii in data['nii']]
        stim = torch.load(data['stim'])
        if len(nii)==1:
            nii = torch.cat(nii, dim=0).repeat(3,1,1,1)
        elif len(nii)==2:
            nii = torch.cat([nii[0], nii[1], nii[0]], dim=0)
        else:
            nii = torch.cat(nii, dim=0)
        return nii, stim
        # return stim

class NSDDataset_npy(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list
    
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        # nii = [torch.from_numpy(np.load(nii))[None].float()/300. for nii in data['nii']]
        nii = [torch.from_numpy(np.load(nii))[None].float()/300. for nii in data['nii']]
        # nii = [motrans.LoadImage(dtype=np.float32)(nii) for nii in data['nii']]
        stim = torch.load(data['stim'])
        if len(nii)==1:
            nii = torch.cat(nii, dim=0).repeat(3,1,1,1)
        elif len(nii)==2:
            nii = torch.cat([nii[0], nii[1], nii[0]], dim=0)
        else:
            nii = torch.cat(nii, dim=0)
        return nii, stim
        # return stim

class NSDDataset_New(Dataset):
    def __init__(self, data_list, batch_each_subj=False, norm=False, padding=False, *args, **kwargs):
        self.data_list = data_list
        self.batch_each_subj = batch_each_subj
        self.padding = padding
        self.norm = norm
        if self.padding:
            assert kwargs.get('patch_size') is not None, 'patch_size must be specified for padding'
            assert kwargs.get('fmri_shape') is not None, 'fmri_shape must be specified for padding'
            self.padding_list = []
            for dim in reversed(kwargs['fmri_shape']):
                pad_size = (kwargs['patch_size'] - dim % kwargs['patch_size']) % kwargs['patch_size']
                self.padding_list.extend([pad_size // 2, pad_size // 2 + (1 if pad_size % 2 else 0)])
        if self.norm:
            assert kwargs.get('subj_list') is not None, 'subj_list must be specified for normalization'
            assert kwargs.get('nsddir') is not None, 'nsddir must be specified for normalization'
            assert kwargs.get('space') is not None, 'space must be specified for normalization'
            assert kwargs.get('func') is not None, 'func must be specified for normalization'
            self.mean = {}
            self.std = {}
            for subj in kwargs['subj_list']:
                self.mean.update({f'subj0{subj}':torch.load(f'{kwargs['nsddir']}/nsddata_betas/mean_std/subj{subj}_{kwargs['space']}_{kwargs['func']}_mean.pth')})
                self.std.update({f'subj0{subj}':torch.load(f'{kwargs['nsddir']}/nsddata_betas/mean_std/subj{subj}_{kwargs['space']}_{kwargs['func']}_std.pth')})
    
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        pass

class NSDDataset_New_1_triggle(NSDDataset_New):
    def __init__(self, data_list, batch_each_subj=False, norm=False, padding=False, *args, **kwargs):
        super().__init__(data_list, batch_each_subj=batch_each_subj, norm=norm, padding=padding, *args, **kwargs)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        if self.batch_each_subj:
            nii = ()
            stim = ()
            subj = ()
            subj_label = ()
            for k in data.keys():
                nii_ = torch.load(data[k]['nii']).float()/300.
                if self.norm:
                    nii_ = (nii_ - self.mean[f'subj0{data[k]["subj"]}']) / (self.std[f'subj0{data[k]["subj"]}']+1e-5)
                if self.padding:
                    nii_ = F.pad(nii_, pad=self.padding_list, mode='constant', value=0.)
                nii+=(nii_[None],)
                stim += (torch.load(data[k]['stim']),)
                subj += (data[k]['subj'],)
                subj_label += (data[k]['subj_label'],)
            return nii, stim, subj, subj_label
        else:
            nii = torch.load(data['nii']).float()/300.
            if self.norm:
                nii = (nii - self.mean[f'subj0{data["subj"]}']) / (self.std[f'subj0{data["subj"]}']+1e-5)
            if self.padding:
                nii = F.pad(nii, pad=self.padding_list, mode='constant', value=0.)
            stim = torch.load(data['stim'])
            return nii[None], stim, data['subj'], data['subj_label']

class NSDDataset_New_all_triggle(NSDDataset_New):
    def __init__(self, data_list, norm=False, padding=False, *args, **kwargs):
        super().__init__(data_list, norm=norm, padding=padding, *args, **kwargs)

    def __getitem__(self, idx):
        data = self.data_list[idx]
        # nii = [torch.from_numpy(np.load(nii))[None].float()/300. for nii in data['nii']]
        nii = [torch.load(nii)[None].float()/300. for nii in data['nii']]
        
        stim = torch.load(data['stim'])
        if len(nii)==1:
            nii = torch.cat(nii, dim=0).repeat(3,1,1,1)
        elif len(nii)==2:
            nii = torch.cat([nii[0], nii[1], nii[0]], dim=0)
        else:
            nii = torch.cat(nii, dim=0)
        if self.norm:
            nii = (nii - self.mean[f'subj0{data["subj"]}']) / (self.std[f'subj0{data["subj"]}']+1e-5)
        if self.padding:
            nii = F.pad(nii, pad=self.padding_list, mode='constant', value=0.)
        return nii, stim

class NSDDataset_New_1_triggle_int16(Dataset):
    def __init__(self, data_list, batch_each_subj=False):
        super().__init__()
        self.data_list = data_list
        self.batch_each_subj = batch_each_subj

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        if self.batch_each_subj:
            nii = ()
            stim = ()
            subj = ()
            subj_label = ()
            for k in data.keys():
                nii_ = torch.load(data[k]['nii'])#.float()/300.
                nii+=(nii_[None],)
                stim += (torch.load(data[k]['stim']),)
                subj += (data[k]['subj'],)
                subj_label += (data[k]['subj_label'],)
            return nii, stim, subj, subj_label
        else:
            nii = torch.load(data['nii'])#.float()/300.
            stim = torch.load(data['stim'])
            return nii[None], stim, data['subj'], data['subj_label']

class NSDDataset_New_all_triggle_int16(Dataset):
    def __init__(self, data_list):
        super().__init__()
        self.data_list = data_list
    def __len__(self):
        return len(self.data_list)
    def __getitem__(self, idx):
        data = self.data_list[idx]
        # nii = [torch.from_numpy(np.load(nii))[None].float()/300. for nii in data['nii']]
        nii = [torch.load(n) for n in data['nii']]
        stim = torch.load(data['stim'])
        if len(nii)==1:
            nii = torch.stack(nii).repeat(3,1,1,1)
        elif len(nii)==2:
            nii = torch.stack([nii[0], nii[1], nii[0]])
        else:
            nii = torch.stack(nii)
        return nii, stim



class NSDDataset_New_1_triggle_image_int16(Dataset):
    def __init__(self, data_list, stimulus=None):
        super().__init__()
        self.data_list = data_list
        self.stimulus = stimulus

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        nii = torch.load(data['nii'])#.float()/300.
        # stim = torch.load(data['stim'])
        stim = torch.from_numpy(self.stimulus[data['stim_idx']])
        return nii[None], stim, data['subj'], data['subj_label']

class NSDDataset_New_all_triggle_image_int16(Dataset):
    def __init__(self, data_list, stimulus=None):
        super().__init__()
        self.data_list = data_list
        self.stimulus = stimulus

    def __len__(self):
        return len(self.data_list)
    def __getitem__(self, idx):
        data = self.data_list[idx]
        # nii = [torch.from_numpy(np.load(nii))[None].float()/300. for nii in data['nii']]
        nii = [torch.load(n) for n in data['nii']]
        # stim = torch.load(data['stim'])
        stim = torch.from_numpy(self.stimulus[data['stim_idx']])
        if len(nii)==1:
            nii = torch.stack(nii).repeat(3,1,1,1)
        elif len(nii)==2:
            nii = torch.stack([nii[0], nii[1], nii[0]])
        else:
            nii = torch.stack(nii)
        return nii, stim

# class NSDDataset_New_1_triggle_image_text_int16(Dataset):
#     def __init__(self, data_list, stimulus=None):
#         super().__init__()
#         self.data_list = data_list
#         self.stimulus = stimulus

#     def __len__(self):
#         return len(self.data_list)
    
#     def __getitem__(self, idx):
#         data = self.data_list[idx]
#         nii = torch.load(data['nii'])[None]#.float()/300.
#         # stim = torch.load(data['stim'])
#         stim = torch.from_numpy(self.stimulus[data['stim_idx']])
#         coco = data['cocoid']
#         return nii, stim, coco, data['subj'], data['subj_label']
import random
class NSDDataset_New_1_triggle_image_text_int16(Dataset):
    def __init__(self, data_list, stimulus=None, tokenized_captions=None, image_processor=None):
        super().__init__()
        self.data_list = data_list
        self.stimulus = stimulus
        self.tokenized_captions = tokenized_captions
        self.image_processor = image_processor

    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data = self.data_list[idx]
        nii = torch.load(data['nii'])[None]#.float()/300.
        # stim = torch.load(data['stim'])
        stim = torch.from_numpy(self.stimulus[data['stim_idx']])
        if self.image_processor is not None:
            stim = self.image_processor(stim, return_tensors="pt", do_rescale=False).pixel_values.squeeze(0)

        coco = data['cocoid']
        n = len(self.tokenized_captions[str(coco)]['input_ids'])
        random_idx = random.randint(0, n-1)
        input_ids = self.tokenized_captions[str(coco)]['input_ids'][random_idx]
        attention_mask = self.tokenized_captions[str(coco)]['attention_mask'][random_idx]
        return nii, stim, input_ids, attention_mask, data['subj'], data['subj_label']
