import os
import pickle
import torch
import numpy as np
from math import ceil
from model_vc import Generator
import librosa
import librosa.display
import pylab
import matplotlib.pyplot as plt
def pad_seq(x, base=32):
    len_out = int(base * ceil(float(x.shape[0]) / base))
    len_pad = len_out - x.shape[0]
    assert len_pad >= 0
    return np.pad(x, ((0, len_pad), (0, 0)), 'constant'), len_pad


device = 'cuda:0'
dim_content = 32
dim_speaker = 256
G = Generator(80, 32, 256, 512).eval().to(device)

g_checkpoint = torch.load('ckpt/1000000.pth')
G.load_state_dict(g_checkpoint['model'])
path = '/home/ttsdev/nastts/AVCT/vctk_test_mel_16000'
spect_vc = []


source_utterance = 'p261_003.npy'
source_speaker = 'p261'
target_speaker = 'p268'
target_utterance = 'p268_003.npy'
x_org = np.load(os.path.join(path, source_speaker, source_utterance))
x_trg = np.load(os.path.join(path, target_speaker, target_utterance))
#print(sbmt_i)
print("test x_org:", x_org.shape, type(x_org))
x_org, len_pad = pad_seq(x_org)

uttr_org = torch.from_numpy(x_org[np.newaxis, :, :]).to(device)
uttr_trg = torch.from_numpy(x_trg[np.newaxis, :, :]).to(device)

with torch.no_grad():
    x_identic_psnt = G(uttr_org, uttr_trg, dim_content)
    print("x_identic_psnt shape:", x_identic_psnt.shape)
    if len_pad == 0:
        uttr_trg = x_identic_psnt[0, 0, :, :].cpu().numpy()
    else:
        uttr_trg = x_identic_psnt[0, 0, :-len_pad, :].cpu().numpy()

    spect_vc.append(('{}x{}'.format(source_speaker, target_speaker), uttr_trg))

with open('results.pkl', 'wb') as handle:
    pickle.dump(spect_vc, handle)
