# --coding:utf-8--
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
import torch
import numpy as np
import torch.nn.functional as F
import pandas as pd
from evaluation.cnen_chinese2pinyin import text2pinyin
from evaluation.cnen_text2phone_yoyo import text2phone

# from Encodec_16k_320.net3 import SoundStream
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
import pydub
import struct
import torch
import cv2
import argparse
import numpy as np
import torchvision
from PIL import Image
import soundfile
from nn_ss.sound_synthesis2.utils.io import load_yaml_config
from nn_ss.sound_synthesis.modeling.build import build_model
from nn_ss.sound_synthesis2.utils.misc import get_model_parameters_info
from utils.checkpoint import load_checkpoint
import datetime
import yaml
import pandas as pd
import numpy as np
import typing as tp
from pathlib import Path
import pyloudnorm as pyln
import re
import soundfile as sf
# from rnnoise_wrapper import RNNoise
# denoiser = RNNoise()

from encodec import EncodecModel
from encodec.utils import convert_audio
import IPython.display as ipd
import torchaudio
import torch
model_encodec = EncodecModel.encodec_model_24khz()
from vocos import Vocos
# from vocos.pretrained import Vocos
pattern = r'，|,|。|；|、'

config_path = "/home/disk2/gongxuefei/Project/1_vocder_upgrade/AudioLM/SoundStorm_durprompt16k_soundstorm_V4_5_encoder_wenet/vocos-encodec-24khz/config.yaml"
model_path = "/home/disk2/gongxuefei/Project/1_vocder_upgrade/AudioLM/SoundStorm_durprompt16k_soundstorm_V4_5_encoder_wenet/vocos-encodec-24khz/pytorch_model.bin"
vocos = Vocos.from_pretrained0724(config_path,model_path)


def loudness_filter(data,rate):
    # print(arg)
    # path1, path2 = arg
    # data, rate = sf.read(path1)
    meter = pyln.Meter(rate)  # create BS.1770 meter
    loudness = meter.integrated_loudness(data)
    wav = pyln.normalize.loudness(data, loudness, -25)

    if np.abs(wav).max() > 1.0:
        # print(path1, np.abs(wav).max())
        wav = wav / np.abs(wav).max()
    return wav


def convert_audio_tmp(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
    assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
    if target_channels == 1:
        wav = wav.mean(0, keepdim=True)
    elif target_channels == 2:
        *shape, _, length = wav.shape
        wav = wav.expand(*shape, target_channels, length)
    elif wav.shape[0] == 1:
        wav = wav.expand(target_channels, -1)
    wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
    return wav



def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
               sample_rate: int, rescale: bool = False):
    limit = 0.99
    mx = wav.abs().max()
    if rescale:
        wav = wav * min(limit / mx, 1)
    else:
        wav = wav.clamp(-limit, limit)
    torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)


def build_codec_model(config):
    model = eval(config.generator.name)(**config.generator.config)
    return model


class Diffsound():
    def __init__(self, config, path):
        self.info = self.get_model(ema=True, model_path=path, config_path=config)
        self.model = self.info['model']
        # self.epoch = self.info['epoch']
        self.model_name = self.info['model_name']
        self.model = self.model.cuda()
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        self.num_quant = 8
        # self.soundstream = build_soundstream()

    def get_model(self, ema, model_path, config_path):
        if 'OUTPUT' in model_path:  # pretrained model
            model_name = model_path.split(os.path.sep)[-3]
        else:
            model_name = os.path.basename(config_path).replace('.yaml', '')
        config = load_yaml_config(config_path)
        model = build_model(config)  # 加载 dalle model
        model_parameters = get_model_parameters_info(model)  # 参数详情
        print(model_parameters)
        load_checkpoint(model,model_path)
        # use_cuda = torch.cuda.is_available()
        # device = torch.device('cuda' if use_cuda else 'cpu')
        # model = model.to(device)

        # if os.path.exists(model_path):
        #     ckpt = torch.load(model_path, map_location="cpu")
        # if 'last_epoch' in ckpt:
        #     epoch = ckpt['last_epoch']
        # elif 'epoch' in ckpt:
        #     epoch = ckpt['epoch']
        # else:
        #     epoch = 0
        # missing, unexpected = model.load_state_dict(ckpt["model"], strict=False)
        # print('Model missing keys:\n', missing)
        # print('Model unexpected keys:\n', unexpected)
        # if ema == True and 'ema' in ckpt:
        #     print("Evaluate EMA model")
        #     ema_model = model.get_ema_model()
        #     missing, unexpected = ema_model.load_state_dict(ckpt['ema'], strict=False)
        return {'model': model, 'model_name': model_name, 'parameter': model_parameters}

    def read_tsv(self, val_path):
        train_tsv = pd.read_csv(val_path, sep=',', usecols=[0, 1])
        filenames = train_tsv['file_name']
        captions = train_tsv['caption']
        filenames_ls = []
        captions_ls = []
        for name in filenames:
            filenames_ls.append(name)
        for cap in captions:
            captions_ls.append(cap)
        caps_dict = {}
        for i in range(len(filenames_ls)):
            if filenames_ls[i] not in caps_dict.keys():
                caps_dict[filenames_ls[i]] = [captions_ls[i]]
            else:
                caps_dict[filenames_ls[i]].append(captions_ls[i])
        return caps_dict

    def generate_sample(self, prompt_wav, prompt_text, text, truncation_rate, save_root, fast=False):


        prompt_name = os.path.basename(prompt_wav).split('.')[0]
        # print("prompt_name:",prompt_name)
        prompt_semanstic_tokens = text2phone.text_to_sequence_seplevel0_345(text=prompt_text)[:-1]
        prompt_semanstic_tokens = torch.from_numpy(np.array(prompt_semanstic_tokens)).unsqueeze(0).cuda()

        target_semanstic_tokens = text2phone.text_to_sequence_seplevel0_345(text=text)[1:]
        target_semanstic_tokens = torch.from_numpy(np.array(target_semanstic_tokens)).unsqueeze(0).cuda()
        # print("prompt_semanstic_tokens:", prompt_semanstic_tokens.shape)
        print("target_semanstic_tokens:", target_semanstic_tokens.shape,target_semanstic_tokens)

        all_semantic_tokens = torch.cat((prompt_semanstic_tokens,target_semanstic_tokens),dim = 1)
        # print("all_semantic_tokens:", all_semantic_tokens.shape,all_semantic_tokens)
        wav, sr = torchaudio.load(prompt_wav)
        if sr > 16000:
            wav = convert_audio(wav, 24000, 16000, 1)
            sr = 16000
        wav1 = loudness_filter(wav[0,:].numpy(), sr)
        wav = torch.from_numpy(wav1).unsqueeze(0)

        # print("123444:",wav.shape,sr)
        wav = convert_audio(wav, sr, model_encodec.sample_rate, model_encodec.channels)
        wav = wav.unsqueeze(0)
        # print("123555:", wav.shape, model_encodec.sample_rate)

        encoded_frames = model_encodec.encode(wav)
        prompt_acoustic_tokens = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]
        # print("prompt_acoustic_tokens 0:",prompt_acoustic_tokens.shape)
        # prompt_acoustic_tokens = prompt_acoustic_tokens.transpose(0,1)

        prompt_acoustic_tokens = prompt_acoustic_tokens[:, :self.num_quant, :].cuda()  # only use 3 codebook, you can set any config.
        # print("prompt_acoustic_tokens 1:", prompt_acoustic_tokens.shape)

        tmp_semantic_tokens_len = prompt_semanstic_tokens.shape[1]
        semantics_lengths = [np.shape(x)[-1] for x in all_semantic_tokens]
        prompt_semantics_lengths = [np.shape(x)[-1] for x in prompt_semanstic_tokens]
        target_semantics_lengths = [np.shape(x)[-1] for x in target_semanstic_tokens]
        prompt_acoustics_lengths = [np.shape(x)[-1] for x in prompt_acoustic_tokens]
        if fast != False:
            add_string = 'r,fast' + str(fast - 1)
        else:
            add_string = 'r'
        # print("prompt_semanstic_tokens:",prompt_acoustics_lengths,prompt_semantics_lengths)
        new_samples = {}
        new_samples['all_semantics'] = all_semantic_tokens
        new_samples['prompt_semantics'] = prompt_semanstic_tokens
        new_samples['target_semantics'] = target_semanstic_tokens
        new_samples['prompt_acoustics'] = prompt_acoustic_tokens
        new_samples['all_semantics_lengths'] = torch.from_numpy(np.array(semantics_lengths)).cuda()
        new_samples['prompt_acoustics_lengths'] = torch.from_numpy(np.array(prompt_acoustics_lengths))
        new_samples['prompt_semantics_lengths'] = torch.from_numpy(np.array(prompt_semantics_lengths))
        new_samples['target_semantics_lengths'] = torch.from_numpy(np.array(target_semantics_lengths))
        # print("all_semantic_tokens1:",all_semantic_tokens.shape,all_semantic_tokens)
        # print("prompt_acoustics2:", prompt_acoustic_tokens.shape, prompt_acoustic_tokens)
        # print("all_semantics_lengths3:", new_samples['all_semantics_lengths'].shape, new_samples['all_semantics_lengths'])
        # print("prompt_acoustics_lengths4:", new_samples['prompt_acoustics_lengths'].shape, new_samples['prompt_acoustics_lengths'])
        # print("prompt_semantic_tokens_lengths:", new_samples['prompt_semantic_tokens_lengths'].shape, new_samples['prompt_semantic_tokens_lengths'])

        with torch.no_grad():
            model_out = self.model.generate_content_tmp(
                batch=new_samples,
                filter_ratio=0,
                replicate=1,  # 每个样本重复多少次?
                content_ratio=1,
                return_att_weight=False,
                sample_type="top" + str(truncation_rate) + add_string, )  # B x C x H x W
            content = model_out['token_pred']  #
            # print("1content1:", content.shape)
            codes = content.transpose(1, 2)   # content.reshape(content.shape[0], self.num_quant, -1)  # reshape to original shape
            codes = codes[0,:,:].cpu()
            # print('1prodict_content ', content.shape, codes.shape)
            features = vocos.codes_to_features(codes)
            bandwidth_id = torch.tensor([2])  # 6 kbps  [1.5, 3.0, 6.0->n_q=8, 12.0].
            out = vocos.decode(features, bandwidth_id=bandwidth_id)
            # out = self.soundstream.decode(codes)
            # print("out:",out.shape)
            coarse_wav = out.detach().cpu()
            # coarse_wav = loudness_filter(coarse_wav[0, :].numpy(), 24000)
            # coarse_wav = torch.from_numpy(coarse_wav).unsqueeze(0)

            # print("1coarse_wav:",coarse_wav.shape)
            # save_audio(coarse_wav.cpu(), store_root + '/' + prompt_name + '_infer.wav', 24000)
            codes = prompt_acoustic_tokens[0,:,:].cpu()
            # print('1prodict_content ', content.shape, codes.shape)
            features = vocos.codes_to_features(codes)
            bandwidth_id = torch.tensor([2])  # 6 kbps  [1.5, 3.0, 6.0->n_q=8, 12.0].
            out = vocos.decode(features, bandwidth_id=bandwidth_id)
            # out = self.soundstream.decode(codes)
            # print("out:",out.shape)
            real_wav = out.detach().cpu()

            # save_audio(real_wav.cpu(), store_root + '/' + prompt_name + '_real.wav', 24000)

            return coarse_wav,real_wav


if __name__ == "__main__":

    config_path = 'config/soundstorm8.yaml'
    pretrained_model_path = '/home/disk2/gongxuefei/Project/1_vocder_upgrade/AudioLM/SoundStorm_durprompt16k_soundstorm_V4_5_encoder_wenet/exp_output8/ep2_pos_25cate/checkpoint-84000.pt'
    save_root_ = './generated_sample_aishell'
    random_seconds_shift = datetime.timedelta(seconds=np.random.randint(60))
    key_words = 'instructtts_diffsound'
    now = (datetime.datetime.now() - random_seconds_shift).strftime('%Y-%m-%dT%H-%M-%S')
    save_root = os.path.join(save_root_, key_words + '_samples')
    os.makedirs(save_root, exist_ok=True)

    Diffsound = Diffsound(config=config_path, path=pretrained_model_path)
    # prompt_text = '大家好我是龚雪飞我是江西人老家是江西南昌'
    # prompt_text = '甚至出现交易几乎停滞的情况'
    prompt_text = '在一个美丽的森林里生活着许许多多的小动物小猫花花和小熊是一对形影不离的好朋友'
    # prompt_text = '下一站幸福任光烯穿高根些跑步时的音乐'
    # prompt_text = '普通的酒我怎么拿的出手呢'
    # prompt_text = '强迫自己进行深度的超出环境的思考读书是为数不多的能够和优秀的人对话交流的方式'
    # prompt_text = '下面我用一段适配给大家展示一下这种理论创新和以及工程创新带来的效果'
    # lines = '习近平总书记#2对民营经济发展高度重视，多次作出重要指示#2和批示，为实现民营经济#2健康发展#2高质量发展#2注入强大信心和动力。广大民营企业#2和各地各部门深入贯彻#2总书记重要指示精神，加强自主创新，持续优化民营经济#2发展环境，为构建#2新发展格局、推动高质量发展#2作出更大贡献'
    lines = '习近，我是，中国，人民，天下，左右，人生，大侠，天意，的话，啊啊，哈哈，嘻嘻，呵呵，嗯嗯'

    result_list = re.split(pattern, lines)
    print("result_list:", result_list)
    promptPath = '/home/disk2/gongxuefei/DATA/soundstorm/prompt_nips2'
    for item in os.listdir(promptPath):
        prompt_wav=os.path.join(promptPath,item)
        print("prompt_wav:", prompt_wav)
        result_wav = None
        for text in result_list:
            coarse_wav,real_wav = Diffsound.generate_sample(prompt_wav, prompt_text, text, truncation_rate=0.85, save_root=save_root, fast=False)
            # print("real_wav:", real_wav.shape)
            if result_wav is None:
                result_wav = coarse_wav
            else:
                result_wav = torch.cat((result_wav,coarse_wav),dim=1)
            # print("222222:",result_wav.shape)
            # result_wav.append(real_wav)
        prompt_name = os.path.basename(prompt_wav).split('.')[0]
        save_audio(result_wav.cpu(), save_root + '/' + prompt_name + '_infer_24k_25cate_ep284knorm_pos.wav', 24000)
        # save_audio(real_wav.cpu(), save_root + '/' + prompt_name + '_real.wav', 24000)



















