# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import tqdm
from statistics import fmean
from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from latentsync.utils.util import red_text
import torch


def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
    syncnet_detector(video_path=video_path, min_track=50)
    crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
    if crop_videos == []:
        raise Exception(red_text(f"Face not detected in {video_path}"))
    av_offset_list = []
    conf_list = []
    for video in crop_videos:
        av_offset, _, conf = syncnet.evaluate(
            video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
        )
        av_offset_list.append(av_offset)
        conf_list.append(conf)
    av_offset = int(fmean(av_offset_list))
    conf = fmean(conf_list)
    print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
    return av_offset, conf


def main():
    parser = argparse.ArgumentParser(description="SyncNet")
    parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
    parser.add_argument("--video_path", type=str, default=None, help="")
    parser.add_argument("--videos_dir", type=str, default="/root/processed")
    parser.add_argument("--temp_dir", type=str, default="temp", help="")

    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    syncnet = SyncNetEval(device=device)
    syncnet.loadParameters(args.initial_model)

    syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")

    if args.video_path is not None:
        syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
    else:
        sync_conf_list = []
        video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
        for video_name in tqdm.tqdm(video_names):
            try:
                _, conf = syncnet_eval(
                    syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
                )
                sync_conf_list.append(conf)
            except Exception as e:
                print(e)
        print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")


if __name__ == "__main__":
    main()
