import librosa
import pandas as pd
from tools.speaker_verification.models.ecapa_tdnn import ECAPA_TDNN_SMALL
import os
import torch.nn.functional as F
import torch
from tqdm import tqdm

def init_model(model_name, checkpoint=None):
    model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
    if checkpoint is not None:
        state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict['model'], strict=False)
    return model

model = init_model(
    model_name='ecapa_tdnn',
    checkpoint='/root/s2st/tools/speaker_verification/ecapa-tdnn.pth'
).cuda().eval()
df = pd.read_csv('/data2/wenetspeech/train.tsv', sep='\t', quoting=3, doublequote=False)
df['confidences'] = None
for i in tqdm(range(1, len(df))):
    path_pre = os.path.join('/data2/wenetspeech/train_l', df.iloc[i-1]['path']) + '.wav'
    path = os.path.join('/data2/wenetspeech/train_l', df.iloc[i]['path']) + '.wav'
    wav_pre, sr = librosa.load(path_pre, sr=16000, offset=df.iloc[i-1]['offset'], duration=df.iloc[i-1]['duration'])
    wav, sr = librosa.load(path, sr=16000, offset=df.iloc[i]['offset'], duration=df.iloc[i]['duration'])
    wav1 = torch.from_numpy(wav_pre).unsqueeze(0).float().cuda()
    wav2 = torch.from_numpy(wav).unsqueeze(0).float().cuda()
    with torch.no_grad():
        emb1 = model(wav1)
        emb2 = model(wav2)
    sim = F.cosine_similarity(emb1, emb2)
    
    df.iloc[i, -1] = sim.item()
    if i % 1000000 == 0:
        print(df.iloc[i-5:i])
        df.to_csv('/data2/wenetspeech/train_confidence.tsv', sep='\t', quoting=3, doublequote=False, index=False)
df.to_csv('/data2/wenetspeech/train_confidence.tsv', sep='\t', quoting=3, doublequote=False, index=False)

# import librosa
# import pandas as pd
# import os
# import torch
# import torch.nn.functional as F
# from tools.speaker_verification.models.ecapa_tdnn import ECAPA_TDNN_SMALL
# from concurrent.futures import ThreadPoolExecutor
# from functools import partial
# from tqdm import tqdm

# CHECKPOINT_PATH = '/root/s2st/tools/speaker_verification/ecapa-tdnn.pth'

# def init_model(checkpoint):
#     model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
#     if checkpoint is not None:
#         state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
#         model.load_state_dict(state_dict['model'], strict=False)
#     return model

# def compute_similarity_chunk(thread_id, df_chunk):
#     device_id = thread_id % torch.cuda.device_count()
#     model = init_model(
#         checkpoint=CHECKPOINT_PATH
#     ).cuda(device_id).eval()
#     similarities = []
#     for i in tqdm(range(1, len(df_chunk))):
#         path_pre = os.path.join('/data2/wenetspeech/train_l', df_chunk.iloc[i-1]['path']) + '.wav'
#         path = os.path.join('/data2/wenetspeech/train_l', df_chunk.iloc[i]['path']) + '.wav'
#         wav_pre, sr = librosa.load(path_pre, sr=16000)
#         wav, sr = librosa.load(path, sr=16000)
#         wav1 = torch.from_numpy(wav_pre).unsqueeze(0).float().cuda(device_id)
#         wav2 = torch.from_numpy(wav).unsqueeze(0).float().cuda(device_id)
#         with torch.no_grad():
#             emb1 = model(wav1)
#             emb2 = model(wav2)
#         sim = F.cosine_similarity(emb1, emb2)
#         similarities.append(sim.item())
#     return similarities

# df = pd.read_csv('/data2/wenetspeech/train.tsv', sep='\t', quoting=3, doublequote=False)
# df['confidences'] = None


# # 分割数据
# num_threads = 4  # 根据需要调整线程数
# chunk_size = len(df) // num_threads
# chunks = [df.iloc[i * chunk_size : (i + 1) * chunk_size] for i in range(num_threads)]

# # 确保除了第一个块，每个块的开始都包含前一个块的最后一条数据
# for i in range(1, num_threads):
#     chunks[i] = pd.concat([df.iloc[(i * chunk_size) - 1: i * chunk_size], chunks[i]])

# # 使用 ThreadPoolExecutor
# with ThreadPoolExecutor(max_workers=num_threads) as executor:
#     futures = [executor.submit(compute_similarity_chunk, i, chunk) for i, chunk in enumerate(chunks)]
#     results = [future.result() for future in futures]

# # 合并结果
# for i in range(num_threads):
#     start_index = i * chunk_size + 1 if i > 0 else 1
#     df.iloc[start_index:start_index + len(results[i]), df.columns.get_loc('confidences')] = results[i]

# # 输出或保存结果
# df.to_csv('/data2/wenetspeech/train_confidence.tsv', sep='\t', quoting=3, doublequote=False, index=False)