import os
import json
from typing import Iterable
import torch
import argparse

import jsonlines
from typing import Dict
import nemo.collections.asr as nemo_asr
from more_itertools import chunked
from tqdm import tqdm
from nemo.collections.asr.models import EncDecMultiTaskModel

def write_output_jsonl(output_file, data: Dict):
    with jsonlines.open(output_file, mode='a') as writer:
        writer.write(data)


def parakeet_bach_infer(model_name: str, dataset: Iterable, output_path: str, batch_size: int = 16):
    import warnings
    warnings.filterwarnings("ignore", category=FutureWarning)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float32
    if model_name =='nvidia/canary-1b':
        model: nemo_asr.models.EncDecRNNTBPEModel = EncDecMultiTaskModel.from_pretrained(model_name=model_name)
        decode_cfg = model.cfg.decoding
        decode_cfg.beam.beam_size = 1
        model.change_decoding_strategy(decode_cfg)
    else:
        model: nemo_asr.models.EncDecRNNTBPEModel = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name=model_name)
    model.to(device)
    if os.path.isfile(output_path):
        with open(output_path, 'r') as file:
            file_list = [json.loads(d)['file'] for d in file]
        dataset = [d for d in dataset if d not in file_list]

    result = model.transcribe(dataset, batch_size=batch_size)
    if model_name=="nvidia/canary-1b":
        [write_output_jsonl(output_file=output_path, data={"file": f, "text": r}) for (r, f) in zip(result, dataset)]
    else:
        [write_output_jsonl(output_file=output_path, data={"file": f, "text": r}) for (r, f) in zip(result[0], dataset)]


if __name__ == '__main__':
    data_name = ["drop", 'narrativeqa', 'quoref', 'ropes', 'squad1.1', 'squad2.0',  'tatqa'][0]
    parser = argparse.ArgumentParser()
    parser.add_argument("--audio_file_folder", type=str,
                        default=f"../local_dataset/chatqa_speech/{data_name}/train_set")
    parser.add_argument("--model_name", type=str, default="nvidia/canary-1b",
                        choices=['nvidia/canary-1b','nvidia/parakeet-tdt-1.1b'])
    parser.add_argument("--bach_size", type=int, default=32)
    parser.add_argument("--output_path",type=str,default=f"../speech_llm/data/chatqa_speech/original_check/{data_name}/{data_name}_check_new_canary.jsonl")
    args = parser.parse_args()
    output_path = args.output_path
    data_list = os.listdir(args.audio_file_folder)
    data_list = [args.audio_file_folder+"/" + d for d in data_list if 'wav' in d]
    parakeet_bach_infer(model_name=args.model_name, dataset=data_list,
                        output_path=output_path, batch_size=args.bach_size)