from pathlib import Path
from tqdm import tqdm
from typing import Literal, Optional, List
import tyro
import ffmpeg
from PIL import Image
import torch
from vhap.data.image_folder_dataset import ImageFolderDataset
from torch.utils.data import DataLoader


def video2frames(video_path: Path, image_dir: Path, keep_video_name: bool=False, target_fps: int=30, n_downsample: int=1):
    print(f'Converting video {video_path} to frames with downsample scale {n_downsample}')
    if not image_dir.exists():
        image_dir.mkdir(parents=True)
    file_path_stem = video_path.stem + '_' if keep_video_name else ''

    probe = ffmpeg.probe(str(video_path))
    
    video_fps = int(probe['streams'][0]['r_frame_rate'].split('/')[0])
    if  video_fps ==0:
        video_fps = int(probe['streams'][0]['avg_frame_rate'].split('/')[0])
        if video_fps == 0:
            # nb_frames / duration
            video_fps = int(probe['streams'][0]['nb_frames']) / float(probe['streams'][0]['duration'])
            if video_fps == 0:
                raise ValueError('Cannot get valid video fps')

    num_frames = int(probe['streams'][0]['nb_frames'])
    video = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
    W = int(video['width'])
    H = int(video['height'])
    w = W // n_downsample
    h = H // n_downsample
    print(f'[Video]  FPS: {video_fps} | number of frames: {num_frames} | resolution: {W}x{H}')
    print(f'[Target] FPS: {target_fps} | number of frames: {round(num_frames * target_fps / int(video_fps))} | resolution: {w}x{h}')

    (ffmpeg
    .input(str(video_path))
    .filter('fps', fps=f'{target_fps}')
    .filter('scale', width=w, height=h)
    .output(
        str(image_dir / f'frame_%05d.jpg'),
        start_number=1,
        qscale=1,  # lower values mean higher quality (1 is the best, 31 is the worst).
    )
    .overwrite_output()
    .run(quiet=True)
    )

def downsample_frames(image_dir: Path, n_downsample: int):
    print(f'Downsample frames in {image_dir} by {n_downsample}')
    assert n_downsample in [2, 4, 8]

    image_paths = sorted(list(image_dir.glob('*.jpg')))
    for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
        # downasample the resolution of images
        img = Image.open(image_path)
        W, H = img.size
        img = img.resize((W // n_downsample, H // n_downsample))
        img.save(image_path)
    
def concat_3(
        input: Path, 
        render_input: Path, 
        strands_input: Path, 
        target_fps: int=25, 
        downsample_scales: List[int]=[],
    ):      

    concat_output_path = render_input / 'concat'
    concat_output_path.mkdir(parents=True, exist_ok=True)
    image_ori_paths = input / 'image_masks_4' / 'hair'
    image_strands_paths = strands_input 
    image_render_paths = render_input / 'renders'
    cams_paths = image_render_paths.glob('cam*')
    for cam_path in tqdm(cams_paths):
        cam_name = cam_path.stem
        tqdm.write(f"Processing cam {cam_path}")
        concat_cam_output_path = concat_output_path / cam_name
        concat_cam_output_path.mkdir(parents=True, exist_ok=True)
        for i, image_render_path in enumerate(sorted(list(cam_path.glob(f'frame*.png')))):
            image_render_name = image_render_path.stem
            image_render = Image.open(image_render_path)
            image_ori = Image.open(image_ori_paths / cam_name / f'{image_render_name}.jpg')
            image_strands = Image.open(image_strands_paths / cam_name / f'{image_render_name}.png')
            concatenated_img = Image.new('RGB', (image_ori.width*3, image_render.height))
            concatenated_img.paste(image_ori, (0, 0))
            concatenated_img.paste(image_strands, (image_ori.width, 0))
            concatenated_img.paste(image_render, (2*image_ori.width, 0))
            concatenated_img.save(concat_cam_output_path / f'{image_render_name}.png')
            
    # extract frames
    concat_output_path = render_input / 'concat'
    cams_paths = concat_output_path.glob('cam*')
    for i, video_path in tqdm(enumerate(cams_paths)):
        tqdm.write(f'Processing video {video_path}')
        # image_dirs = video_path.glob('frame*')
        image_dirs = str(video_path) + '/frame_%05d.png'
        output_video_file = concat_output_path / 'video'
        output_video_file.mkdir(parents=True, exist_ok=True)
        output_video_path = str(output_video_file)+'/' +str(video_path.stem)  + '.mp4'
        tqdm.write(image_dirs)
        tqdm.write(output_video_path)
        ffmpeg.input(image_dirs, framerate=target_fps).output(output_video_path, vcodec='libx264', pix_fmt='yuv420p').run()
        
def concat_2(
        input: Path, 
        render_input: Path, 
        strands_input: Path, 
        target_fps: int=1, 
        downsample_scales: List[int]=[],
    ):      

    concat_output_path = render_input / 'concat'
    concat_output_path.mkdir(parents=True, exist_ok=True)
    image_ori_paths = input / 'image_masks_4' / 'hair'
    image_render_paths = render_input / 'renders'
    cams_paths = image_render_paths.glob('cam*')
    for cam_path in tqdm(cams_paths):
        cam_name = cam_path.stem
        tqdm.write(f"Processing cam {cam_path}")
        concat_cam_output_path = concat_output_path / cam_name
        concat_cam_output_path.mkdir(parents=True, exist_ok=True)
        for i, image_render_path in enumerate(sorted(list(cam_path.glob(f'frame*.png')))):
            image_render_name = image_render_path.stem
            image_render = Image.open(image_render_path)
            image_ori = Image.open(image_ori_paths / cam_name / f'{image_render_name}.jpg')
            concatenated_img = Image.new('RGB', (image_ori.width*2, image_render.height))
            concatenated_img.paste(image_ori, (0, 0))
            concatenated_img.paste(image_render, (image_ori.width, 0))
            concatenated_img.save(concat_cam_output_path / f'{image_render_name}.png')
            
    # extract frames
    concat_output_path = render_input / 'concat'
    cams_paths = concat_output_path.glob('cam*')
    output_video_file = concat_output_path / 'video'
    output_video_file.mkdir(parents=True, exist_ok=True)
    for i, video_path in tqdm(enumerate(cams_paths)):
        tqdm.write(f'Processing video {video_path}')
        # image_dirs = video_path.glob('frame*')
        image_dirs = str(video_path) + '/frame_%05d.png'
        output_video_path = str(output_video_file)+'/' +str(video_path.stem)  + '.mp4'
        tqdm.write(image_dirs)
        tqdm.write(output_video_path)
        ffmpeg.input(image_dirs, framerate=target_fps,start_number=21).output(output_video_path, vcodec='libx264', pix_fmt='yuv420p').run()
            
        



if __name__ == '__main__':
    tyro.cli(concat_2)