import pandas as pd
import librosa
import os.path as osp
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append('~/workspace/code')
from tools.silero_vad.utils_vad import get_speech_timestamps


table_path = '~/workspace/data/cvss/covost_v2.fr_en.test.tsv'
table = pd.read_csv(table_path, sep='\t', on_bad_lines='error', quoting=3, doublequote=False, encoding='utf-8')
table = table[:300]
table['src_path'] = table['path'].map(lambda x: f'~/workspace/data/cvss/extracted/fr/clips/{x}')
table['tgt_path'] = table['path'].map(lambda x: f'~/workspace/test/processed/1b/{x}' + '.wav')


def load_vad_model():
    root = osp.dirname(osp.dirname(osp.abspath(__file__)))
    vad_model_path = osp.join(root, '~/workspace/code/tools/silero_vad')
    model, utils = torch.hub.load(repo_or_dir=vad_model_path,
                              model='silero_vad',
                              force_reload=True,
                              source='local')
    return model

model = load_vad_model().cuda()

src_durations = []
tgt_durations = []
for src_path, tgt_path in tqdm(zip(table['src_path'], table['tgt_path'])):
    src_audio, sr = librosa.load(src_path, sr=16000)
    tgt_audio, sr = librosa.load(tgt_path, sr=16000)
    try:
        src_speech_timestamps = get_speech_timestamps(torch.tensor(src_audio).cuda(), model, sampling_rate=16000)
        tgt_speech_timestamps = get_speech_timestamps(torch.tensor(tgt_audio).cuda(), model, sampling_rate=16000)
        src_start = src_speech_timestamps[0]['start']
        src_end = src_speech_timestamps[-1]['end']
        tgt_start = tgt_speech_timestamps[0]['start']
        tgt_end = tgt_speech_timestamps[-1]['end']
        src_duration = (src_end - src_start) / 16000
        tgt_duration = (tgt_end - tgt_start) / 16000
    except:
        src_duration = len(src_audio) / 16000
        tgt_duration = len(tgt_audio) / 16000
    src_durations.append(src_duration)
    tgt_durations.append(tgt_duration)


duration_ratios = np.array(tgt_durations) / np.array(src_durations)
print(duration_ratios.tolist())
# plot histogram

plt.hist(duration_ratios, bins=np.arange(0, 2, 0.1))
plt.savefig('duration_ratios.png')

SLC_2 = ((0.8 < duration_ratios) & (duration_ratios < 1.2)).astype(int).sum() / duration_ratios.shape[0]

SLC_4 = ((0.6 < duration_ratios) & (duration_ratios < 1.4)).astype(int).sum() / duration_ratios.shape[0]

print(f"SLC_2: {SLC_2}")
print(f"SLC_4: {SLC_4}")