import numpy as np
import math
import os
import argparse
from pathlib import Path


def get_args():
    parser = argparse.ArgumentParser(description="Concatenate videos.")
    parser.add_argument("--video_dir", type=str, help="Directory containing videos to concatenate")
    parser.add_argument(
        "--output_video", type=str, 
        default="/home/skorokho/coding/voi_gs/tmp/videos/concated_video.mp4", help="Output video file path"
    )
    parser.add_argument(
        "--num_rows", type=int, default=None, help="Number of rows in the grid layout"
    )
    return parser.parse_args()

def get_grid_layout(n):
    cols = math.ceil(math.sqrt(n))
    rows = math.ceil(n / cols)
    return rows, cols

def generate_ffmpeg_grid_command(input_videos, output_path="output.mp4", num_rows=None):
    def generate_layout(rows, cols):
        layout = []
        for i in range(rows):
            for j in range(cols):
                idx = i * cols + j
                if idx < len(input_videos):
                    comb_w = "0" if j == 0 else ("w0+"*j)[:-1]
                    comb_h = "0" if i == 0 else ("h0+"*i)[:-1]
                    layout.append(f"{comb_w}_{comb_h}")
        return "|".join(layout)

    n = len(input_videos)

    if num_rows is not None:
        rows = num_rows
        cols = math.ceil(n / rows)
    else:
        rows, cols = get_grid_layout(n)

    input_files = ' '.join(f'-i "{video}"' for video in input_videos)
    
    layout = generate_layout(rows, cols)
    
    filter_complex = f"xstack=inputs={n}:layout={layout}[v]"
    cmd = f'ffmpeg {input_files} -filter_complex "{filter_complex}" -map "[v]" "{output_path}"'
    return cmd

def find_videos(directory):
    def is_float(value):
        try:
            float(value)
            return True
        except ValueError:
            return False

    # def key_func(path):
    #     name = Path(path).parent.name
    #     return [float(el) for el in name.split("_") if is_float(el)]

    video_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith((".mp4", ".avi", ".mov", ".mkv")):
                video_files.append(os.path.join(root, file))

    # video_files.sort(key=lambda x: key_func(x))
    video_files.sort()
    return video_files


def main():
    args = get_args()
    video_files = find_videos(args.video_dir)
    for video_file in video_files:
        print(f"Found video: {video_file}")
    cmd = generate_ffmpeg_grid_command(video_files, args.output_video, args.num_rows)
    # print(cmd)
    os.system(cmd)

if __name__ == "__main__":
    main()