# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import json
import random
import re
import tarfile
from subprocess import PIPE, Popen
from urllib.parse import urlparse

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import torch.nn.functional as F
torchaudio.utils.sox_utils.set_buffer_size(16500)

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma','npy','pth'])


def url_opener(data):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        # TODO(Binbin Zhang): support HTTP
        url = sample['src']
        # print("******",url)
        try:
            pr = urlparse(url)
            # local file
            if pr.scheme == '' or pr.scheme == 'file':
                stream = open(url, 'rb')
            # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
            else:
                cmd = f'wget -q -O - {url}'
                process = Popen(cmd, shell=True, stdout=PIPE)
                sample.update(process=process)
                stream = process.stdout
            sample.update(stream=stream)
            # print("1111:url_opener_end")

            yield sample
        except Exception as ex:
            logging.warning('Failed to open {}'.format(url))

def tar_file_and_group(data):
    """ Expand a stream of open tar files into a stream of tar file contents.
        And groups the file with same prefix

        Args:
            data: Iterable[{src, stream}]

        Returns:
            Iterable[{key, wav, txt, sample_rate}]
    """
    # print("data2:",data)
    for sample in data:
        # print("sample:",sample)
        assert 'stream' in sample
        # stream = tarfile.open(fileobj=sample['stream'], mode="r|*")
        stream = tarfile.open(sample['src'])
        prev_prefix = None
        example = {}
        valid = True
        for tarinfo in stream:
            name = tarinfo.name
            pos = name.rfind('.')
            assert pos > 0
            prefix, postfix = name[:pos], name[pos + 1:]
            # print("*********))",postfix,prefix)
            if prev_prefix is not None and prefix != prev_prefix:
                example['key'] = prev_prefix
                if valid:
                    yield example
                example = {}
                valid = True
            with stream.extractfile(name) as file_obj:
                try:
                    if postfix == 'txt':
                        example['txt'] = file_obj.read().decode('utf8').strip()
                    elif postfix == 'json':
                        data_tmp = json.load(file_obj)
                        # print(data_tmp)
                        # assert 0
                        example['style'] = data_tmp["style_token"]
                        example['style_attention'] =data_tmp["attention_mask"]
                    elif postfix == 'pth':   # is durations
                        # print("file_obj",file_obj)
                        dur = torch.load(file_obj)  #,torchaudio.load(file_obj)
                        example['dur'] = dur
                        # print("dur:", dur)
                    elif postfix == 'npy':   # is acoustic
                        acoustic = torch.load(file_obj)
                        example['wav'] = acoustic
                        # print("acoustic:", acoustic)
                    else:
                        example[postfix] = file_obj.read()
                except Exception as ex:
                    valid = False
                    logging.warning('error to parse {}'.format(name))
            prev_prefix = prefix
        if prev_prefix is not None:
            example['key'] = prev_prefix
            # print("example:",example)
            # print("2:tar_file_and_group_end")
            yield example
        stream.close()
        if 'process' in sample:
            sample['process'].communicate()
        sample['stream'].close()

def load_raw(data,num_quant=8):
    # #这个是一条句子
    # print(data.size())
    # assert 0
    
    for sample in data:

        # print(sample)

        # assert 0

        # print("jsp_tmp")
        # print("sample",sample)
        over_duration = sample['dur']
        over_acoustic = sample['wav']

        # print("over_acoustic",over_acoustic.size())
        # print("*****jsp",num_quant)
        # assert 0

        over_acoustic = over_acoustic[:num_quant, ...]  #[n,t] 这里限制前n层
        over_semantic = torch.Tensor([int(item) for item in sample['txt'].split()])
        # over_style = torch.Tensor([int(item) for item in sample['style'].split()])
        over_style = sample['style'][0]
        over_style_attention = sample['style_attention'][0]

        # print("over_semantic",over_semantic.size())

        #加强模型的鲁棒性
        if int(over_semantic[0])==int(1):  #如果存在静音片段
            if int(over_duration[0]) < int(15): #持续时间未到15帧,进行补齐
                # id=0
                diff_num = int(15) - int(over_duration[0]) #补齐的长度
                over_duration[0] = 15
                over_acoustic = torch.cat([over_acoustic[:,0].repeat(diff_num,1).transpose(0,1),over_acoustic],axis=1)
                # print("diff_num",diff_num)
                # print("over_acoustic[:,0]",over_acoustic[:,0],over_acoustic[:,0].size())
                # print("over_acoustic[:,0].repeat(diff_num,1)",over_acoustic[:,0].repeat(diff_num,1),over_acoustic[:,0].repeat(diff_num,1).size())
                # print("over_acoustic[:,0].repeat(diff_num,1).transpose(0,1)",over_acoustic[:,0].repeat(diff_num,1).transpose(0,1),over_acoustic[:,0].repeat(diff_num,1).transpose(0,1).size())
                # assert 0
            elif int(over_duration[0]) > int(15): #持续时间超过20帧的，截掉
                # id=1
                diff_num = int(over_duration[0]) - int(15)
                over_duration[0] = 15
                over_acoustic = over_acoustic[:,diff_num:]
        else: #如果第一个不是sil
            if int(over_semantic[-1]) == int(345):
                # id=2
                # print("over_duration",over_duration,over_semantic)
                diff_num = int(15)
                # print("*****",torch.ones((1))*int(20),torch.tensor([diff_num]))
                over_duration = torch.cat([torch.tensor([diff_num]),over_duration],axis=0)
                over_acoustic = torch.cat([over_acoustic[:,-1].repeat(diff_num,1).transpose(0,1),over_acoustic],axis=1)
                over_semantic = torch.cat([torch.ones((1)),over_semantic],axis=0)

 
        if over_acoustic.shape[1] == int(torch.sum(over_duration)) and over_acoustic.shape[1] < 1500 and over_duration.shape[0]==over_semantic.shape[0] and over_acoustic.shape[1] > 75:
            example = dict(
                target_semantics = over_semantic,
                target_durations = over_duration,
                target_acoustics = over_acoustic,
                prompt_style = over_style,
                prompt_style_attention = over_style_attention,
            )
            # print("example:",example)
            # assert 0
            # print("3:load_raw_end")

            yield example
        else:
            continue

def shuffle(data, shuffle_size=1500):
    """ Local shuffle the data

        Args:
            data: Iterable[{key, feat, label}]
            shuffle_size: buffer size for shuffle

        Returns:
            Iterable[{key, feat, label}]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= shuffle_size:
            random.shuffle(buf)
            for x in buf:
                # print("40:shuffle_end")

                yield x
            buf = []
    # The sample left over
    random.shuffle(buf)
    for x in buf:
        # print("41:shuffle_end")

        yield x

def sort(data, sort_size=500):
    """ Sort the data by feature length.
        Sort is used after shuffle and before batch, so we can group
        utts with similar lengths into a batch, and `sort_size` should
        be less than `shuffle_size`

        Args:
            data: Iterable[{key, feat, label}]
            sort_size: buffer size for sort

        Returns:
            Iterable[{key, feat, label}]
    """

    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= sort_size:
            buf.sort(key=lambda x: x['target_acoustics'].size(1))
            for x in buf:
                # print("5:sort_end")

                yield x
            buf = []
    # The sample left over
    buf.sort(key=lambda x: x['target_acoustics'].size(1))
    for x in buf:
        # print("5:sort_end")

        yield x

def static_batch(data, batch_size=16):
    """ Static batch the data by `batch_size`

        Args:
            data: Iterable[{key, feat, label}]
            batch_size: batch size

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= batch_size:
            yield buf
            buf = []
    if len(buf) > 0:
        yield buf


def dynamic_batch(data, max_frames_in_batch=3500):
    """ Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    # print("max_frames_in_batch",max_frames_in_batch)
    # assert 0

    buf = []
    longest_frames = 0
    frames_after_padding = 0
    for sample in data:
        # print("sample:",sample)
        # assert 0
        assert 'target_acoustics' in sample
        assert isinstance(sample['target_acoustics'], torch.Tensor)

        #这里计算真实情况下的batch的长度
        ta_sample_frames = sample['target_acoustics'].size(1)
        ts_sample_frames = sample['target_semantics'].size()[0]
        pstyle_sample_frames = 64
        td_sample_frames = sample['target_durations'].size()[0]

        # print("******",ta_sample_frames,pa_sample_frames,ts_sample_frames,ps_sample_frames
        #               ,td_sample_frames,pd_sample_frames)
        # assert 0
        # print("new_sample_frames",new_sample_frames)

        longest_frames = (ta_sample_frames+td_sample_frames+ts_sample_frames+pstyle_sample_frames)
        frames_after_padding += longest_frames

        if frames_after_padding > max_frames_in_batch:
            # print("6:dynamic_batchr_end")

            yield buf
            buf = [sample]
            frames_after_padding = longest_frames
        else:
            buf.append(sample)
    if len(buf) > 0:
        yield buf


def dynamic_batch_ori(data, max_frames_in_batch=4000):
    """ Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    longest_frames = 0
    frames_after_padding = 0
    for sample in data:
        # print("sample:",sample)
        assert 'target_acoustics' in sample
        assert isinstance(sample['target_acoustics'], torch.Tensor)
        new_sample_frames = sample['target_acoustics'].size(1)
        longest_frames = 4*new_sample_frames   #max(longest_frames, new_sample_frames)
        frames_after_padding += longest_frames
        if frames_after_padding > max_frames_in_batch:
            # print("6:dynamic_batchr_end")

            yield buf
            buf = [sample]
            frames_after_padding = 4*new_sample_frames
        else:
            buf.append(sample)
    if len(buf) > 0:
        yield buf


def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
    """ Wrapper for static/dynamic batch
    """
    if batch_type == 'static':
        return static_batch(data, batch_size)
    elif batch_type == 'dynamic':
        return dynamic_batch(data, max_frames_in_batch)
    else:
        logging.fatal('Unsupported batch type {}'.format(batch_type))


def pad_2D(inputs, PAD):
    # when each sample in inputs is 2D, this function can be used
    # print('inputs ', inputs.shape)
    def pad(x, max_len):
        # print('x ', x.shape, max_len)
        return F.pad(x, (0, max_len - x.shape[-1]), mode="constant", value=PAD)

    max_len = max(np.shape(x)[-1] for x in inputs)  #
    output = np.stack([pad(x, max_len) for x in inputs])  #
    return output

def padding(data):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    # print("7:padding_end")

    semantic_token_nums = 500
    prompt_style_start_id = semantic_token_nums
    prompt_style_end_id = semantic_token_nums + 1
    target_semantic_start_id = semantic_token_nums + 2
    target_semantic_end_id = semantic_token_nums + 3
    acoustic_token_nums = 1024  #
    target_acoustic_eos = acoustic_token_nums  # self.acoustic_token_nums + 1

    for sample in data:
        # print("data:12",sample)
        # assert 0
        assert isinstance(sample, list)
        target_acoustics_lengths = torch.tensor([x['target_acoustics'].size(1) for x in sample],dtype=torch.int32)
        order = torch.argsort(target_acoustics_lengths, descending=True)
        target_semantics_lengths = torch.tensor([sample[i]['target_semantics'].size(0) for i in order], dtype=torch.int32)
        # prompt_style_lengths = torch.tensor([sample[i]['prompt_style'].size(0) for i in order], dtype=torch.int32)       

        target_acoustics = [sample[i]['target_acoustics'] for i in order]
        target_semantics = [sample[i]['target_semantics'] for i in order]
        target_durations = [sample[i]['target_durations'] for i in order]
        prompt_style = [sample[i]['prompt_style'] for i in order]
        prompt_style_attention = [sample[i]['prompt_style_attention'] for i in order]

        

        padded_target_semantics = pad_2D(target_semantics, target_semantic_end_id)  #
        padding_target_durations = pad_2D(target_durations, 0)
        padding_target_acoustics = pad_2D(target_acoustics, target_acoustic_eos)
        # padding_prompt_style = pad_2D(prompt_style, prompt_style_end_id)


        new_samples = {}

        new_samples['target_acoustics'] = torch.from_numpy(padding_target_acoustics)
        new_samples['target_semantics'] = torch.from_numpy(padded_target_semantics)
        new_samples['target_durations'] = torch.from_numpy(padding_target_durations)
        new_samples['prompt_style'] = torch.tensor(prompt_style)
        new_samples['prompt_style_attention'] = torch.tensor(prompt_style_attention)

        new_samples['target_semantics_lengths'] = target_semantics_lengths
        new_samples['target_acoustics_lengths'] = target_acoustics_lengths
        # new_samples['prompt_style_lengths'] = prompt_style_lengths

        # print(new_samples)

        # assert 0
    
        yield new_samples




