# See Appendix C.2 in the paper for more details
import os
import torch
import argparse
from torchvision.io import read_video

THRESHOLD=55
def process_video(video_path,theoretical_position, N):

    video, _, _ = read_video(video_path, output_format="TCHW")
    T = video.shape[0]
    video = video.float().div(255)
    differences = []

    # ====== Step 1: static-video filtering ======
    idxs = [0, T//8, T//4, (3*T)//8, T//2, 3*T//4, 7*T//8,T-1]
    sampled = video[idxs].flatten(1)  
    dmat = torch.cdist(sampled, sampled, p=2)  
    iu = torch.triu_indices(len(idxs), len(idxs), 1)  
    mean_pairwise = dmat[iu[0], iu[1]].mean()
    
    print(mean_pairwise.item())
    if mean_pairwise.item() < THRESHOLD:
        print('static video')
        return None

    # ====== Step 2: repeated-frame ratio calculation ======
    # search around the dominant-frequency period for the frame with the minimal $L_2$ distance to the first frame
    start = max(0, 0 + args.theoretical_position - N)
    end = min(0 + args.theoretical_position + N, T-1)
    if start >= end:
        print(f"Invalid initial window for {video_path}")
        return None
    current_frame = video[0].flatten()
    window_frames = video[start:end+1].flatten(1)
    sims = torch.cdist(current_frame.unsqueeze(0), window_frames, p=2)
    best_idx = torch.argmin(sims).item()
    offset = start + best_idx - 0

    max_i = T - 1 - offset
    if max_i < 0:
        print(f"Offset {offset} too large for {T} frames")
        return None
    # compare each frame in this candidate sequence with the corresponding frame at the beginning of the video
    # frames whose $L_2$ distance is below the threshold are counted as repetitions
    for i in range(0, max_i + 1):
        j = i + offset
        frame_i = video[i].flatten()
        frame_j = video[j].flatten()

        l2 = torch.cdist(frame_i.unsqueeze(0), frame_j.unsqueeze(0), p=2)
        if l2 >= THRESHOLD:
            differences.append(1)
        else:
            differences.append(0)

    return (offset+sum(differences)) / T

def main(directory,theoretical_position, N):
    all_means = 0
    cnt=0
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".mp4"):
                video_path = os.path.join(root, file)
                mean = process_video(video_path,theoretical_position, N)
                if mean is not None:
                    print(f"{video_path}: {mean*100:.4f}")
                    cnt+=1
                    all_means += mean

    all_means/=cnt
    
    print(f"Processed videos: {cnt}")
    print(f"Global mean: {100*all_means:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--directory', type=str, required=True)
    parser.add_argument('--N', type=int, required=True)
    parser.add_argument('--theoretical_position', type=int, required=True)
    args = parser.parse_args()

    main(args.directory,args.theoretical_position, args.N)