import cv2
import random
import os
import numpy as np
from pydub import AudioSegment
from moviepy.editor import VideoFileClip

def extract_frames(video_path, mode='train', num_frames=10):
    # Open the video file
    cap = cv2.VideoCapture(video_path)

    if not cap.isOpened():
        print(f"Error: Unable to open video file {video_path}")
        return None

    # Get the FPS and total number of frames
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    
    # Convert start and end times to frame numbers
    start_frame = 0
    end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if end_frame < 10:
        print(f"\n\n LOW frame count in {video_path} - {end_frame}")

    extracted_frames = []
    frame_indices = np.linspace(start_frame, end_frame-2, num=num_frames, dtype=int)

    # Extract the frames
    for frame_index in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        ret, frame = cap.read()
        if ret:
            extracted_frames.append(frame)
            # print(f"Extracted frame at index: {frame_index}")
        else:
            print(f"Error: Failed to read frame at index {frame_index}")
    
    # Release the video capture object
    cap.release()

    return extracted_frames

def extract_audio_from_video(video_path, output_audio_folder, filename):
    try:
        # Load the video using moviepy to extract audio
        video_clip = VideoFileClip(video_path)
        
        # Extract the audio from the video
        audio = video_clip.audio

        # Save the extracted audio as WAV in the specified folder
        audio_output_path = os.path.join(output_audio_folder, f"{filename}.wav")
        audio.write_audiofile(audio_output_path)
        # print(f"Extracted audio saved as {audio_output_path}")

    except Exception as e:
        print(f"Error extracting audio: {e}")

def process_video_from_txt(txt_line, video_directory, output_video_folder, output_audio_folder, mode='train'):
    # Parse the txt line (format: class$filename&good&sttime&endtime)
    parts = txt_line.strip().split(',')
    class_name = parts[2]
    filename = parts[0]

    # Construct the video file path
    video_path = os.path.join(video_directory, mode, f"{filename}.mp4")
    
    # Extract frames from the video
    frames = extract_frames(video_path, mode=mode)

    # Save the frames as jpg in the specified output video folder
    os.makedirs(os.path.join(output_video_folder, filename), exist_ok=True)

    for idx, frame in enumerate(frames):
        frame_filename = os.path.join(output_video_folder, filename, f"frame_{idx + 1}.jpg")
        cv2.imwrite(frame_filename, frame)
        # print(f"Saved frame: {frame_filename}")

    # Extract audio from the video and save to output_audio_folder
    extract_audio_from_video(video_path, output_audio_folder, filename)

def process_dataset(txt_file_path, video_directory, output_base_folder, mode):
    # Create output folders for video and audio
    output_video_folder = os.path.join(output_base_folder, mode, 'video')
    output_audio_folder = os.path.join(output_base_folder, mode, 'audio')
    os.makedirs(output_video_folder, exist_ok=True)
    os.makedirs(output_audio_folder, exist_ok=True)

    # Process each line in the txt file
    counter = 1
    with open(txt_file_path, 'r') as txt_file:
        for line in txt_file:
            print(f"{counter} - Processing: {line}")
            counter += 1
            try:
                process_video_from_txt(line, video_directory, output_video_folder, output_audio_folder, mode)
            except Exception as e:
                failed.append(line)
                print(f"Failed to process line {line} due to: {e}")

# Example usage
train_txt_file = "./train_ks.txt"
test_txt_file = "./test_ks.txt"
video_directory = "/PATH/TO/VIDEO/DIRECTORY"

# Output base folder where train/ and test/ subfolders will be created
output_base_folder = "/PATH/TO/TARGET/DIRECTORY"

# Files containing errors
failed = []

# Process train and test datasets
process_dataset(train_txt_file, video_directory, output_base_folder, mode='train')
process_dataset(test_txt_file, video_directory, output_base_folder, mode='test')

for f in failed:
    print(f)