import sys
sys.path.append("/home/disk1/nips/speech/code/Amphion")

from models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2

import torch
import os
import librosa

from einops import rearrange

import soundfile as sf

import argparse

import pandas as pd

def get_args():
    parser = argparse.ArgumentParser(description='Generated Pretrained Language Model')
    parser.add_argument('--output_key',
                        type=str,
                        default="vccm_base_epoch9",
                        help='the name of output path')
   
    args = parser.parse_args()
    return args

args = get_args()

device=torch.device('cuda:0')

fa_encoder_v2 = FACodecEncoderV2(
    ngf=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
).to(device)

fa_decoder_v2 = FACodecDecoderV2(
    in_channels=256,
    upsample_initial_channel=1024,
    ngf=32,
    up_ratios=[5, 5, 4, 2],
    vq_num_q_c=2,
    vq_num_q_p=1,
    vq_num_q_r=3,
    vq_dim=256,
    codebook_dim=8,
    codebook_size_prosody=10,
    codebook_size_content=10,
    codebook_size_residual=10,
    use_gr_x_timbre=True,
    use_gr_residual_f0=True,
    use_gr_residual_phone=True,
).to(device)


encoder_v2_ckpt = "/home/disk1/nips/speech/code/Amphion/checkpoints/ns3_facodec_encoder_v2.bin"
decoder_v2_ckpt = "/home/disk1/nips/speech/code/Amphion/checkpoints/ns3_facodec_decoder_v2.bin"

fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))


path = "/home/disk2/nips/Result/controlspeech/infer/"+args.output_key+"_samples/style_codec"

outpath = "/home/disk2/nips/Result/controlspeech/infer/"+args.output_key+"_samples/wav_prompt"

index_id_path = "/home/disk2/nips/Data/2024nips/0509/emo_idlist_all"
with open(index_id_path,'r') as fin:
    index_id_list = fin.readlines()

index_id_test_list = {}

for i in range(len(index_id_list)):
    index_id_test_list[index_id_list[i].split('|')[0]] = index_id_list[i].strip().split('|')[1]


os.system("rm -r %s"%(outpath))
os.system("mkdir -p %s"%(outpath))

# breakpoint()

input = os.listdir(path)

vc_path_raw_libri = "/home/disk2/nips/Data/audio"
vc_path_raw_emo = "/home/disk2/nips/Data/emotion/"



# 找到对应的prompt音频

input_test_csv_prompt_path = "/home/disk2/nips/Data/2024nips/0509/libri_test_prompt.csv"

input_test_csv_path = "/home/disk2/nips/Data/2024nips/0509/libri_test_addid.csv"

input_test_csv_prompt = pd.read_csv(input_test_csv_prompt_path)

input_test_csv = pd.read_csv(input_test_csv_path)

prompt_dict = {}
prompt_id_index_dict = {}

input_test = input_test_csv["item_name"]

input_test_prompt = input_test_csv_prompt["item_name"]

for i in range(len(input_test)):
    prompt_dict[input_test[i]] = input_test_prompt[i]
    prompt_id_index_dict[input_test[i]] = input_test_csv["new_id"][i]


for i in range(len(input)):

    # try:
    print(i)
    outpath_i = os.path.join(outpath,input[i].split('.')[0]+".wav") 

    test_wav_path = os.path.join(path,input[i])

    vq_id_a = rearrange(torch.load(test_wav_path),"n m -> m n").unsqueeze(1).to(device)

    # 这里要分情况进行讨论
    if(len(input[i].split('.')[0]) <= 15):
        input_raw_id = prompt_dict[index_id_test_list[input[i].split('.')[0]]]
        outpath_i = os.path.join(outpath,prompt_id_index_dict[index_id_test_list[input[i].split('.')[0]]]+".wav") 
    else:
        input_raw_id = prompt_dict[input[i].split('.')[0]]


    # 找到音色相关的prompt
    # 分情况讨论
    if(input_raw_id[:5] == "CREMA"):
        ravde_list = input_raw_id.split('-')
        vc_path_input = vc_path_raw_emo + ravde_list[0]+'-' + ravde_list[1] + '/' + ravde_list[2] + '/' + ravde_list[3] + '.wav'
    elif(input_raw_id[:5] == "RAVDE"):
        # breakpoint()
        ravde_list = input_raw_id.split('-',2)
        vc_path_input = vc_path_raw_emo + ravde_list[0] + "/" + ravde_list[1] + "/" +ravde_list[2] + ".wav"
    elif(input_raw_id[0] >= '0' and input_raw_id[0] <= '9'):
        prompt_id = prompt_dict[input[i].split('.')[0]]
        vc_path_input = os.path.join(vc_path_raw_libri,prompt_id.split('_')[0]+"/"+prompt_id+".wav")
    else:
        vc_path_input = vc_path_raw_emo+input_raw_id.replace('-','/') + ".wav"

    wav_b = librosa.load(vc_path_input, sr=16000)[0]
    wav_b= torch.from_numpy(wav_b).float().to(device)
    wav_b = wav_b.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():

        enc_out_b = fa_encoder_v2(wav_b)
        prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)

        vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
            enc_out_b, prosody_b, eval_vq=False, vq=True
        )

        # breakpoint()

        # vq_id_a [n_q,1,t]

        vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
        recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)

        sf.write(outpath_i, recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
    # except:
    #     print("有点问题")
    #     continue
