import os
import sys
import argparse
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname("."), "..")))
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
from peft import PeftModel
import soundfile as sf
import torch


device = "cuda:0"

base_model_name = "parler-tts/parler-tts-mini-v1"

prompt = "PLEASE WRITE YOUR TRANSCRIPTION HERE."

description = "PLEASE WRITE YOUR DESCRIPTION HERE."

lora_path_a = "SET YOUR PEFT MODEL A PATH HERE"
lora_path_b = "SET YOUR PEFT MODEL B PATH HERE"
lora_path_c = "SET YOUR PEFT MODEL B PATH HERE"

adapter_name_a = "LORA_A"
adapter_name_b = "LORA_B"
adapter_name_c = "LORA_C"

output_file = f"./{adapter_name_a}_{adapter_name_b}_{adapter_name_c}_triple.wav"

temperature = 1.0

density = 0.5

adapter_weights = [0.5, 0.5, 0.5]

combination_type = "cat"

gen_kwargs = {
    "do_sample": True,
    "temperature": 1.0,
}

base_model = ParlerTTSForConditionalGeneration.from_pretrained(base_model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

os.makedirs(os.path.dirname(output_file), exist_ok=True)

tokenized_input = tokenizer(description, return_tensors="pt")
input_ids = tokenized_input.input_ids.to(device)
attention_mask = tokenized_input.attention_mask.to(device)

tokenized_prompt = tokenizer(prompt, return_tensors="pt")
prompt_input_ids = tokenized_prompt.input_ids.to(device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(device)

fused_model = PeftModel.from_pretrained(base_model, lora_path_a, adapter_name=adapter_name_a)
fused_model.load_adapter(lora_path_b, adapter_name=adapter_name_b)
fused_model.load_adapter(lora_path_c, adapter_name=adapter_name_c)

weighted_adapter_name = "merge"
fused_model.add_weighted_adapter(
    adapters=[adapter_name_a, adapter_name_b, adapter_name_c],
    weights=adapter_weights,
    density=density,
    adapter_name=weighted_adapter_name,
    combination_type=combination_type
)

fused_model.set_adapter(weighted_adapter_name)
fused_model = fused_model.merge_and_unload()

with torch.no_grad():
    fused_generation = fused_model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        prompt_input_ids=prompt_input_ids,
        prompt_attention_mask=prompt_attention_mask,
        **gen_kwargs
    )

fused_audio_arr = fused_generation.cpu().numpy().squeeze()
sf.write(output_file, fused_audio_arr, base_model.config.sampling_rate)