import os
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset
from PIL import Image
from io import BytesIO
import av
import bisect

from tqdm import tqdm


class EvalDataSet(Dataset):
    """Dataset for inverse dynamics model

    Contents of Dataset:  If args.load_mp4==True, the data_file_path can be an mp4 file or an qpos file. 
    like:  {task_name}/f'episode_{episode_idx}.mp4', {task_name}/f'episode_{episode_idx}_qpos.pt', where:
    - the mp4 is saved as cv2.VideoWriter({task_name}/f'episode_{episode_idx}.mp4', 
    cv2.VideoWriter_fourcc(*'mp4v'), fps=30, (width=640, height=720)). The length of the video is episode_len.
    - the 14-dim qpos of each trajectory is in the form of torch.zeros([episode_len, dim=14]).dtype(float32).cpu(), 
    saved with torch.save()
    - the json file is saved as {task_name}.json', which contains the information of the trajectory, but the caption
    is "Random Wing", not specfically processed yet (no use for IDM, but useful for video generation model).
    """
    
    def __init__(self, args, dataset_path, disable_pbar=False, type="train", preprocessor=None):
        self.data = []
        self.dataset_path = dataset_path
        self.type = type
        self.height = 720  # 480 + 240
        self.width = 640        
        self.video_path = []
        self.preprocessor = preprocessor
        if self.preprocessor is not None:
            self.preprocessor.set_augmentation_progress(0)

        for task_name in os.listdir(dataset_path):
            task_path = os.path.join(dataset_path, task_name)
            if not os.path.isdir(task_path):
                continue
            for file_name in tqdm(os.listdir(task_path), desc=f"Loading videos from {task_name}", disable=disable_pbar):
                if file_name.endswith('.mp4'):
                    video_path = os.path.join(task_path, file_name)
                    if os.path.islink(video_path):
                        target = os.readlink(video_path)
                        target_abs = os.path.abspath(os.path.join(task_path, target))
                        if not os.path.exists(target_abs):
                            print(f"Broken symlink: {target_abs}")
                        cap = cv2.VideoCapture(target_abs)
                        ok, _ = cap.read()
                        cap.release()
                        if not ok:
                            print(f"Unreadable mp4: {target_abs}")
                        self.video_path.append(target_abs)
                    else:
                        self.video_path.append(video_path)

    def __len__(self):
        return len(self.video_path)

    def get_images(self, idx):
        cap = cv2.VideoCapture(self.video_path[idx])
        success = True
        frames = []
        while success:
            success, frame = cap.read()
            if success:
                frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
        cap.release()
        return frames

    def __getitem__(self, idx):
        images = self.get_images(idx)
        if self.preprocessor is not None:
            images = [self.preprocessor.process_image(image) for image in images]
        return images, self.video_path[idx]
