# Efficient Speech Language Modeling via Energy Distance in Continuous Latent Space

## Usage
**We provide the training and inference code.**

### Installation
``` sh
cd SLED-TTS
pip install -e ./
```

We currently utilize the sum of the first 8 embedding vectors from [Encodec_24khz](https://huggingface.co/facebook/encodec_24khz) as the continuous latent vector. To proceed, ensure that [Encodec_24khz](https://huggingface.co/facebook/encodec_24khz) is downloaded and cached in your HuggingFace dir.


### Training

***Data Processing***

Process the LibriHeavy data so that each line follows the JSON format shown below.
```
{"id": "large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb_5", "start": 610.32, "duration": 19.76, "channel": 0, "supervisions": [{"id": "large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb_5", "recording_id": "large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb", "start": 0, "duration": 19.76, "channel": 0, "language": "English", "speaker": "10022", "text": "Hail! bards triumphant! born in happier days; Immortal heirs of universal praise! Whose honors with increase of ages grow, As streams roll down, enlarging as they flow; Nations unborn your mighty names shall sound, [193] And worlds applaud that must not yet be found!"}], "recording": {"id": "large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb", "sources": [{"type": "file", "channels": [0], "source": "download/librilight/large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb.flac"}], "sampling_rate": 16000, "num_samples": 10575221, "duration": 660.9513125, "channel_ids": [0]}, "type": "MonoCut"}
```

***Training Offline Model***
``` sh
OUTPUT_DIR=./runs/libriheavy
mkdir -p $OUTPUT_DIR
LOG_FILE=${OUTPUT_DIR}/log

BATCH_SIZE=8
UPDATE_FREQ=8
# assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512

torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
    ./scripts/train_libriheavy.py \
    --training_cfg 0.1 \
    --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
    --dataloader_num_workers 8 \
    --dataloader_pin_memory True \
    --remove_unused_columns False \
    --label_names audio_inputs \
    --group_by_speech_length \
    --do_train \
    --do_eval \
    --eval_strategy steps \
    --eval_steps 10000 \
    --prediction_loss_only \
    --per_device_train_batch_size ${BATCH_SIZE} \
    --per_device_eval_batch_size 24 \
    --gradient_accumulation_steps ${UPDATE_FREQ} \
    --bf16 \
    --learning_rate 5e-4 \
    --weight_decay 0.01 \
    --adam_beta1 0.9 \
    --adam_beta2 0.999 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --max_steps 300000 \
    --lr_scheduler_type "linear" \
    --warmup_steps 32000 \
    --logging_first_step \
    --logging_steps 100 \
    --save_steps 10000 \
    --save_total_limit 10 \
    --output_dir ${OUTPUT_DIR} \
    --report_to tensorboard \
    --disable_tqdm True \
    --ddp_timeout 3600 --overwrite_output_dir

```

***Training Streaming Model***
``` sh
OUTPUT_DIR=./runs/libriheavy_stream
mkdir -p $OUTPUT_DIR
LOG_FILE=${OUTPUT_DIR}/log

BATCH_SIZE=8
UPDATE_FREQ=8
# assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512

torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
    ./scripts/train_libriheavy_stream.py \
    --finetune_path ./runs/libriheavy/checkpoint-300000/model.safetensors \
    --stream_n 5 --stream_m 45 \
    --training_cfg 0.1 \
    --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
    --dataloader_num_workers 8 \
    --dataloader_pin_memory True \
    --remove_unused_columns False \
    --label_names audio_inputs \
    --group_by_speech_length \
    --do_train \
    --do_eval \
    --eval_strategy steps \
    --eval_steps 10000 \
    --prediction_loss_only \
    --per_device_train_batch_size ${BATCH_SIZE} \
    --per_device_eval_batch_size 24 \
    --gradient_accumulation_steps ${UPDATE_FREQ} \
    --bf16 \
    --learning_rate 3e-4 \
    --weight_decay 0.01 \
    --adam_beta1 0.9 \
    --adam_beta2 0.999 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --max_steps 100000 \
    --lr_scheduler_type "linear" \
    --warmup_steps 10000 \
    --logging_first_step \
    --logging_steps 100 \
    --save_steps 10000 \
    --save_total_limit 10 \
    --output_dir ${OUTPUT_DIR} \
    --report_to tensorboard \
    --disable_tqdm True \
    --ddp_timeout 3600 --overwrite_output_dir
```

### Inference
``` sh
CHECKPOINT=/path/to/checkpoint
CFG=2.0
SEED=0
```
***Offline Inference***
``` sh
python scripts/run_offline.py \
    --model_name_or_path ${CHECKPOINT} \
    --cfg ${CFG} \
    --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
    --seed ${SEED}
```
***Streaming Inference***
``` sh
python scripts/run_stream.py \
    --model_name_or_path ${CHECKPOINT} \
    --cfg ${CFG} \
    --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
    --seed ${SEED}
# Please note that we have simulated the generation in a streaming environment in run_stream.py for evaluating its quality.
# However, the existing code does not actually provide a streaming API.
```
***Voice Clone***

You can adjust the prompt speech by setting `--prompt_text` and `--prompt_audio`.
``` sh
python scripts/run_voice_clone.py \
    --prompt_text "Were I in the warm room with all the splendor and magnificence!" \
    --prompt_audio "example_prompt.flac" \
    --model_name_or_path ${CHECKPOINT} \
    --cfg ${CFG} \
    --input "Perhaps the other trees from the forest will come to look at me!" \
    --seed ${SEED}
```