# --coding:utf-8--
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
import numpy as np
import torch.nn.functional as F
import pandas as pd
from evaluation.cnen_chinese2pinyin import text2pinyin
from evaluation.cnen_text2phone_yoyo import text2phone
from txt_processors.en import TxtProcessor
from txt_processors.text_encoder import is_sil_phoneme
import copy

# from Encodec_16k_320.net3 import SoundStream
import sys
# sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
import pydub
import struct
import torch
import cv2
import argparse
import numpy as np
import torchvision
from PIL import Image

from nn_ss.sound_synthesis2.utils.io import load_yaml_config
from nn_ss.sound_synthesis.modeling.build import build_model
from nn_ss.sound_synthesis2.utils.misc import get_model_parameters_info
from utils.checkpoint import load_checkpoint
import datetime
import yaml
import pandas as pd
import numpy as np
import typing as tp
from pathlib import Path
import pyloudnorm as pyln
import re
import soundfile as sf
import glob
import time
# from rnnoise_wrapper import RNNoise
# denoiser = RNNoise()

from encodec.utils import convert_audio
import IPython.display as ipd
import torchaudio
import torch
import json
from transformers import BertTokenizer


device=torch.device('cuda:0')

pattern = r'，|,|。|；|、'

tokenizer = BertTokenizer.from_pretrained('/home/disk2/nips/Data/bert-base-uncased', do_lower_case=True)

txt_processor = TxtProcessor()

json_data = json.load(open("/home/disk1/nips/speech/code/controlspeech/baker_mfa_cnen_mapper_yoyo345.json"))
phone_map = json_data['symbol_to_id']

def txt_to_ph(txt_raw):
    txt_struct, txt = txt_processor.process(txt_raw)
    ph = [p for w in txt_struct for p in w[1]]
    ph_gb_word = ["_".join(w[1]) for w in txt_struct]
    words = [w[0] for w in txt_struct]
    # word_id=0 is reserved for padding
    ph2word = [w_id + 1 for w_id, w in enumerate(txt_struct) for _ in range(len(w[1]))]
    return " ".join(ph), txt, " ".join(words), ph2word, " ".join(ph_gb_word)

### 生成text文件
def gen_lab_file(text):
    *_, ph_gb_word = txt_to_ph(text)
    ph_gb_word_nosil = " ".join(["_".join([p for p in w.split("_") if not is_sil_phoneme(p)])
                                    for w in ph_gb_word.split(" ") if not is_sil_phoneme(w)])

    return ph_gb_word,ph_gb_word_nosil
    


def get_args():
    parser = argparse.ArgumentParser(description='Generated Pretrained Language Model')
    parser.add_argument('--config_path',
                        type=str,
                        default="none",
                        help='the name of config path')

    parser.add_argument('--output_key',
                        type=str,
                        default="vccm_base_epoch9",
                        help='the name of output path')

    parser.add_argument('--checkpoint',
                        type=str,
                        default="aa",
                        help='checkpoint')

    parser.add_argument('--index_id_list',
                        type=str,
                        default="aa",
                        help='index_id_list')

    parser.add_argument('--text_style_csv',
                        type=str,
                        default="aa",
                        help='text_style_csv')
   
    args = parser.parse_args()
    return args


def loudness_filter(data,rate):
    # print(arg)
    # path1, path2 = arg
    # data, rate = sf.read(path1)
    meter = pyln.Meter(rate)  # create BS.1770 meter
    loudness = meter.integrated_loudness(data)
    wav = pyln.normalize.loudness(data, loudness, -25)

    if np.abs(wav).max() > 1.0:
        # print(path1, np.abs(wav).max())
        wav = wav / np.abs(wav).max()
    return wav


def convert_audio_tmp(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
    assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
    if target_channels == 1:
        wav = wav.mean(0, keepdim=True)
    elif target_channels == 2:
        *shape, _, length = wav.shape
        wav = wav.expand(*shape, target_channels, length)
    elif wav.shape[0] == 1:
        wav = wav.expand(target_channels, -1)
    wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
    return wav



def save_audio(wav: torch.Tensor, path: tp.Union[Path, str],
               sample_rate: int, rescale: bool = False):
    limit = 0.99
    mx = wav.abs().max()
    if rescale:
        wav = wav * min(limit / mx, 1)
    else:
        wav = wav.clamp(-limit, limit)
    torchaudio.save(path, wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)


def build_codec_model(config):
    model = eval(config.generator.name)(**config.generator.config)
    return model


class Vccm_Infer():
    def __init__(self, config, path):
        self.info = self.get_model(ema=True, model_path=path, config_path=config)
        self.model = self.info['model']
        # self.epoch = self.info['epoch']
        self.model_name = self.info['model_name']
        self.model = self.model.cuda()
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        self.num_quant = 6
        # self.soundstream = build_soundstream()

    def get_model(self, ema, model_path, config_path):
        if 'OUTPUT' in model_path:  # pretrained model
            model_name = model_path.split(os.path.sep)[-3]
        else:
            model_name = os.path.basename(config_path).replace('.yaml', '')
        # breakpoint()
        config = load_yaml_config(config_path)
        model = build_model(config)  # 加载 dalle model
        model_parameters = get_model_parameters_info(model)  # 参数详情
        print(model_parameters)
        load_checkpoint(model,model_path)

        return {'model': model, 'model_name': model_name, 'parameter': model_parameters}

    def read_tsv(self, val_path):
        train_tsv = pd.read_csv(val_path, sep=',', usecols=[0, 1])
        filenames = train_tsv['file_name']
        captions = train_tsv['caption']
        filenames_ls = []
        captions_ls = []
        for name in filenames:
            filenames_ls.append(name)
        for cap in captions:
            captions_ls.append(cap)
        caps_dict = {}
        for i in range(len(filenames_ls)):
            if filenames_ls[i] not in caps_dict.keys():
                caps_dict[filenames_ls[i]] = [captions_ls[i]]
            else:
                caps_dict[filenames_ls[i]].append(captions_ls[i])
        return caps_dict

    def generate_sample(self, input_text, input_style, input_style_attention):
        
        # breakpoint()

        input_text = torch.tensor(input_text).unsqueeze(0).to(device)
        input_style = torch.tensor(input_style).to(device)
        input_style_attention = torch.tensor(input_style_attention).to(device)

        input_text_lengths = [np.shape(x)[-1] for x in input_text]
        

        new_samples = {}
        new_samples['input_text'] = input_text
        new_samples['input_style'] = input_style
        new_samples['input_style_attention'] = input_style_attention

        new_samples['input_text_lengths'] = torch.from_numpy(np.array(input_text_lengths)).to(device)
       

        with torch.no_grad():
            model_out = self.model.generate_content_tmp(
                batch=new_samples,
                filter_ratio=0,
                replicate=1,  # 每个样本重复多少次?
                content_ratio=1,
                return_att_weight=False)  # B x C x H x W
    

            content = model_out['token_pred'].cpu()  #

        # breakpoint()

        return content
          

if __name__ == "__main__":

    args = get_args()

    config_path = args.config_path
    save_root_ = '/home/disk2/nips/Result/controlspeech/infer'
    key_words = args.output_key
    save_root1 = os.path.join(save_root_, key_words + '_samples')
    os.system("rm -r %s"%(save_root1))
    os.makedirs(save_root1, exist_ok=True)

    save_root = os.path.join(save_root1, "style_codec")

    os.makedirs(save_root, exist_ok=True)

    pretrained_model_path = args.checkpoint
    Vccm_Infer = Vccm_Infer(config=config_path, path=pretrained_model_path)

    input_df_path = "/home/disk2/nips/Data/2024nips/0509/timbre_similar_list"

    with open(input_df_path,'r') as fin:
        input_df_raw = fin.readlines()
    


    style_all = [i.strip().split('|')[1] for i in input_df_raw]
    id_all = [i.strip().split('|')[0] for i in input_df_raw]





    for i in range(len(style_all)):
        # breakpoint()
        # df["txt"][i]
        # df["style_prompt"][i]
        # df["new_id"][i]
        # sent = ""
        encoded_dict = tokenizer.encode_plus (
                        style_all[i],                      # 输入文本
                        add_special_tokens = True, # 添加 '[CLS]' 和 '[SEP]'
                        max_length = 64,  #这个决定了文本的最大长度
                        # padding=True,         # 填充 & 截断长度
                        pad_to_max_length = True,
                        return_attention_mask = True,   # 返回 attn. masks.
                        return_tensors = 'pt',     # 返回 pytorch tensors 格式的数据
                )

    # breakpoint()

        tmp_data = {
            'style_token': encoded_dict['input_ids'].tolist(),
            'attention_mask': encoded_dict['attention_mask'].tolist()
        }

        a,b = gen_lab_file("The patient and the surgeon are both recuperating from the lengthy operation")
        input_text=[1]
        for j in b.split(' '):
            for k in j.split('_'):
                input_text.append(phone_map[k])

        input_text.append(345)


        # fin.write(df["new_id"][i]+"|"+df["txt"][i] +"\n")
        # fin.write(df["new_id"][i]+"|"+df["style_prompt"][i] +"\n\n" )





        output_codec = Vccm_Infer.generate_sample(input_text, tmp_data['style_token'], tmp_data['attention_mask'])  # [t,n_q]

        tmp_output_path = os.path.join(save_root, id_all[i]+".pth")

        torch.save(output_codec, tmp_output_path)
















