from pathlib import Path
from typing import List, Tuple, Dict, Any
import pandas as pd
import ast
import os
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from PIL import Image
from torch.utils.data import DataLoader
import json
import random
import cv2
import numpy as np
import math
from src.utils.videoreader_pyav import VideoReaderAV
from src.utils.process_reference import process_reference_image

def read_clip_with_resize(
    video_path: str,
    height: int,
    width: int,
    max_num_frames: int,
) -> torch.Tensor:
    video_reader = VideoReaderAV(video_path)
    indices = torch.linspace(0, video_reader.video_num_frames - 1, max_num_frames).int().tolist()
    frames = video_reader.get_batch(indices)
    frames = torch.from_numpy(frames)
    frames = frames.permute(0, 3, 1, 2).contiguous()
    frames = torch.nn.functional.interpolate(
        frames,
        size=(height, width),
        mode='bilinear',
        align_corners=False
    )
    return frames

class Dataset_cogvideo_lora(Dataset):
    def __init__(
        self,
        json_path: str,
        video_id: str,
        max_num_frames: int,
        height: int,
        width: int,
        clip_num: int,
    ) -> None:
        super().__init__()

        meta_info = json.load(open(json_path))[video_id]
        
        self.data_list = []
        for segment in meta_info["segments"]:
            frames = read_clip_with_resize(segment["video_clip_path"], height, width, max_num_frames)
            last_frame = frames[-1]
            last_frame = last_frame.permute(1, 2, 0).cpu().numpy().astype('uint8')
            last_frame_pil = Image.fromarray(last_frame)

            self.data_list.append({
                "video": frames,
                "caption": segment["caption"],
                "last_frame": last_frame_pil,
            })
        
        self.reference_image_path = meta_info["reference"]

        self.clip_num = clip_num
        
        # Set up frame transforms
        self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
        self.__image_transforms = self.__frame_transforms
    
    def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
        return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
    def image_transform(self, image: torch.Tensor) -> torch.Tensor:
        return self.__image_transforms(image)

    def __len__(self) -> int:
        # return len(self.meta_info)
        return 100000000

    def __getitem__(self, index: int) -> Dict[str, Any]:
        random.seed(index)
        selected_segments = random.sample(self.data_list, min(self.clip_num, len(self.data_list)))
        random.shuffle(selected_segments)
        
        clips = []
        captions = []
        last_frames = []
        for segment in selected_segments:
            frames = self.video_transform(segment["video"])
            frames = frames.unsqueeze(0)
            frames = frames.permute(0, 2, 1, 3, 4).contiguous() # Convert to [B, C, F, H, W]
            clip = frames[0]  # Remove batch dimension

            clips.append(clip)
            captions.append(segment["caption"])
            last_frames.append(segment["last_frame"])
        
        return {
            "videos": clips,
            "captions": captions,
            "images": last_frames,
            "reference_images": process_reference_image(self.reference_image_path)
        }

    @staticmethod
    def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch_captions = [item['captions'] for item in batch]
        batch_images = [item['images'] for item in batch]
        batch_reference_images = [item['reference_images'] for item in batch]
        
        all_videos = []
        for item in batch:
            all_videos.extend(item['videos'])
        videos_tensor = torch.stack(all_videos) if all_videos else torch.tensor([])

        return {
            "captions": batch_captions,  # List[List[str]]
            "videos": videos_tensor,       # tensor[N, C, F, H, W] where N is total number of clips
            "images": batch_images,  # List[List[PIL.Image]]
            "reference_images": batch_reference_images,  # List[PIL.Image]
        }