# --coding:utf-8--
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import numpy as np
import torch.nn.functional as F
import pandas as pd
from evaluation.cnen_chinese2pinyin import text2pinyin
from evaluation.cnen_text2phone_yoyo import text2phone
import copy

# from Encodec_16k_320.net3 import SoundStream
import sys
# sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
import pydub
import struct
import torch
import cv2
import argparse
import numpy as np
import torchvision
from PIL import Image

from nn_ss.sound_synthesis2.utils.io import load_yaml_config
from nn_ss.sound_synthesis.modeling.build import build_model
from nn_ss.sound_synthesis2.utils.misc import get_model_parameters_info
from utils.checkpoint import load_checkpoint
import datetime
import yaml
import pandas as pd
import numpy as np
import typing as tp
from pathlib import Path
import pyloudnorm as pyln
import re
import soundfile as sf
import glob
import time
# from rnnoise_wrapper import RNNoise
# denoiser = RNNoise()

from encodec.utils import convert_audio
import IPython.display as ipd
import torchaudio
import torch
import json


device=torch.device('cuda:0')

pattern = r'，|,|。|；|、'


def get_args():
    parser = argparse.ArgumentParser(description='Generated Pretrained Language Model')
    parser.add_argument('--config_path',
                        type=str,
                        default="none",
                        help='the name of config path')

    parser.add_argument('--output_key',
                        type=str,
                        default="vccm_base_epoch9",
                        help='the name of output path')

    parser.add_argument('--checkpoint',
                        type=str,
                        default="aa",
                        help='checkpoint')

    parser.add_argument('--index_id_list',
                        type=str,
                        default="aa",
                        help='index_id_list')

    parser.add_argument('--text_style_csv',
                        type=str,
                        default="aa",
                        help='text_style_csv')
   
    args = parser.parse_args()
    return args


def loudness_filter(data,rate):
    # print(arg)
    # path1, path2 = arg
    # data, rate = sf.read(path1)
    meter = pyln.Meter(rate)  # create BS.1770 meter
    loudness = meter.integrated_loudness(data)
    wav = pyln.normalize.loudness(data, loudness, -25)

    if np.abs(wav).max() > 1.0:
        # print(path1, np.abs(wav).max())
        wav = wav / np.abs(wav).max()
    return wav


def convert_audio_tmp(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
    assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
    if target_channels == 1:
        wav = wav.mean(0, keepdim=True)
    elif target_channels == 2:
        *shape, _, length = wav.shape
        wav = wav.expand(*shape, target_channels, length)
    elif wav.shape[0] == 1:
        wav = wav.expand(target_channels, -1)
    wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
    return wav



def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
               sample_rate: int, rescale: bool = False):
    limit = 0.99
    mx = wav.abs().max()
    if rescale:
        wav = wav * min(limit / mx, 1)
    else:
        wav = wav.clamp(-limit, limit)
    torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)


def build_codec_model(config):
    model = eval(config.generator.name)(**config.generator.config)
    return model


class Vccm_Infer():
    def __init__(self, config, path):
        self.info = self.get_model(ema=True, model_path=path, config_path=config)
        self.model = self.info['model']
        # self.epoch = self.info['epoch']
        self.model_name = self.info['model_name']
        self.model = self.model.cuda()
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        self.num_quant = 6
        # self.soundstream = build_soundstream()

    def get_model(self, ema, model_path, config_path):
        if 'OUTPUT' in model_path:  # pretrained model
            model_name = model_path.split(os.path.sep)[-3]
        else:
            model_name = os.path.basename(config_path).replace('.yaml', '')
        # breakpoint()
        config = load_yaml_config(config_path)
        model = build_model(config)  # 加载 dalle model
        model_parameters = get_model_parameters_info(model)  # 参数详情
        print(model_parameters)
        load_checkpoint(model,model_path)

        return {'model': model, 'model_name': model_name, 'parameter': model_parameters}

    def read_tsv(self, val_path):
        train_tsv = pd.read_csv(val_path, sep=',', usecols=[0, 1])
        filenames = train_tsv['file_name']
        captions = train_tsv['caption']
        filenames_ls = []
        captions_ls = []
        for name in filenames:
            filenames_ls.append(name)
        for cap in captions:
            captions_ls.append(cap)
        caps_dict = {}
        for i in range(len(filenames_ls)):
            if filenames_ls[i] not in caps_dict.keys():
                caps_dict[filenames_ls[i]] = [captions_ls[i]]
            else:
                caps_dict[filenames_ls[i]].append(captions_ls[i])
        return caps_dict

    def generate_sample(self, input_text, input_style, input_style_attention):
        
        # breakpoint()

        input_text = torch.tensor(input_text).unsqueeze(0).to(device)
        input_style = torch.tensor(input_style).to(device)
        input_style_attention = torch.tensor(input_style_attention).to(device)

        input_text_lengths = [np.shape(x)[-1] for x in input_text]
        

        new_samples = {}
        new_samples['input_text'] = input_text
        new_samples['input_style'] = input_style
        new_samples['input_style_attention'] = input_style_attention

        new_samples['input_text_lengths'] = torch.from_numpy(np.array(input_text_lengths)).to(device)
       

        with torch.no_grad():
            model_out = self.model.generate_content_tmp(
                batch=new_samples,
                filter_ratio=0,
                replicate=1,  # 每个样本重复多少次?
                content_ratio=1,
                return_att_weight=False)  # B x C x H x W
    

            content = model_out['token_pred'].cpu()  #

        # breakpoint()

        return content
          

if __name__ == "__main__":

    args = get_args()

    config_path = args.config_path
    save_root_ = '/home/disk2/nips/Result/controlspeech/infer'
    key_words = args.output_key
    save_root1 = os.path.join(save_root_, key_words + '_samples')
    os.system("rm -r %s"%(save_root1))
    os.makedirs(save_root1, exist_ok=True)

    save_root = os.path.join(save_root1, "style_codec")

    os.makedirs(save_root, exist_ok=True)

    pretrained_model_path = args.checkpoint
    Vccm_Infer = Vccm_Infer(config=config_path, path=pretrained_model_path)

    # inferdata input 
    input_path = "/home/disk2/nips/Data/2024nips/0509/test/tmp/librittsid_1"
    aa = glob.glob(os.path.join(input_path,"*.json"))
    input_id = [i.split('/')[-1].split('.')[0] for i in aa]

    infer_num = len(input_id)

    infer_num = 50



    # breakpoint()
    # 先保存一轮相应的风格描述和文本内容

    with open(args.index_id_list,'r') as fin:
        index_id_list = fin.readlines()

    index_id_list_test = {}

    for i in range(len(index_id_list)):
        index_id_list_test[str(index_id_list[i].strip().split('|')[0])]=index_id_list[i].strip().split('|')[1]

    df = pd.read_csv(args.text_style_csv)

    id_text_list_test = {}
    id_style_list_test = {}
    id_index_list_test = {}

    for i in range(len(df.values)):
        id_text_list_test[df["item_name"][i]] = df["txt"][i]
        id_style_list_test[df["item_name"][i]] = df["style_prompt"][i]
        id_index_list_test[df["item_name"][i]] = df["new_id"][i]

    # breakpoint()

    with open(save_root1+"/"+"libritest_text_style",'w') as fin:
        for i in range(infer_num):
            if(len(input_id[i]) <= 15):
                fin.write(input_id[i]+"|"+id_index_list_test[index_id_list_test[str(input_id[i])]] + "|" + id_text_list_test[index_id_list_test[str(input_id[i])]]+"\n")
                fin.write(input_id[i]+"|"+id_index_list_test[index_id_list_test[str(input_id[i])]] + "|" + id_style_list_test[index_id_list_test[str(input_id[i])]]+"\n"+"\n")
            else:
                fin.write(input_id[i]+"|"+id_index_list_test[str(input_id[i])] + "|" + id_text_list_test[str(input_id[i])]+"\n")
                fin.write(input_id[i]+"|"+id_index_list_test[str(input_id[i])] + "|" + id_style_list_test[str(input_id[i])]+"\n"+"\n")

    # old
    # text_input_path = "/home/disk2/nips/Data/2024nips/infer/libri_textlist_test"
    # style_input_path = "/home/disk2/nips/Data/2024nips/infer/libri_stylelist_test"
    

    # with open(text_input_path, 'r') as fin1:
    #     with open(style_input_path, 'r') as fin2:
    #         text_input_raw = fin1.readlines()
    #         style_input_raw = fin2.readlines()
    
    # text_input_raw_dict={}
    # style_input_raw_dict={}

    # for i in range(len(text_input_raw)):
    #     text_input_raw_dict[text_input_raw[i].split('|')[0]] = text_input_raw[i].strip().split('|')[1]

    # for i in range(len(style_input_raw)):
    #     style_input_raw_dict[style_input_raw[i].split('|')[0]] = style_input_raw[i].strip().split('|')[1]

    # with open(save_root1+"/"+"libritest_text_style",'w') as fin:
    #     for i in range(infer_num):
    #         fin.write(input_id[i]+"|"+text_input_raw_dict[input_id[i]]+"\n")
    #         fin.write(input_id[i]+"|"+style_input_raw_dict[input_id[i]]+"\n"+"\n")
    
    # breakpoint()

    # 生成相应的音频
    for i in range(infer_num):
        with open(os.path.join(input_path,input_id[i]+".txt"), 'r') as fin:
            input_text = fin.readlines()

        # with open(os.path.join(input_path,input_id[i]+".log"), 'r') as fin:
        #     input_style = fin.readlines()

        with open(os.path.join(input_path,input_id[i]+".json"),'r') as fin:
            data_tmp = json.load(fin)

        input_style = data_tmp["style_token"]
        input_style_attention = data_tmp["attention_mask"]

        input_text = [int(text_i) for text_i in input_text[0].strip().split(' ')]
        # input_style =[int(style_i) for style_i in input_style[0].strip().split(' ')]

        output_codec = Vccm_Infer.generate_sample(input_text, input_style, input_style_attention)  # [t,n_q]

        tmp_output_path = os.path.join(save_root, input_id[i]+".pth")

        torch.save(output_codec, tmp_output_path)

        # breakpoint()
        # breakpoint()















