import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
import json
import os
import pdb
import glob
from torchvision.io import read_video
import torch.nn.functional as F

def resize_long_side(image_tensor, target_long_side):
    """
    Resizes a tensor image to have the specified size for its longest side while maintaining the aspect ratio.

    Args:
        image_tensor (torch.Tensor): Input image tensor of dtype torch.uint8 with shape (C, H, W).
        target_long_side (int): Desired size of the longest side of the resized image.

    Returns:
        torch.Tensor: The resized image tensor.
    """
    # Ensure that the input is a 3D tensor and the data type is uint8
    
    # Get the original dimensions of the image
    height, width = image_tensor.shape[-2:]

    # Determine whether the width or height is longer, and compute the new dimensions accordingly
    if width > height:
        scale = target_long_side / width
        new_width = target_long_side
        new_height = int(height * scale)
    else:
        scale = target_long_side / height
        new_height = target_long_side
        new_width = int(width * scale)

    # Reshape the image using bilinear interpolation (align_corners=False is recommended)
    resized_image = F.interpolate(
        image_tensor.float(),  
        size=(new_height, new_width),
        mode='bilinear',
        align_corners=False
    )  # Remove the batch dimension and convert back to uint8

    return resized_image

def find_mp4_files(directory):
    mp4_files = []
    for root, dirs, files in os.walk(directory, followlinks=True):
        for file in files:
            if file.endswith(".mp4") and not file.startswith('._'):
                mp4_files.append(os.path.join(root, file))
    return mp4_files

class StandardVidoDataset(Dataset):
    def __init__(self, video_dir, transform=None, transform_no_resize=None, load_video=True):
        self.video_dir = video_dir
        self.transform = transform
        self.transform_no_resize = transform_no_resize
        self.load_video = load_video
        video_paths = find_mp4_files(video_dir)
        self.video_paths = [v for v in video_paths if not os.path.basename(v).startswith('._')]
    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        if self.load_video:
            try:
                video, _, video_metas = read_video(video_path, pts_unit='sec')
                interval = round(video_metas['video_fps'] / 8)
                if interval > 1:
                    video = video[::interval]
            except BaseException as e:
                print(f'{e}')
                return None
            org_video = video.clone()

            sample = {
                'video_name':os.path.basename(video_path),
                'video': video, 
                'length':video.shape[0],
                'org_video': org_video,
                'video_path': video_path,
                }

            video_no_resize = resize_long_side(video.permute(0,3,1,2).float() / 255, 512)
            sample.update({'video_no_resize': video_no_resize})
            video = self.transform(video.permute(0,3,1,2).float() / 255)
            sample.update({'video': video})
        else:
            sample = {'video_path':video_path}
        return sample

def standard_collate_fn(batch):
    video_names = []
    videos = []
    videos_no_resize = []
    video_lengths = []
    org_videos = []
    video_paths = []
    with_video = True
    for item in batch:
        if item is not None:
            video_paths.append(item['video_path'])
            if 'video' in item:
                video_names.append(os.path.basename(item['video_path']))
                videos.append(item['video'])
                video_lengths.append(item['length'])
                org_videos.append(item['org_video'])
                
                if 'video_no_resize' in item:
                    videos_no_resize.append(item['video_no_resize'])
            else:
                with_video = False
                
    if with_video:
        videos = torch.cat(videos, dim=0)
    
    return {
        'video_names': video_names,
        'videos': videos,
        'video_lengths': video_lengths,
        'org_videos': org_videos,
        'videos_no_resize': videos_no_resize,
        'video_paths': video_paths,
    }



if __name__ == '__main__':
    video_folder = '/path/to/videos'
    transform = Compose([
        Resize((224, 224)),
    ])

    dataset = StandardVidoDataset(video_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=standard_collate_fn)

    for i, data in enumerate(dataloader):
        videos = data['videos']
        org_videos = data['org_videos']
        print(f"Batch {i+1}:")
        print(f"Videos shape: {videos.shape}")
        print(f'len(org_videos): {len(org_videos)}')
        if i == 1:
            break
