# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.dataset.dataset import Dataset


def get_args():
    parser = argparse.ArgumentParser(description='inference with your model')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--prompt_data', required=True, help='prompt data file')
    parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
    parser.add_argument('--tts_text', required=True, help='tts input file')
    parser.add_argument('--llm_model', required=True, help='llm model file')
    parser.add_argument('--flow_model', required=True, help='flow model file')
    parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
    parser.add_argument('--gpu',
                        type=int,
                        default=-1,
                        help='gpu id for this rank, -1 for cpu')
    parser.add_argument('--mode',
                        default='sft',
                        choices=['sft', 'zero_shot'],
                        help='inference mode')
    parser.add_argument('--result_dir', required=True, help='asr result file')
    args = parser.parse_args()
    print(args)
    return args


def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # Init cosyvoice models from configs
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    with open(args.config, 'r') as f:
        configs = load_hyperpyyaml(f)

    model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
    model.load(args.llm_model, args.flow_model, args.hifigan_model)

    test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
                           tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
    test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

    del configs
    os.makedirs(args.result_dir, exist_ok=True)
    fn = os.path.join(args.result_dir, 'wav.scp')
    f = open(fn, 'w')
    with torch.no_grad():
        for _, batch in tqdm(enumerate(test_data_loader)):
            utts = batch["utts"]
            assert len(utts) == 1, "inference mode only support batchsize 1"
            text_token = batch["text_token"].to(device)
            text_token_len = batch["text_token_len"].to(device)
            tts_index = batch["tts_index"]
            tts_text_token = batch["tts_text_token"].to(device)
            tts_text_token_len = batch["tts_text_token_len"].to(device)
            speech_token = batch["speech_token"].to(device)
            speech_token_len = batch["speech_token_len"].to(device)
            speech_feat = batch["speech_feat"].to(device)
            speech_feat_len = batch["speech_feat_len"].to(device)
            utt_embedding = batch["utt_embedding"].to(device)
            spk_embedding = batch["spk_embedding"].to(device)
            if args.mode == 'sft':
                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
                               'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
            else:
                model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
                               'prompt_text': text_token, 'prompt_text_len': text_token_len,
                               'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                               'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                               'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                               'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
            tts_speeches = []
            for model_output in model.tts(**model_input):
                tts_speeches.append(model_output['tts_speech'])
            tts_speeches = torch.concat(tts_speeches, dim=1)
            tts_key = '{}_{}'.format(utts[0], tts_index[0])
            tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
            torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
            f.write('{} {}\n'.format(tts_key, tts_fn))
            f.flush()
    f.close()
    logging.info('Result wav.scp saved in {}'.format(fn))


if __name__ == '__main__':
    main()
