# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import os
from pathlib import Path
import torch
import torch.utils.data
import torchvision
from datamodules.video_data_api import VideoDataset
from pytorch_lightning import LightningDataModule
# from pytorchvideo.transforms import Normalize
from torchvision.transforms import Normalize
from transformers import AutoTokenizer, ViTImageProcessor
from torchvision.datasets.folder import default_loader
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader, default_collate
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Normalize
from torchvision.datasets.folder import default_loader
from transformers import AutoTokenizer

class UVGDataset(Dataset):
    def __init__(self, root: str, num_frames: int = 5, normalize=False,
                 clip_name="openai/clip-vit-base-patch32"):
        self.root = Path(root)
        self.num_frames = num_frames
        self.normalize = normalize
        self.tokenizer = AutoTokenizer.from_pretrained(clip_name)

        self.video_clips = []  # List of (folder_path, start_index)
        self.to_tensor = ToTensor()
        self.norm = Normalize(mean=(0.45, 0.45, 0.45), std=(0.225, 0.225, 0.225)) if normalize else lambda x: x

        video_folders = sorted([p for p in self.root.iterdir() if p.is_dir()])
        for folder in video_folders:
            frames = sorted(folder.glob("*.png"))
            # 只读取前100帧
            frames = frames[:100]

            num_total = len(frames)
            for start in range(0, num_total - num_frames + 1, num_frames):
                self.video_clips.append((folder, start))

    def parse_frame_id(self, filename):
        return int(filename.replace("im", ""))

    def get_best_matching_caption(self, clip_frames, txt_root):
        clip_frame_ids = [self.parse_frame_id(Path(f).stem) for f in clip_frames]

        center_id = clip_frame_ids[len(clip_frame_ids) // 2]
        start_id = (center_id // 10) * 10 + 1 
        end_id = start_id + 9

        filename = f"im{start_id:05d}-im{end_id:05d}.txt"
        txt_path = Path(txt_root) / filename

        caption = txt_path.read_text(encoding="utf-8")
        return caption


    def __len__(self):
        return len(self.video_clips)

    def __getitem__(self, idx):
        folder, start = self.video_clips[idx]
        frames = sorted(folder.glob("*.png"))
        clip_frames = frames[start:start + self.num_frames]

        video = [self.norm(self.to_tensor(default_loader(str(f)))) for f in clip_frames]
        video_tensor = torch.stack(video)

        caption_txt_root = Path(str(folder).replace("UVG-96", "UVG-TXT"))
        caption = self.get_best_matching_caption(clip_frames, caption_txt_root)

        tokens = self.tokenizer(caption, padding="max_length", max_length=38, truncation=True, return_tensors="pt")
        token_ids = tokens["input_ids"].squeeze(0)
        attention_mask = tokens["attention_mask"].squeeze(0)

        return video_tensor, token_ids, attention_mask, [str(f) for f in clip_frames]


class UVGTextDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        batch_size: int = 1,
        num_frames: int = 50,
        normalize: bool = False,
        num_workers: int = 2,
        pin_memory: bool = True,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_frames = num_frames
        self.normalize = normalize
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def setup(self, stage: Optional[str] = None):
        self.test_dataset = UVGDataset(
            root=self.data_dir,
            num_frames=self.num_frames,
            normalize=self.normalize
        )

    def test_dataloader(self):
        def custom_collate(batch):
            videos, tokens, masks, clip_names = zip(*batch)
            return {
                "video": VideoDataset(torch.utils.data.default_collate(videos)),
                "tokens": torch.utils.data.default_collate(tokens),
                "attention_masks": torch.utils.data.default_collate(masks),
                "clip_names": clip_names
            }

        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=False,
            collate_fn=custom_collate,
        )
