# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pickle
import traceback
import warnings

import numpy as np
import torch
from decord import VideoReader, cpu
from torch.utils.data import Dataset as TorchDataset
from torchvision import transforms as T
from torchvision.transforms.v2 import UniformTemporalSubsample
from pathlib import Path
from cosmos_predict2.data.dataset_utils import Resize_Preprocess, ToTensorVideo
from imaginaire.utils import log
import random

"""
Test the dataset with the following command:
python -m cosmos_predict2.data.dataset_video
"""


class Dataset(TorchDataset):
    def __init__(
        self,
        dataset_dir,
        num_frames,
        video_size,
        target_fps,
    ):
        """Dataset class for loading image-text-to-video generation data.

        Args:
            dataset_dir (str): Base path to the dataset directory
            num_frames (int): Number of frames to load per sequence
            video_size (list): Target size [H,W] for video frames

        Returns dict with:
            - video: RGB frames tensor [T,C,H,W]
            - video_name: Dict with episode/frame metadata
        """

        super().__init__()
        self.dataset_dir = dataset_dir
        self.sequence_length = num_frames
        self.target_fps = target_fps
        self.repeat = 1000


        video_dir = os.path.join(self.dataset_dir, "videos")
        self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl")

        self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")]
        self.video_paths = sorted(self.video_paths)
        # remove video paths that does not have t5_embedding
        self.video_paths = [
            path
            for path in self.video_paths
            if os.path.exists(os.path.join(self.t5_dir, os.path.basename(path).replace(".mp4", ".pickle")))
        ]
        log.info(f"{len(self.video_paths)} videos in total")

        self.wrong_number = 0
        self.preprocess = T.Compose([ToTensorVideo(), Resize_Preprocess(tuple(video_size))])

    def __str__(self):
        return f"{len(self.video_paths)} samples from {self.dataset_dir}"

    def __len__(self):
        return len(self.video_paths) * self.repeat

    def _load_video(self, video_path) -> tuple[np.ndarray, float]:
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
        video_fps = vr.get_avg_fps()
        target_fps = self.target_fps
        frame_stride = video_fps / target_fps
        total_frames = len(vr)
        seg_num_frames = int(self.sequence_length * frame_stride)
        if total_frames <= seg_num_frames:
            # If there are not enough frames, repeat them
            frame_ids = np.linspace(0, total_frames - 1, self.sequence_length, dtype=int)
            frame_ids = frame_ids.tolist()
            act_fps = self.sequence_length / (total_frames / video_fps)
        else:
            # randomly sample a sequence of frames
            max_start_idx = total_frames - seg_num_frames
            start_frame = np.random.randint(0, max_start_idx)
            end_frame = start_frame + seg_num_frames
            frame_ids = np.linspace(start_frame, end_frame, self.sequence_length, dtype=int)
            frame_ids = frame_ids.tolist()
            act_fps = self.sequence_length / (seg_num_frames / video_fps)

        frame_data = vr.get_batch(frame_ids).asnumpy()
        vr.seek(0)  # set video reader point back to 0 to clean up cache

        del vr  # delete the reader to avoid memory leak
        return frame_data, act_fps

    def _get_frames(self, video_path: str) -> tuple[torch.Tensor, float]:
        frames, fps = self._load_video(video_path)
        frames = frames.astype(np.uint8)
        frames = torch.from_numpy(frames).permute(0, 3, 1, 2)  # [T, C, H, W]
        frames = self.preprocess(frames)
        frames = torch.clamp(frames * 255.0, 0, 255).to(torch.uint8)
        return frames, fps

    def __getitem__(self, index):
        index = index % len(self.video_paths)
        try:
            data = dict()
            video, fps = self._get_frames(self.video_paths[index])
            video = video.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            video_path = self.video_paths[index]
            t5_embedding_path = os.path.join(
                self.t5_dir,
                os.path.basename(video_path).replace(".mp4", ".pickle"),
            )
            data["video"] = video
            data["video_name"] = {
                "video_path": video_path,
                "t5_embedding_path": t5_embedding_path,
                # "start_frame_id": '0',
            }

            # Just add these to fit the interface
            # t5_embedding = np.load(sample["t5_embedding_path"])[0]
            with open(t5_embedding_path, "rb") as f:
                t5_embedding = pickle.load(f)[0]  # [n_tokens, 1024]
            n_tokens = t5_embedding.shape[0]
            if n_tokens < 512:
                t5_embedding = np.concatenate(
                    [t5_embedding, np.zeros((512 - n_tokens, 1024), dtype=np.float32)], axis=0
                )
            t5_text_mask = torch.zeros(512, dtype=torch.int64)
            t5_text_mask[:n_tokens] = 1

            data["t5_text_embeddings"] = torch.from_numpy(t5_embedding)
            data["t5_text_mask"] = t5_text_mask
            data["fps"] = fps
            data["image_size"] = torch.tensor([704, 1280, 704, 1280])
            data["num_frames"] = self.sequence_length
            data["padding_mask"] = torch.zeros(1, 704, 1280)

            return data
        except Exception as e:
            warnings.warn(
                f"Invalid data encountered: {self.video_paths[index]}. Skipped "
                f"(by randomly sampling another sample in the same dataset)."
            )
            warnings.warn("FULL TRACEBACK:")
            warnings.warn(traceback.format_exc())
            self.wrong_number += 1
            log.info(self.wrong_number, rank0_only=False)
            raise e
            return self[np.random.randint(len(self.video_paths))]




class WanDataset(Dataset):
    def __init__(
        self,
        dataset_dir,
        num_frames,
        video_size,
        target_fps,
    ):
        super().__init__(dataset_dir=dataset_dir, num_frames=num_frames, video_size=video_size, target_fps=target_fps)
        self.dataset_dir = dataset_dir
        self.sequence_length = num_frames

        video_dir = os.path.join(self.dataset_dir, "videos")
        self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl")

        self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")]
        self.video_paths = sorted(self.video_paths)
        log.info(f"{len(self.video_paths)} videos in total")
        self.wan_umt5_dir = os.path.join(self.dataset_dir, "umt5_wan")

    def __getitem__(self, index):
        index = index % len(self.video_paths)
        try:
            data = dict()
            video, fps = self._get_frames(self.video_paths[index])
            video = video.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            video_path = self.video_paths[index]
            umt5_path = os.path.join(
                self.wan_umt5_dir,
                os.path.basename(video_path).replace(".mp4", ".pt"),
            )
            data["video"] = video
            data["video_name"] = {
                "video_path": video_path,
                "umt5_path": umt5_path,
                # "start_frame_id": '0',
            }

            # Just add these to fit the interface
            umt5_embedding = torch.load(umt5_path)
            n_tokens = umt5_embedding.shape[0]
            if n_tokens < 512:
                umt5_embedding = torch.cat([umt5_embedding, torch.zeros((512 - n_tokens, umt5_embedding.shape[1]), dtype=umt5_embedding.dtype)], dim=0)
            t5_text_mask = torch.zeros(512, dtype=torch.int64)
            t5_text_mask[:n_tokens] = 1

            data["umt5_embedding"] = umt5_embedding
            data["t5_text_mask"] = t5_text_mask
            data["fps"] = fps
            data["image_size"] = torch.tensor([704, 1280, 704, 1280])
            data["num_frames"] = self.sequence_length
            data["padding_mask"] = torch.zeros(1, 704, 1280)

            return data
        except Exception as e:
            warnings.warn(
                f"Invalid data encountered: {self.video_paths[index]}. Skipped "
                f"(by randomly sampling another sample in the same dataset)."
            )
            warnings.warn("FULL TRACEBACK:")
            warnings.warn(traceback.format_exc())
            self.wrong_number += 1
            log.info(self.wrong_number, rank0_only=False)
            raise e
            return self[np.random.randint(len(self.video_paths))]
        


class WanDatasetMask(WanDataset):
    def __init__(
        self,
        dataset_dir,
        num_frames,
        video_size,
        target_fps,
    ):
        super().__init__(dataset_dir=dataset_dir, num_frames=num_frames, video_size=video_size, target_fps=target_fps)
        self.mask_dir = os.path.join(self.dataset_dir, "mask_videos")
        self.video_dir = os.path.join(self.dataset_dir, "videos")
        self.mask_paths = [os.path.join(self.mask_dir, f) for f in os.listdir(self.mask_dir) if f.endswith(".mp4")]
        self.mask_paths = sorted(self.mask_paths)
        for mask_path, video_path in zip(self.mask_paths, self.video_paths):
            assert Path(mask_path).stem == Path(video_path).stem
        log.info(f"{len(self.video_paths)} videos in total")

    def __getitem__(self, index):
        index = index % len(self.video_paths)
        try:
            data = dict()
            video, fps = self._get_frames(self.video_paths[index])
            masks, fps = self._get_frames(self.mask_paths[index])
            video = video.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            masks = masks.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            video_path = self.video_paths[index]
            umt5_path = os.path.join(
                self.wan_umt5_dir,
                os.path.basename(video_path).replace(".mp4", ".pt"),
            )
            data["video"] = video
            data["masks"] = masks
            data["video_name"] = {
                "video_path": video_path,
                "umt5_path": umt5_path,
                # "start_frame_id": '0',
            }

            # Just add these to fit the interface
            umt5_embedding = torch.load(umt5_path)
            n_tokens = umt5_embedding.shape[0]
            if n_tokens < 512:
                umt5_embedding = torch.cat([umt5_embedding, torch.zeros((512 - n_tokens, umt5_embedding.shape[1]), dtype=umt5_embedding.dtype)], dim=0)
            t5_text_mask = torch.zeros(512, dtype=torch.int64)
            t5_text_mask[:n_tokens] = 1

            data["umt5_embedding"] = umt5_embedding
            data["t5_text_mask"] = t5_text_mask
            data["fps"] = fps
            data["image_size"] = torch.tensor([704, 1280, 704, 1280])
            data["num_frames"] = self.sequence_length
            data["padding_mask"] = torch.zeros(1, 704, 1280)

            return data
        except Exception as e:
            warnings.warn(
                f"Invalid data encountered: {self.video_paths[index]}. Skipped "
                f"(by randomly sampling another sample in the same dataset)."
            )
            warnings.warn("FULL TRACEBACK:")
            warnings.warn(traceback.format_exc())
            self.wrong_number += 1
            log.info(self.wrong_number, rank0_only=False)
            raise e
            return self[np.random.randint(len(self.video_paths))]
 

class WanDatasetCFG(Dataset):
    def __init__(
        self,
        dataset_dir,
        num_frames,
        video_size,
        target_fps,
        uncond_rate,
        uncond_path="datasets/umt5_uncond.pt"
    ):
        super().__init__(dataset_dir=dataset_dir, num_frames=num_frames, video_size=video_size, target_fps=target_fps)
        self.dataset_dir = dataset_dir
        self.sequence_length = num_frames
        self.uncond_rate = uncond_rate
        self.uncond_path = uncond_path
        assert uncond_rate >=0 and uncond_rate < 1

        video_dir = os.path.join(self.dataset_dir, "videos")
        self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl")

        self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")]
        self.video_paths = sorted(self.video_paths)
        log.info(f"{len(self.video_paths)} videos in total")
        self.wan_umt5_dir = os.path.join(self.dataset_dir, "umt5_wan")

    def __getitem__(self, index):
        index = index % len(self.video_paths)
        try:
            data = dict()
            video, fps = self._get_frames(self.video_paths[index])
            video = video.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            video_path = self.video_paths[index]
            umt5_path = os.path.join(
                self.wan_umt5_dir,
                os.path.basename(video_path).replace(".mp4", ".pt"),
            )
            data["video"] = video
            data["video_name"] = {
                "video_path": video_path,
                "umt5_path": umt5_path,
                # "start_frame_id": '0',
            }

            # Just add these to fit the interface
            if random.random() < self.uncond_rate:
                umt5_embedding = torch.load(self.uncond_path)
            else:
                umt5_embedding = torch.load(umt5_path)
            n_tokens = umt5_embedding.shape[0]
            if n_tokens < 512:
                umt5_embedding = torch.cat([umt5_embedding, torch.zeros((512 - n_tokens, umt5_embedding.shape[1]), dtype=umt5_embedding.dtype)], dim=0)
            t5_text_mask = torch.zeros(512, dtype=torch.int64)
            t5_text_mask[:n_tokens] = 1

            data["umt5_embedding"] = umt5_embedding
            data["t5_text_mask"] = t5_text_mask
            data["fps"] = fps
            data["image_size"] = torch.tensor([704, 1280, 704, 1280])
            data["num_frames"] = self.sequence_length
            data["padding_mask"] = torch.zeros(1, 704, 1280)

            return data
        except Exception as e:
            warnings.warn(
                f"Invalid data encountered: {self.video_paths[index]}. Skipped "
                f"(by randomly sampling another sample in the same dataset)."
            )
            warnings.warn("FULL TRACEBACK:")
            warnings.warn(traceback.format_exc())
            self.wrong_number += 1
            log.info(self.wrong_number, rank0_only=False)
            raise e
            return self[np.random.randint(len(self.video_paths))]


class WanDatasetCFGMask(WanDatasetCFG):
    def __init__(
        self,
        dataset_dir,
        num_frames,
        video_size,
        target_fps,
        uncond_rate,
        uncond_path="datasets/umt5_uncond.pt",
        padding_frames: int = 0
    ):
        super().__init__(dataset_dir=dataset_dir, num_frames=num_frames, video_size=video_size, target_fps=target_fps, uncond_rate=uncond_rate, uncond_path=uncond_path)
        self.mask_dir = os.path.join(self.dataset_dir, "mask_videos")
        self.mask_paths = [os.path.join(self.mask_dir, f) for f in os.listdir(self.mask_dir) if f.endswith(".mp4")]
        self.mask_paths = sorted(self.mask_paths)
        for mask_path, video_path in zip(self.mask_paths, self.video_paths):
            assert Path(mask_path).stem == Path(video_path).stem
        log.info(f"{len(self.video_paths)} videos in total")
        self.padding_frames = padding_frames

    def _load_video(self, video_path) -> tuple[np.ndarray, float]:
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
        video_fps = vr.get_avg_fps()
        target_fps = self.target_fps
        frame_stride = video_fps / target_fps
        actual_total_frames = len(vr)
        total_frames = actual_total_frames + self.padding_frames
        
        seg_num_frames = int(self.sequence_length * frame_stride)
        if total_frames <= seg_num_frames:
            # If there are not enough frames, repeat them
            frame_ids = np.linspace(0, total_frames - 1, self.sequence_length, dtype=int)
            frame_ids = frame_ids.tolist()
            act_fps = self.sequence_length / (total_frames / video_fps)
        else:
            # randomly sample a sequence of frames
            max_start_idx = total_frames - seg_num_frames
            start_frame = np.random.randint(0, max_start_idx)
            end_frame = start_frame + seg_num_frames
            frame_ids = np.linspace(start_frame, end_frame, self.sequence_length, dtype=int)
            frame_ids = frame_ids.tolist()
            act_fps = self.sequence_length / (seg_num_frames / video_fps)
        frame_ids_adj = []
        for frame_id in frame_ids:
            if frame_id < actual_total_frames:
                frame_ids_adj.append(frame_id)
            else:
                frame_ids_adj.append(actual_total_frames - 1)
        
        frame_data = vr.get_batch(frame_ids_adj).asnumpy()
        vr.seek(0)  # set video reader point back to 0 to clean up cache

        del vr  # delete the reader to avoid memory leak
        return frame_data, act_fps

    def __getitem__(self, index):
        index = index % len(self.video_paths)
        try:
            data = dict()
            video, fps = self._get_frames(self.video_paths[index])
            masks, fps = self._get_frames(self.mask_paths[index])
            video = video.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            masks = masks.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
            video_path = self.video_paths[index]
            umt5_path = os.path.join(
                self.wan_umt5_dir,
                os.path.basename(video_path).replace(".mp4", ".pt"),
            )
            data["video"] = video
            data["masks"] = masks
            data["video_name"] = {
                "video_path": video_path,
                "umt5_path": umt5_path,
                # "start_frame_id": '0',
            }

            # Just add these to fit the interface
            if random.random() < self.uncond_rate:
                umt5_embedding = torch.load(self.uncond_path)
            else:
                umt5_embedding = torch.load(umt5_path)
            n_tokens = umt5_embedding.shape[0]
            if n_tokens < 512:
                umt5_embedding = torch.cat([umt5_embedding, torch.zeros((512 - n_tokens, umt5_embedding.shape[1]), dtype=umt5_embedding.dtype)], dim=0)
            t5_text_mask = torch.zeros(512, dtype=torch.int64)
            t5_text_mask[:n_tokens] = 1

            data["umt5_embedding"] = umt5_embedding
            data["t5_text_mask"] = t5_text_mask
            data["fps"] = fps
            data["num_frames"] = self.sequence_length

            return data
        except Exception as e:
            warnings.warn(
                f"Invalid data encountered: {self.video_paths[index]}. Skipped "
                f"(by randomly sampling another sample in the same dataset)."
            )
            warnings.warn("FULL TRACEBACK:")
            warnings.warn(traceback.format_exc())
            self.wrong_number += 1
            log.info(self.wrong_number, rank0_only=False)
            raise e
            return self[np.random.randint(len(self.video_paths))]

        
if __name__ == "__main__":
    dataset = Dataset(
        dataset_dir="datasets/benchmark_train/gr1",
        num_frames=93,
        video_size=[480, 832],
    )

    indices = [0, 13, -1]
    for idx in indices:
        data = dataset[idx]
        log.info(
            (
                f"{idx=} "
                f"{data['video'].sum()=}\n"
                f"{data['video'].shape=}\n"
                f"{data['video_name']=}\n"
                f"{data['t5_text_embeddings'].shape=}\n"
                "---"
            )
        )
