import os
import json
from more_itertools import chunked
from tqdm import tqdm
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
import scipy
import argparse
def sort_dataset_by_query_length(dataset, tokenizer):
    tokenized_samples = [(sample, len(tokenizer(sample['response'])['input_ids'])) for sample in dataset]
    tokenized_samples.sort(key=lambda x: x[1], reverse=True)
    sorted_dataset = [sample for sample, _ in tokenized_samples]
    return sorted_dataset

if __name__ == '__main__':

    data_name = ["drop", 'narrativeqa', 'quoref', 'ropes', 'squad1.1', 'squad2.0',  'tatqa'][0]
    parser=argparse.ArgumentParser()
    parser.add_argument("--model_path",type=str,default="../local_model/parler-tts-large-v1")
    parser.add_argument("--save_dir",type=str,default=f"../local_dataset/{data_name}/train_set")
    parser.add_argument("--input_jsonl_file_path",type=str,default=f"../data/train_set_tts/drop.json")
    parser.add_argument("--speaker_path",type=str,default="../speaker_id.json")
    parser.add_argument("--speaker_min_num",type=int,default=161)
    parser.add_argument("--speaker_max_num",type=int,default=200)
    args=parser.parse_args()

    save_path=args.save_dir
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    model_path=args.model_path
    dataset_path=args.input_jsonl_file_path

    model = ParlerTTSForConditionalGeneration.from_pretrained(model_path, attn_implementation="sdpa").to("cuda")
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)

    with open(dataset_path,'r') as file:
        if dataset_path.split(".")[-1] =="jsonl":
            data=[json.loads(d) for d in file]
        else:
            data=json.load(file)
    speaker_list=[int(d['file'].split("_")[-1].replace(".wav","")) for d in data]
    with open(args.speaker_path,'r') as file:
        speaker=json.load(file)
        speaker_dis=[speaker[i]['description']for i in speaker_list]
    if args.speaker_min_num is not None and args.speaker_max_num is not None:
        data=[d for i,d in enumerate(data) if args.speaker_max_num >= speaker_list[i] >= args.speaker_min_num]
    for idx,d in enumerate(data):
        d['speaker']=speaker_dis[idx]

    data = [d for d in data if not os.path.isfile(f"{save_path}/{d['file']}")]

    dataset = sort_dataset_by_query_length(data, tokenizer=tokenizer)

    dataset = list(chunked(dataset, 1))
    for datas in tqdm(dataset):
        input_text = [d['response'] for d in datas]
        description = [d['speaker'] for d in datas]
        inputs = tokenizer(description, return_tensors="pt", padding=True, return_attention_mask=True).to('cuda')
        prompt = tokenizer(input_text, return_tensors="pt", padding=True, return_attention_mask=True).to('cuda')
        generation = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            prompt_input_ids=prompt.input_ids,
            prompt_attention_mask=prompt.attention_mask,
            do_sample=True,
            return_dict_in_generate=True,
        )

        for i in range(len(datas)):
            file_name=datas[i]['file']

            out_path = f"{save_path}/{file_name}"

            audio = generation.sequences[i, :generation.audios_length[i]]
            try:
                scipy.io.wavfile.write(out_path, rate=feature_extractor.sampling_rate,
                                       data=audio.cpu().numpy().squeeze())
            except:
                print(datas[i])
        del generation