import numpy as np
import os
from moviepy.editor import VideoFileClip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
import cv2

def yuv_to_rgb(yuv_file, width, height, fps):
    # Calculate the total number of frames
    frame_size = width * height * 3 // 2
    total_frames = os.path.getsize(yuv_file) // frame_size
    
    # Prepare output video file
    temp_video_file = 'temp_video.mp4'
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(temp_video_file, fourcc, fps, (width, height))

    # Read and process frames
    with open(yuv_file, 'rb') as f:
        for _ in range(total_frames):
            Y = np.fromfile(f, dtype=np.uint8, count=width * height).reshape((height, width))
            U = np.fromfile(f, dtype=np.uint8, count=(width // 2) * (height // 2)).reshape((height // 2, width // 2))
            V = np.fromfile(f, dtype=np.uint8, count=(width // 2) * (height // 2)).reshape((height // 2, width // 2))

            U = cv2.resize(U, (width, height), interpolation=cv2.INTER_LINEAR)
            V = cv2.resize(V, (width, height), interpolation=cv2.INTER_LINEAR)

            YUV = np.stack([Y, U, V], axis=-1)
            RGB = cv2.cvtColor(YUV, cv2.COLOR_YUV2BGR)
            out.write(RGB)

    out.release()
    return temp_video_file

def convert_to_mp4_high_quality(temp_video_file, output_file):
    # Use moviepy to apply high-quality settings
    clip = VideoFileClip(temp_video_file)
    clip.write_videofile(output_file, codec='libx264', preset='slow', ffmpeg_params=['-crf', '18'])

    # Optionally remove the temporary video file
    os.remove(temp_video_file)

def process_directory(input_directory):
    for filename in os.listdir(input_directory):
        if filename.endswith('.yuv'):
            input_file = os.path.join(input_directory, filename)
            output_file = os.path.join(input_directory, filename.replace('.yuv', '.mp4'))

            if '1920x1080' in filename:
                width, height = 1920, 1080
                fps = 120
            elif '3840x2160' in filename:
                width, height = 3840, 2160
                fps = 50
            else:
                print(f"Unknown resolution in filename {filename}.")
                continue

            temp_video_file = yuv_to_rgb(input_file, width, height, fps)
            convert_to_mp4_high_quality(temp_video_file, output_file)

if __name__ == "__main__":
    input_directory = '/data22/aho/UVG_dataset/download/'
    process_directory(input_directory)
