import os
import cv2 as cv
import numpy as np
import torch
import imageio
import glob

from torchvision.utils import make_grid
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_pil_image


def prepare_video_to_grid(path, grid_count, grid_size, pad):

    video = cv.VideoCapture(path)
    if grid_count == -1:
        frame_count = int(video.get(cv.CAP_PROP_FRAME_COUNT))
    else:
        frame_count = min(grid_count * pad * grid_size**2, int(video.get(cv.CAP_PROP_FRAME_COUNT)))

    transform = transforms.Compose([
        transforms.ConvertImageDtype(dtype=torch.float),
    ])
    success = True

    max_grid_area = 512*512* grid_size**2
    grids = []
    frames = []

    total_grid = grid_size**2
    for idx in range(frame_count):
        success,image = video.read()
        assert success, 'Video read failed'
        if idx % pad == 0:
            rgb_img = cv.cvtColor(image, cv.COLOR_BGR2RGB)
            rgb_img = np.transpose(rgb_img, (2, 0, 1))
            frames.append(transform(torch.from_numpy(rgb_img)))
            
            if len(frames) == total_grid:
                grid = make_grid(frames, nrow=grid_size, padding=0)
                pil_image = (to_pil_image(grid))
                # no resize
                w,h = pil_image.size
                a = float(np.sqrt((w*h/max_grid_area)))
                w1 = int((w//a)//(grid_size*8))*grid_size*8
                h1 = int((h//a)//(grid_size*8))*grid_size*8
                pil_image= pil_image.resize((w1, h1))
                grids.append(pil_image)

                frames = []

    return grids # list of frames

def prepare_video_to_frames(path, grid_count, grid_size, pad, format='gif'):
    video = cv.VideoCapture(path)
    
    if grid_count == -1:
        frame_count = int(video.get(cv.CAP_PROP_FRAME_COUNT))
        
    else:
        frame_count = min(grid_count * pad * grid_size**2, int(video.get(cv.CAP_PROP_FRAME_COUNT)))
        
    frame_idx = 0
    frames = []
    frames_grid = []

    dir_path = os.path.dirname(path)
    video_name = path.split('/')[-1].split('.')[0]
    os.makedirs(os.path.join(dir_path, 'frames/', video_name), exist_ok=True)
    os.makedirs(os.path.join(dir_path, 'video/', video_name), exist_ok=True)

    for idx in range(frame_count):
        success,image = video.read()
        assert success, 'Video read failed'
        if idx % pad == 0:
            frames.append(image)

    for frame in frames[:(len(frames)//(grid_size**2)*(grid_size**2))]:
        frames_grid.append(frame)
        cv.imwrite(os.path.join(dir_path, 'frames/', video_name, f'{str(frame_idx).zfill(5)}.png'), frame)
        frame_idx += 1


    if format == 'gif':
        with imageio.get_writer(os.path.join(dir_path, 'video/', f'{video_name}_fc{frame_idx}_pad{pad}_grid{grid_size}.gif'), mode='I') as writer:
            for frame in frames_grid:
                frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
                writer.append_data(frame)
    elif format == 'mp4':
        image_files = sorted(glob.glob(os.path.join(dir_path, 'frames/', video_name, '*.png')))
        images = [imageio.imread(image_file) for image_file in image_files]
        save_file_path = os.path.join(dir_path, 'video/', f'{video_name}_fc{frame_idx}_pad{pad}_grid{grid_size}.mp4')
        imageio.mimsave(save_file_path, images, fps=20)

    return frame_idx # number of frames

