#!/usr/bin/env python3
import argparse
from distutils.util import strtobool
import logging

import kaldiio
import numpy

from espnet_transform.cmvn import CMVN
from espnet_utils.cli_readers import file_reader_helper
from espnet_utils.cli_utils import get_commandline_args
from espnet_utils.cli_utils import is_scipy_wav_style
from espnet_utils.cli_writers import file_writer_helper


def get_parser():
    parser = argparse.ArgumentParser(
        description='apply mean-variance normalization to files',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--verbose', '-V', default=0, type=int,
                        help='Verbose option')
    parser.add_argument('--in-filetype', type=str, default='mat',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for the rspecifier. '
                             '"mat" is the matrix format in kaldi')
    parser.add_argument('--stats-filetype', type=str, default='mat',
                        choices=['mat', 'hdf5', 'npy'],
                        help='Specify the file format for the rspecifier. '
                             '"mat" is the matrix format in kaldi')
    parser.add_argument('--out-filetype', type=str, default='mat',
                        choices=['mat', 'hdf5'],
                        help='Specify the file format for the wspecifier. '
                             '"mat" is the matrix format in kaldi')

    parser.add_argument('--norm-means', type=strtobool, default=True,
                        help='Do variance normalization or not.')
    parser.add_argument('--norm-vars', type=strtobool, default=False,
                        help='Do variance normalization or not.')
    parser.add_argument('--reverse', type=strtobool, default=False,
                        help='Do reverse mode or not')
    parser.add_argument('--spk2utt', type=str,
                        help='A text file of speaker to utterance-list map. '
                             '(Don\'t give rspecifier format, such as '
                             '"ark:spk2utt")')
    parser.add_argument('--utt2spk', type=str,
                        help='A text file of utterance to speaker map. '
                             '(Don\'t give rspecifier format, such as '
                             '"ark:utt2spk")')
    parser.add_argument('--write-num-frames', type=str,
                        help='Specify wspecifer for utt2num_frames')
    parser.add_argument('--compress', type=strtobool, default=False,
                        help='Save in compressed format')
    parser.add_argument('--compression-method', type=int, default=2,
                        help='Specify the method(if mat) or '
                             'gzip-level(if hdf5)')
    parser.add_argument('stats_rspecifier_or_rxfilename',
                        help='Input stats. e.g. ark:stats.ark or stats.mat')
    parser.add_argument('rspecifier', type=str,
                        help='Read specifier id. e.g. ark:some.ark')
    parser.add_argument('wspecifier', type=str,
                        help='Write specifier id. e.g. ark:some.ark')
    return parser


def main():
    args = get_parser().parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if ':' in args.stats_rspecifier_or_rxfilename:
        is_rspcifier = True
        if args.stats_filetype == 'npy':
            stats_filetype = 'hdf5'
        else:
            stats_filetype = args.stats_filetype

        stats_dict = dict(file_reader_helper(
            args.stats_rspecifier_or_rxfilename, stats_filetype))
    else:
        is_rspcifier = False
        if args.stats_filetype == 'mat':
            stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
        else:
            stats = numpy.load(args.stats_rspecifier_or_rxfilename)
        stats_dict = {None: stats}

    cmvn = CMVN(stats=stats_dict,
                norm_means=args.norm_means,
                norm_vars=args.norm_vars,
                utt2spk=args.utt2spk,
                spk2utt=args.spk2utt,
                reverse=args.reverse)

    with file_writer_helper(
            args.wspecifier,
            filetype=args.out_filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method) as writer:
        for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat
            mat = cmvn(mat, utt if is_rspcifier else None)
            writer[utt] = mat


if __name__ == "__main__":
    main()
