import os

import cv2
from insightface.app import FaceAnalysis
import torch

app = FaceAnalysis(
    providers=['CUDAExecutionProvider'],
    provider_options=[{'device_id': 0}],
    root='/root/models/insightface_models/',
)
app.prepare(ctx_id=0, det_size=(512, 512))

ground_truth_dir = '/root/programs/TalkingHead-1KH/selected_test_samples/'
# generated_dir = '/root/programs/Wav2Lip/output/talkinghead_1kh/'
generated_dir = '/root/programs/Moore-AnimateAnyone/output/talkinghead_1kh/184000/'

# ground_truth_dir = '/root/datasets/AVSpeech/test/processed/'
# generated_dir ='/root/programs/Wav2Lip/output/avspeech/'
# generated_dir = '/root/programs/Moore-AnimateAnyone/output/avspeech/184000/'

dataset_sims = []
counter = 0
for sample in os.listdir(ground_truth_dir):
    ref_fid = torch.load(os.path.join(ground_truth_dir, sample, 'fi_0_fid.pth'), map_location='cpu')

    frames = []
    video_capture = cv2.VideoCapture(os.path.join(generated_dir, sample + '.mp4'))
    # video_capture = cv2.VideoCapture(os.path.join(ground_truth_dir, sample, 'processed_vid.mp4'))
    while video_capture.isOpened():
        ret, frame = video_capture.read()
        if not ret:
            break
        frames.append(frame)
    sims = []
    for idx, frame in enumerate(frames):
        face_info = app.get(frame)
        if len(face_info) != 1:
            sim = 0.
        else:
            fid = face_info[0].embedding
            sim = torch.cosine_similarity(torch.from_numpy(ref_fid), torch.from_numpy(fid), dim=0)
        sims.append(sim)
    sample_sim = sum(sims) / len(sims)
    print(counter, sample_sim)
    counter += 1
    dataset_sims.append(sample_sim)
print(sum(dataset_sims) / len(dataset_sims))
