import os

import cv2
from insightface.app import FaceAnalysis
import torch
import numpy as np

app = FaceAnalysis(
    providers=['CUDAExecutionProvider'],
    provider_options=[{'device_id': 0}],
    root='/root/models/insightface_models/',
)
app.prepare(ctx_id=0, det_size=(512, 512))

# ground_truth_dir = '/root/programs/TalkingHead-1KH/selected_test_samples/'
# # generated_dir = '/root/programs/Wav2Lip/output/talkinghead_1kh/'
# generated_dir = '/root/programs/Moore-AnimateAnyone/output/talkinghead_1kh/184000/'

ground_truth_dir = '/root/datasets/AVSpeech/test/processed/'
# generated_dir = '/root/programs/Wav2Lip/output/avspeech/'
generated_dir = '/root/programs/Moore-AnimateAnyone/output/avspeech/184000/'

dataset_sims = []
counter = 0
for sample in os.listdir(ground_truth_dir):
    tgt_kps = torch.load(os.path.join(ground_truth_dir, sample, 'kps.pth'), map_location='cpu')
    tgt_kps = torch.tensor(tgt_kps)

    frames = []
    video_capture = cv2.VideoCapture(os.path.join(generated_dir, sample + '.mp4'))
    # video_capture = cv2.VideoCapture(os.path.join(ground_truth_dir, sample, 'processed_vid.mp4'))
    while video_capture.isOpened():
        ret, frame = video_capture.read()
        if not ret:
            break
        frames.append(frame)

    gen_kps = []
    for idx, frame in enumerate(frames):
        face_info = app.get(frame)
        if len(face_info) != 1:
            continue
        kps = face_info[0].kps[:3]
        gen_kps.append(kps)
    gen_kps = torch.tensor(gen_kps)

    if tgt_kps.shape[0] < gen_kps.shape[0]:
        tgt_kps = torch.nn.functional.interpolate(
            tgt_kps.permute(1, 2, 0),
            size=gen_kps.shape[0],
            mode='linear',
        ).permute(2, 0, 1)
    elif tgt_kps.shape[0] > gen_kps.shape[0]:
        gen_kps = torch.nn.functional.interpolate(
            gen_kps.permute(1, 2, 0),
            size=tgt_kps.shape[0],
            mode='linear',
        ).permute(2, 0, 1)
    sim = (gen_kps - tgt_kps) ** 2
    sim = sim.sum(dim=-1).sqrt().mean()
    # sims.append(np.sqrt(sim.sum(axis=-1)).mean())
    # sample_sim = sum(sims) / len(sims)
    print(counter, generated_dir, sim)
    counter += 1
    dataset_sims.append(sim)
print(generated_dir, sum(dataset_sims) / len(dataset_sims))
