"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import pathlib
import random
import xmltodict

import numpy as np
import h5py
from torch.utils.data import Dataset


class SliceData(Dataset):
    """
    A PyTorch Dataset that provides access to MR image slices.
    """

    def __init__(self, root, transform, challenge, sample_rate=1, contrast_type='both'):
        """
        Args:
            root (pathlib.Path): Path to the dataset.
            transform (callable): A callable object that pre-processes the raw data into
                appropriate form. The transform function should take 'kspace', 'target',
                'attributes', 'filename', and 'slice' as inputs. 'target' may be null
                for test data.
            challenge (str): "singlecoil" or "multicoil" depending on which challenge to use.
            sample_rate (float, optional): A float between 0 and 1. This controls what fraction
                of the volumes should be loaded.
        """
        assert contrast_type in ['fs', 'nfs', 'both']
        if challenge not in ('singlecoil', 'multicoil'):
            raise ValueError('challenge should be either "singlecoil" or "multicoil"')

        self.transform = transform
        self.recons_key = 'reconstruction_esc' if challenge == 'singlecoil' \
            else 'reconstruction_rss'

        self.examples = []
        files = list(pathlib.Path(root).iterdir())
        total_slices = 0
        
        # Filter based on fat-suppression
        filtered_files = []
        for fname in sorted(files):
            data = h5py.File(fname, 'r')
            kspace = data['kspace']
            num_slices = kspace.shape[0]
            total_slices += num_slices
            contrast, _ = export_attrs(data['ismrmrd_header'], data.attrs['acquisition'])
            if contrast.startswith('fs'):
                contrast = 'fs'
            elif contrast.startswith('non'):
                contrast = 'nfs'
            if contrast == contrast_type or contrast_type == 'both':
                filtered_files.append(fname)
        files = filtered_files
        
        if sample_rate < 1:
            random.shuffle(files)
            num_files = round(len(files) * sample_rate)
            files = files[:num_files]
        for fname in sorted(files):
            data = h5py.File(fname, 'r')

            # Compute the size of zero padding in k-space
            # We really should have stored this as an attribute in the hdf5 file
            try:
                import ismrmrd
                hdr = ismrmrd.xsd.CreateFromDocument(
                    data['ismrmrd_header'][()])
                enc = hdr.encoding[0]
                enc_size = (enc.encodedSpace.matrixSize.x,
                            enc.encodedSpace.matrixSize.y,
                            enc.encodedSpace.matrixSize.z)
                enc_limits_center = enc.encodingLimits.kspace_encoding_step_1.center
                enc_limits_max = enc.encodingLimits.kspace_encoding_step_1.maximum + 1
                padding_left = enc_size[1] // 2 - enc_limits_center
                padding_right = padding_left + enc_limits_max
            except Exception as e:
                padding_left = None
                padding_right = None
                raise e

            kspace = data['kspace']
            num_slices = kspace.shape[0]
            self.examples += [(fname, slice, padding_left, padding_right) for slice in range(num_slices)]

        print('Slices added: ', len(self.examples),'/',total_slices)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        fname, slice, padding_left, padding_right = self.examples[i]
        with h5py.File(fname, 'r') as data:
            kspace = data['kspace'][slice]
            mask = np.asarray(data['mask']) if 'mask' in data else None
            target = data[self.recons_key][slice] if self.recons_key in data else None
            attrs = dict(data.attrs)
            attrs['padding_left'] = padding_left
            attrs['padding_right'] = padding_right
            return self.transform(kspace, mask, target, attrs, fname.name, slice)

    def set_augment_strength(self, p):
        self.transform.set_augment_strength(p)
        

# Code from https://github.com/pputzky/irim_fastMRI
def export_attrs(ismrmrd_header, acquisition):
    xml_header = ismrmrd_header[()].decode('UTF-8')
    dict_header = xmltodict.parse(xml_header)
    useful_info = ['studyInformation', 'measurementInformation', 'acquisitionSystemInformation',
                    'experimentalConditions', 'encoding', 'sequenceParameters', 'userParameters']
    ismrmrd_header_to_dict = {}
    for keys, values in dict_header.items():
        for key, value in values.items():
            if key in useful_info:
                for k, v in value.items():
                    if key != 'encoding':
                        if (key == 'acquisitionSystemInformation' and k == 'coilLabel'
                        ) or (key == 'userParameters' and k == 'userParameterDouble'):
                            for cnk, cnv in v[0].items():
                                ismrmrd_header_to_dict[key + '_' + k + '_' + cnk] = cnv
                        else:
                            ismrmrd_header_to_dict[key + '_' + k] = v
                    else:
                        if key == 'encoding' and k == 'trajectory':
                            ismrmrd_header_to_dict[k] = v
                        elif key == 'encoding' and k == 'parallelImaging':
                            for enc_pi_key, enc_pi_value in v.items():
                                if enc_pi_key == 'calibrationMode':
                                    ismrmrd_header_to_dict[k + '_' + enc_pi_key] = enc_pi_value
                                else:
                                    for enc_pi_k, enc_pi_v in enc_pi_value.items():
                                        ismrmrd_header_to_dict[k + '_' + enc_pi_key + '_' + enc_pi_k] = enc_pi_v
                        else:
                            for enc_key, enc_value in v.items():
                                for enc_k, enc_v in enc_value.items():
                                    ismrmrd_header_to_dict[k + '_' + enc_key + '_' + enc_k] = enc_v

    if acquisition == 'CORPDFS_FBKREPEAT':
        acquisition = 'CORPDFS_FBK'

    if acquisition == 'CORPDFS_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Aera':
        flag = 'fs_1_5T_Aera'
        features = [1, 0, 0, 0, 0, 0, 0, 0]
    elif acquisition == 'CORPD_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Aera':
        flag = 'non_fs_1_5T_Aera'
        features = [0, 1, 0, 0, 0, 0, 0, 0]
    elif acquisition == 'CORPDFS_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Biograph_mMR':
        flag = 'fs_3T_Biograph'
        features = [0, 0, 1, 0, 0, 0, 0, 0]
    elif acquisition == 'CORPD_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Biograph_mMR':
        flag = 'non_fs_3T_Biograph'
        features = [0, 0, 0, 1, 0, 0, 0, 0]
    elif acquisition == 'CORPDFS_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Prisma_fit':
        flag = 'fs_3T_Prisma'
        features = [0, 0, 0, 0, 1, 0, 0, 0]
    elif acquisition == 'CORPD_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Prisma_fit':
        flag = 'non_fs_3T_Prisma'
        features = [0, 0, 0, 0, 0, 1, 0, 0]
    elif acquisition == 'CORPDFS_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Skyra':
        flag = 'fs_3T_Skyra'
        features = [0, 0, 0, 0, 0, 0, 1, 0]
    elif acquisition == 'CORPD_FBK' and ismrmrd_header_to_dict[
        'acquisitionSystemInformation_systemModel'] == 'Skyra':
        flag = 'non_fs_3T_Skyra'
        features = [0, 0, 0, 0, 0, 0, 0, 1]

    return flag, features
