import io
import logging
import sys

import h5py
import kaldiio
import soundfile

from espnet_utils.io_utils import SoundHDF5File


def file_reader_helper(rspecifier: str, filetype: str = 'mat',
                       return_shape: bool = False,
                       segments: str = None):
    """Read uttid and array in kaldi style

    This function might be a bit confusing as "ark" is used
    for HDF5 to imitate "kaldi-rspecifier".

    Args:
        rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
        filetype: "mat" is kaldi-martix, "hdf5": HDF5
        return_shape: Return the shape of the matrix,
            instead of the matrix. This can reduce IO cost for HDF5.
    Returns:
        Generator[Tuple[str, np.ndarray], None, None]:

    Examples:
        Read from kaldi-matrix ark file:

        >>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
        ...     array

        Read from HDF5 file:

        >>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
        ...     array

    """
    if filetype == 'mat':
        return KaldiReader(rspecifier, return_shape=return_shape,
                           segments=segments)
    elif filetype == 'hdf5':
        return HDF5Reader(rspecifier, return_shape=return_shape)
    elif filetype == 'sound.hdf5':
        return SoundHDF5Reader(rspecifier, return_shape=return_shape)
    elif filetype == 'sound':
        return SoundReader(rspecifier, return_shape=return_shape)
    else:
        raise NotImplementedError(f'filetype={filetype}')


class KaldiReader:
    def __init__(self, rspecifier, return_shape=False, segments=None):
        self.rspecifier = rspecifier
        self.return_shape = return_shape
        self.segments = segments

    def __iter__(self):
        with kaldiio.ReadHelper(
                self.rspecifier, segments=self.segments) as reader:
            for key, array in reader:
                if self.return_shape:
                    array = array.shape
                yield key, array


class HDF5Reader:
    def __init__(self, rspecifier, return_shape=False):
        if ':' not in rspecifier:
            raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'
                             .format(self.rspecifier))
        self.rspecifier = rspecifier
        self.ark_or_scp, self.filepath = self.rspecifier.split(':', 1)
        if self.ark_or_scp not in ['ark', 'scp']:
            raise ValueError(f'Must be scp or ark: {self.ark_or_scp}')

        self.return_shape = return_shape

    def __iter__(self):
        if self.ark_or_scp == 'scp':
            hdf5_dict = {}
            with open(self.filepath, 'r', encoding='utf-8') as f:
                for line in f:
                    key, value = line.rstrip().split(None, 1)

                    if ':' not in value:
                        raise RuntimeError(
                            'scp file for hdf5 should be like: '
                            '"uttid filepath.h5:key": {}({})'
                            .format(line, self.filepath))
                    path, h5_key = value.split(':', 1)

                    hdf5_file = hdf5_dict.get(path)
                    if hdf5_file is None:
                        try:
                            hdf5_file = h5py.File(path, 'r')
                        except Exception:
                            logging.error(
                                'Error when loading {}'.format(path))
                            raise
                        hdf5_dict[path] = hdf5_file

                    try:
                        data = hdf5_file[h5_key]
                    except Exception:
                        logging.error('Error when loading {} with key={}'
                                      .format(path, h5_key))
                        raise

                    if self.return_shape:
                        yield key, data.shape
                    else:
                        yield key, data[()]

            # Closing all files
            for k in hdf5_dict:
                try:
                    hdf5_dict[k].close()
                except Exception:
                    pass

        else:
            if self.filepath == '-':
                # Required h5py>=2.9
                filepath = io.BytesIO(sys.stdin.buffer.read())
            else:
                filepath = self.filepath
            with h5py.File(filepath, 'r') as f:
                for key in f:
                    if self.return_shape:
                        yield key, f[key].shape
                    else:
                        yield key, f[key][()]


class SoundHDF5Reader:
    def __init__(self, rspecifier, return_shape=False):
        if ':' not in rspecifier:
            raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'
                             .format(rspecifier))
        self.ark_or_scp, self.filepath = rspecifier.split(':', 1)
        if self.ark_or_scp not in ['ark', 'scp']:
            raise ValueError(f'Must be scp or ark: {self.ark_or_scp}')
        self.return_shape = return_shape

    def __iter__(self):
        if self.ark_or_scp == 'scp':
            hdf5_dict = {}
            with open(self.filepath, 'r', encoding='utf-8') as f:
                for line in f:
                    key, value = line.rstrip().split(None, 1)

                    if ':' not in value:
                        raise RuntimeError(
                            'scp file for hdf5 should be like: '
                            '"uttid filepath.h5:key": {}({})'
                            .format(line, self.filepath))
                    path, h5_key = value.split(':', 1)

                    hdf5_file = hdf5_dict.get(path)
                    if hdf5_file is None:
                        try:
                            hdf5_file = SoundHDF5File(path, 'r')
                        except Exception:
                            logging.error(
                                'Error when loading {}'.format(path))
                            raise
                        hdf5_dict[path] = hdf5_file

                    try:
                        data = hdf5_file[h5_key]
                    except Exception:
                        logging.error('Error when loading {} with key={}'
                                      .format(path, h5_key))
                        raise

                    # Change Tuple[ndarray, int] -> Tuple[int, ndarray]
                    # (soundfile style -> scipy style)
                    array, rate = data
                    if self.return_shape:
                        array = array.shape
                    yield key, (rate, array)

            # Closing all files
            for k in hdf5_dict:
                try:
                    hdf5_dict[k].close()
                except Exception:
                    pass

        else:
            if self.filepath == '-':
                # Required h5py>=2.9
                filepath = io.BytesIO(sys.stdin.buffer.read())
            else:
                filepath = self.filepath
            for key, (a, r) in SoundHDF5File(filepath, 'r').items():
                if self.return_shape:
                    a = a.shape
                yield key, (r, a)


class SoundReader:
    def __init__(self, rspecifier, return_shape=False):
        if ':' not in rspecifier:
            raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'
                             .format(rspecifier))
        self.ark_or_scp, self.filepath = rspecifier.split(':', 1)
        if self.ark_or_scp != 'scp':
            raise ValueError('Only supporting "scp" for sound file: {}'
                             .format(self.ark_or_scp))
        self.return_shape = return_shape

    def __iter__(self):
        with open(self.filepath, 'r', encoding='utf-8') as f:
            for line in f:
                key, sound_file_path = line.rstrip().split(None, 1)
                # Assume PCM16
                array, rate = soundfile.read(sound_file_path, dtype='int16')
                # Change Tuple[ndarray, int] -> Tuple[int, ndarray]
                # (soundfile style -> scipy style)
                if self.return_shape:
                    array = array.shape
                yield key, (rate, array)
