import os.path as osp
import pandas as pd
from tqdm import tqdm
import sacrebleu
import librosa
import sys

import tools.Whisper as Whisper
PRETRAINED_ROOT = ''
model = Whisper.load_model(
        "large", 
        download_root=f"{PRETRAINED_ROOT}/whisper/",
        device='cuda',
        )

language = 'fr' # 'en', 'zh'

table_path = '~/workspace/data/cvss/covost_v2.fr_en.test.tsv'
hyps_path = '~/workspace/test/processed/1b'
table = pd.read_csv(table_path, sep='\t', on_bad_lines='error', quoting=3, doublequote=False, encoding='utf-8')
table = table[:300]
paths = table['path'].apply(lambda x: osp.join(hyps_path, x) + '.wav').to_list()
labels = table['translation'].to_list()

res = []
for audio_path in tqdm(paths, total=len(paths)):
    # audio = Whisper.load_audio(audio_path)
    audio, sr = librosa.load(audio_path, sr=16000)
    audio = Whisper.pad_or_trim(audio)
    mel = Whisper.log_mel_spectrogram(audio, n_mels=128).to(model.device)
    options = Whisper.DecodingOptions(language='en', beam_size=5, without_timestamps=True, fp16=True)
    # options = Whisper.DecodingOptions(language='fr', beam_size=5, without_timestamps=True, fp16=True)
    result = Whisper.decode(model, mel, options)
    text = result.text.strip() 
    res.append(text)

bleu = sacrebleu.corpus_bleu(res, [labels])
print(bleu.score)
sentence_bleu = [sacrebleu.sentence_bleu(hyp, [ref]).score for hyp, ref in zip(res, labels)]

output_tsv = pd.DataFrame({'path': table['path'], 'hyps': res, 'refs': labels, 'bleu': sentence_bleu})
output_tsv.to_csv(f'{hyps_path}/out.tsv', sep='\t', index=False, quoting=3, doublequote=False)
