import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from peft import PeftModel
import soundfile as sf
import torch

device = "cuda:0"

base_model_name = "parler-tts/parler-tts-mini-v1"

peft_model_path = "SET YOUR PEFT MODEL PATH HERE"

description = "PLEASE WRITE YOUR DESCRIPTION HERE."

prompt = "PLEASE WRITE YOUR TRANSCRIPTION HERE."

adapter_name_a = "lora_a"

output_file = f"./{adapter_name_a}_single.wav"

base_model = ParlerTTSForConditionalGeneration.from_pretrained(base_model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

os.makedirs(os.path.dirname(output_file), exist_ok=True)

tokenized_input = tokenizer(description, return_tensors="pt")
input_ids = tokenized_input.input_ids.to(device)
attention_mask = tokenized_input.attention_mask.to(device)

tokenized_prompt = tokenizer(prompt, return_tensors="pt")
prompt_input_ids = tokenized_prompt.input_ids.to(device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(device)

gen_kwargs = {
    "do_sample": True,
    "temperature": 1.0,
}

peft_model = PeftModel.from_pretrained(base_model, peft_model_path, adapter_name = adapter_name_a)

peft_model.set_adapter(adapter_name_a)

model = peft_model.merge_and_unload()

print("Start generate...")
with torch.no_grad():
    generation = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        prompt_input_ids=prompt_input_ids,
        prompt_attention_mask=prompt_attention_mask,
        **gen_kwargs
    )

print("Finish generate...")

audio_arr = generation.cpu().numpy().squeeze()

sf.write("single.wav", audio_arr, model.config.sampling_rate)