import numpy as np
import torch
import cv2
import re
import os
import shutil
import flow_vis
from PIL import Image
from pathlib import Path
from typing import Union


img_extensions = ['.png', '.jpg', '.jpeg', 'heif', 'heic']



def read_video(video_path):
    vidcap = cv2.VideoCapture(video_path)
    frames = []
    while vidcap.isOpened():
        ret, frame = vidcap.read()
        if ret == False:
            break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    vidcap.release()
    return frames


def save_pcd(points, save_path):
    import open3d as o3d
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    o3d.io.write_point_cloud(save_path, pcd)


def read_image(img_path):
    return np.array(Image.open(img_path).convert("RGB"))

def read_images(img_paths):
    return [read_image(img_path) for img_path in img_paths]

def write_image(image: Union[Image.Image, np.ndarray, torch.Tensor], save_path: Union[str, Path]):
    if isinstance(image, torch.Tensor):
        image = image.numpy()
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image.save(str(save_path))


def write_video(frames, save_dir, save_fname=None, frame_lists=None):
    save_frame_dir = Path(save_dir) / "frames"
    save_frame_dir.mkdir(parents=True, exist_ok=True)

    if frame_lists is not None:
        len(frames) == len(frame_lists)

        for frame, frame_fname in zip(frames, frame_lists):
            write_image(frame, save_frame_dir / f"{frame_fname.stem}.png")
        video_path = str(save_dir / f"{save_fname}.mp4")
    else:
        for i, frame in enumerate(frames):
            write_image(frame, save_frame_dir / f"{i:06d}.png")
        video_path = str(save_dir / f"{save_fname}.mp4")

    os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate 10 -pattern_type glob -i "%s/*.png" -c:v libx264 -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -crf 20 -pix_fmt yuv420p %s' % (str(save_frame_dir), video_path))

    # shutil.rmtree(save_frame_dir)
    
    
def read_image_paths(folder: Union[str, Path]):
    folder = Path(folder)
    image_paths = sorted([x for x in folder.iterdir() if x.suffix.lower() in img_extensions], key=lambda x: int(re.search(r'\d+', Path(x).stem).group()))
    image_paths = sorted(image_paths)
    return image_paths


def load_frames(img_dir: Union[str, Path, list]):
    if isinstance(img_dir, (str, Path)):
        frames = read_image_paths(img_dir)
    return [read_image(frame) for frame in frames], frames


def convert_to_flow_image(flow):
    return flow_vis.flow_to_color(flow)