import argparse
from concurrent import futures
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

try:
    resampling_method = Image.Resampling.BICUBIC
except AttributeError:
    resampling_method = Image.BICUBIC


def sample_frames(frames_videos, vlen):
    import numpy as np

    acc_samples = min(frames_videos, vlen)
    intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
    ranges = []
    for idx, interv in enumerate(intervals[:-1]):
        ranges.append((interv, intervals[idx + 1] - 1))

    frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]

    return frame_idxs


def concat_h_imgs(im_list, resample=resampling_method):
    min_height = min(im.height for im in im_list)
    im_list_resize = [
        im.resize(
            (int(im.width * min_height / im.height), min_height), resample=resample
        )
        for im in im_list
    ]
    total_width = sum(im.width for im in im_list_resize)
    dst = Image.new("RGB", (total_width, min_height))
    pos_x = 0
    for im in im_list_resize:
        dst.paste(im, (pos_x, 0))
        pos_x += im.width
    return dst


def visualize_path_video(path, n_frames=10):
    video = cv2.VideoCapture(str(path))
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_idxs = sample_frames(n_frames, total_frames)
    frames = []
    for frames_idx in frames_idxs:
        video.set(cv2.CAP_PROP_POS_FRAMES, frames_idx)
        ret, frame = video.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(np.uint8(frame))
            frames.append(pil_image)

    if len(frames) == 0:
        return None

    if n_frames == 1:
        return frames[0]

    return concat_h_imgs(frames)


def extract_frames(video_path, output_dir):
    video_id = Path(video_path).stem

    video_img = visualize_path_video(video_path, 3)
    if video_img is None:
        print(f"Failed to extract frames from {video_path}")
        return

    output_path = str(Path(output_dir) / f"{video_id}.png")
    video_img.save(output_path)


def read_video_list(video_list_file):
    """
    Reads the video list file and returns a list of video paths.
    """
    video_paths = []
    with open(video_list_file, "r") as f:
        for line in f:
            video_path = line.strip()
            video_paths.append(video_path)
    return video_paths


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract middle frames from videos.")
    parser.add_argument("root_dir", help="path to the root directory")
    parser.add_argument(
        "subdirectory", type=int, help="subdirectory to extract videos from"
    )
    parser.add_argument("output_dir", help="path to the output directory")
    parser.add_argument("video_list_file", help="path to the video list file")
    args = parser.parse_args()

    # If csv file, openf with pandas
    if args.video_list_file.endswith(".csv"):
        df = pd.read_csv(args.video_list_file)
        video_paths = list(
            set(df["pth1"].unique().tolist() + df["pth2"].unique().tolist())
        )
    else:
        video_paths = read_video_list(args.video_list_file)

    shard_video_paths = [
        pth for pth in video_paths if Path(pth).parent.name == str(args.subdirectory)
    ]
    output_dir = Path(args.output_dir) / str(args.subdirectory)
    output_dir.mkdir(exist_ok=True)

    for video_path in tqdm(shard_video_paths):
        with futures.ThreadPoolExecutor(max_workers=4) as executor:
            for video_filename in shard_video_paths:
                full_video_path = Path(args.root_dir) / f"{video_filename}.mp4"
                if full_video_path.exists():
                    executor.submit(extract_frames, full_video_path, output_dir)

            # kill the executor
            executor.shutdown(wait=True)
            break
