# 用自己的音色进行prompt


import sys
sys.path.append("/home/disk1/nips/speech/code/Amphion")

from models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2

import torch
import os
import librosa

from einops import rearrange

import soundfile as sf

import argparse

import pandas as pd

def get_args():
    parser = argparse.ArgumentParser(description='Generated Pretrained Language Model')
    parser.add_argument('--output_key',
                        type=str,
                        default="vccm_base_epoch9",
                        help='the name of output path')
   
    args = parser.parse_args()
    return args

args = get_args()

device=torch.device('cuda:0')

fa_encoder_v2 = FACodecEncoderV2(
    ngf=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
).to(device)

fa_decoder_v2 = FACodecDecoderV2(
    in_channels=256,
    upsample_initial_channel=1024,
    ngf=32,
    up_ratios=[5, 5, 4, 2],
    vq_num_q_c=2,
    vq_num_q_p=1,
    vq_num_q_r=3,
    vq_dim=256,
    codebook_dim=8,
    codebook_size_prosody=10,
    codebook_size_content=10,
    codebook_size_residual=10,
    use_gr_x_timbre=True,
    use_gr_residual_f0=True,
    use_gr_residual_phone=True,
).to(device)


encoder_v2_ckpt = "/home/disk1/nips/speech/code/Amphion/checkpoints/ns3_facodec_encoder_v2.bin"
decoder_v2_ckpt = "/home/disk1/nips/speech/code/Amphion/checkpoints/ns3_facodec_decoder_v2.bin"

fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))


path = "/home/disk2/nips/Result/controlspeech/infer/"+args.output_key+"_samples/style_codec"

outpath = "/home/disk2/nips/Result/controlspeech/infer/"+args.output_key+"_samples/wav_promptself"



os.system("rm -r %s"%(outpath))
os.system("mkdir -p %s"%(outpath))

# breakpoint()

input = os.listdir(path)

vc_path_raw_libri = "/home/disk2/nips/Data/audio"
vc_path_raw_emo = "/home/disk2/nips/Data/emotion/"


# 找到对应的prompt音频



input_df = ["1353_121397_000097_000001","2496_156083_000037_000000","7255_291500_000014_000002","6695_252334_000009_000002"]



for i in range(len(input)):

    for j in range(len(input_df)):

        # try:
        print(i,j)
        outpath_i = os.path.join(outpath,input[i].split('.')[0]+"_"+str(j+1)+".wav") 


        input_raw_id = input_df[j]

        # breakpoint()
    
        # 分情况讨论
        if(input_raw_id[:5] == "CREMA"):
            ravde_list = input_raw_id.split('-')
            vc_path_input = vc_path_raw_emo + ravde_list[0]+'-' + ravde_list[1] + '/' + ravde_list[2] + '/' + ravde_list[3] + '.wav'
        elif(input_raw_id[:5] == "RAVDE"):
            # breakpoint()
            ravde_list = input_raw_id.split('-',2)
            vc_path_input = vc_path_raw_emo + ravde_list[0] + "/" + ravde_list[1] + "/" +ravde_list[2] + ".wav"
        elif(input_raw_id[0] >= '0' and input_raw_id[0] <= '9'):
            prompt_id = input[i].split('.')[0]
            vc_path_input = os.path.join(vc_path_raw_libri,input_raw_id.split('_')[0]+"/"+input_raw_id+".wav")
        else:
            vc_path_input = vc_path_raw_emo+input_raw_id.replace('-','/') + ".wav"
        

        test_wav_path = os.path.join(path,input[i])

        vq_id_a = rearrange(torch.load(test_wav_path),"n m -> m n").unsqueeze(1).to(device)

        # breakpoint()

        # vc_path_input = os.path.join(vc_path_raw,input[i].split('_')[0]+"/"+input[i].split('.')[0]+".wav")

        wav_b = librosa.load(vc_path_input, sr=16000)[0]
        wav_b= torch.from_numpy(wav_b).float().to(device)
        wav_b = wav_b.unsqueeze(0).unsqueeze(0)

        with torch.no_grad():

            enc_out_b = fa_encoder_v2(wav_b)
            prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)

            vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
                enc_out_b, prosody_b, eval_vq=False, vq=True
            )

            # breakpoint()

            # vq_id_a [n_q,1,t]

            vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
            recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)

            sf.write(outpath_i, recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
    # except:
    #     print("有点问题")
    #     continue
