import torchaudio
import os
import resampy
import soundfile as sf


resample_rate = 16000
idx = 0
# for dirname in os.listdir('/nas/datasets/SLAKH/slakh2100/test'):
#     print(idx)
#     onepath = os.sep.join(['/nas/datasets/SLAKH/slakh2100/test', dirname])
#     mixture = None
#     if os.path.exists(os.sep.join([onepath, 'bass.wav'])):
#         bass, sample_rate = torchaudio.load(os.sep.join([onepath, 'bass.wav']))
#         if mixture is None:
#             mixture = bass
#         else:
#             mixture += bass
#     if os.path.exists(os.sep.join([onepath, 'drums.wav'])):
#         drums, sample_rate = torchaudio.load(os.sep.join([onepath, 'drums.wav']))
#         if mixture is None:
#             mixture = drums
#         else:
#             if drums.shape[-1] != mixture.shape[-1]:
#                 min_len = min(drums.shape[-1], mixture.shape[-1])
#                 mixture = mixture[:, 0:min_len] + drums[:, 0:min_len]
#             else:
#                 mixture += drums
#     if os.path.exists(os.sep.join([onepath, 'guitar.wav'])):
#         guitar, sample_rate = torchaudio.load(os.sep.join([onepath, 'guitar.wav']))
#         if mixture is None:
#             mixture = guitar
#         else:
#             if guitar.shape[-1] != mixture.shape[-1]:
#                 min_len = min(guitar.shape[-1], mixture.shape[-1])
#                 mixture = mixture[:, 0:min_len] + guitar[:, 0:min_len]
#             else:
#                 mixture += guitar
#     if os.path.exists(os.sep.join([onepath, 'piano.wav'])):
#         piano, sample_rate = torchaudio.load(os.sep.join([onepath, 'piano.wav']))
#         if mixture is None:
#             mixture = piano
#         else:
#             if piano.shape[-1] != mixture.shape[-1]:
#                 min_len = min(piano.shape[-1], mixture.shape[-1])
#                 mixture = mixture[:, 0:min_len] + piano[:, 0:min_len]
#             else:
#                 mixture += piano
#     # print(sample_rate)
#     mixture = torchaudio.functional.resample(mixture, sample_rate, resample_rate)
#     torchaudio.save('/nas/datasets/SLAKH/fad/background/{}.wav'.format(idx), mixture, resample_rate)
#     idx += 1

idx = 0
work_path = '/nas/datasets/SLAKH/output/partial_generating/B/opt_no_restrict/sum/'
# work_path = '/nas/datasets/SLAKH/output/partial_generating/B/opt_0.05/sum/'
save_path = '/nas/datasets/SLAKH/fad/opt_no_restrict'
# save_path = '/nas/datasets/SLAKH/fad/reproduce_96'
# save_path = '/nas/datasets/SLAKH/fad/eval_B_msdm'
if not os.path.exists(save_path):
    os.makedirs(save_path)
for dirname in os.listdir(work_path):
    print(idx)
    # if idx == 96:
    #     break
    onepath = os.sep.join([work_path, dirname])
    mixture, sample_rate = sf.read(os.sep.join([onepath, 'mixture.wav']), dtype="float32")
    # mixture = resampy.resample(mixture, sample_rate, resample_rate)
    # if idx < 1500:
    sf.write(os.sep.join([save_path, '{}.wav'.format(idx)]), mixture, sample_rate)

    
    # mixture, sample_rate = sf.read(os.sep.join([onepath, 'gt_mixture.wav']), dtype="float32")
    # # mixture = resampy.resample(mixture, sample_rate, resample_rate)
    # sf.write('/nas/datasets/SLAKH/fad/background/{}.wav'.format(idx), mixture, sample_rate)
    idx += 1