# from memory_profiler import profile

import os
import json
import random
import math
import torch
from torch.utils.data import DataLoader
import copy
import pickle
import gc
import numpy as np

from utils import *
from loading import *
import torchvision
import torch
from collections import defaultdict
import json



def load_video_frames(video_path, transpose=False):
    frames = []
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video file {video_path}")
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        # Convert frame from BGR to RGB color space
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if transpose:
            # Transpose frame to shape (width, height, 3)
            frame_rgb = np.transpose(frame_rgb, (1, 0, 2))
        frames.append(frame_rgb)
    cap.release()
    return frames

def load_video(video_path, start_frame, end_frame):
    if not os.path.exists(video_path):
        print("video path does not exist")
        return []

    cap = cv2.VideoCapture(video_path)
    video = []
    iter_count = 0

    while(cap.isOpened()):
        iter_count += 1
        # Capture frames in the video
        ret, frame = cap.read()

        if iter_count == 1:
            orig_height, orig_width, _ = frame.shape

        if ret == True:
            video.append(frame)
        else:
            break

    video_window = np.stack(video[start_frame: end_frame])
    return video_window, orig_height, orig_width




# Example helper function to create a mask from a bounding box
def bbox_to_mask(frame_height, frame_width, bbox):
    """
    Create a binary mask of shape (frame_height, frame_width, 3),
    where the bounding box area is set to 1 and the rest is 0.
    bbox is [xmin, ymin, xmax, ymax].
    """
    mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
    x1, y1, x2, y2 = map(int, bbox)
    
    # Clamp the coordinates to valid ranges, just in case
    x1 = max(0, min(frame_width, x1))
    x2 = max(0, min(frame_width, x2))
    y1 = max(0, min(frame_height, y1))
    y2 = max(0, min(frame_height, y2))
    
    mask[y1:y2, x1:x2, :] = 1
    return mask


def load_video(video_path, start_frame, end_frame):
    """
    Loads frames from a video file using OpenCV,
    returning (list_of_frames, orig_height, orig_width).
    If end_frame = -1, it means "load until the end".
    """
    if not os.path.exists(video_path):
        print("video path does not exist:", video_path)
        return [], 0, 0

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Cannot open video file:", video_path)
        return [], 0, 0

    all_frames = []
    iter_count = 0
    orig_height = None
    orig_width = None

    while True:
        ret, frame = cap.read()
        if not ret:
            break  # no more frames or read error

        if orig_height is None or orig_width is None:
            orig_height, orig_width = frame.shape[:2]

        # If we are bounding the frames by [start_frame, end_frame), skip frames outside that range
        if end_frame != -1 and iter_count >= end_frame:
            break
        if iter_count >= start_frame:
            # BGR to RGB if needed, or keep BGR. Usually we do keep BGR for OpenCV usage.
            # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            all_frames.append(frame)

        iter_count += 1

    cap.release()
    return all_frames, orig_height, orig_width


class VidVRDDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataset_dir,
        device,
        data_percentage=100,
        phase="train",
        max_vid_len=99999,  # load all frames by default
        skip_videos=[],
        only_videos = [],
        splice_size = 1,
        splice_start = 0,
        ft_split=None,
        load_sgdet=False,
    ):
        super().__init__()
        self.phase = phase
        self.data_path = os.path.join(dataset_dir, phase)
        self.video_path = os.path.join(dataset_dir, "videos")
        self.samples = [f for f in os.listdir(self.data_path) if f.endswith(".json")]
        
        self.samples.sort()
        if ft_split is not None:
            finetune_dir = os.path.join(os.path.abspath(dataset_dir), "finetune")
            if ft_split == "eval":
                split_file = os.path.join(finetune_dir, f"eval_random.json")
            else:
                split_file = os.path.join(finetune_dir, f"finetune_{ft_split}percent.json")
            with open(split_file, "r") as f:
                ft_sample_ids = json.load(f)
            self.samples = [s for s in ft_sample_ids if s in self.samples]
            if ft_split == "eval":
                self.samples = self.samples[:200]
        
        self.device = device
        self.max_vid_len = max_vid_len
        
        assert phase == "train" or phase == "test"
        objects_path = os.path.abspath(os.path.join(os.path.abspath(dataset_dir), "info", "objects.txt"))
        predicates_path = os.path.abspath(os.path.join(os.path.abspath(dataset_dir), "info", "predicates.txt"))
        rel_weights_path = os.path.abspath(os.path.join(os.path.abspath(dataset_dir), "info", "rel_cache.json"))
        
        # We have an sgcls folder from the original code
        self.sgcls_path = os.path.abspath(os.path.join(os.path.abspath(dataset_dir), "sgcls"))
        
        # The sgdet folder is in the same parent directory as sgcls
        # E.g., if dataset_dir/sgcls/..., then dataset_dir/sgdet/...
        sgcls_parent = os.path.dirname(self.sgcls_path)  # directory containing 'sgcls'
        self.sgdet_path = os.path.join(sgcls_parent, "sgdet")  # parallel folder
        
        with open(objects_path, 'r') as file:
            self.objects = [line.strip() for line in file]
        with open(predicates_path, 'r') as file:
            self.predicates = [line.strip() for line in file]
        if os.path.exists(rel_weights_path):
            with open(rel_weights_path, 'r') as file:
                self.rel_weights = json.load(file)
        else:
            self.rel_weights = {}

        self.objects = [i.replace("_", " ") for i in self.objects]
        self.predicates = [i.replace("_", " ") for i in self.predicates]
        self.predicates = [i for i in self.predicates if i not in ["creep beneath", "lie with", "swim behind"]]
        
        # store whether we load sgdet
        self.load_sgdet = load_sgdet

        # potentially filter out some fraction if data_percentage < 100
        use_count = int(len(self.samples) * (data_percentage / 100.0))
        self.samples = self.samples[:use_count]
        self.samples_all  = self.samples

        # skip_videos is a list of video_ids you might want to exclude
        if skip_videos or only_videos:
            filtered_samples = []
            for s in self.samples:
                json_path = os.path.join(self.data_path, s)
                with open(json_path, "r") as f:
                    meta = json.load(f)
                if meta["video_id"] not in skip_videos:
                    if only_videos and meta["video_id"] not in only_videos:
                        continue
                    filtered_samples.append(s)
            self.samples = filtered_samples
            
        if only_videos:
            filtered_samples = []
            for s in self.samples:
                json_path = os.path.join(self.data_path, s)
                with open(json_path, "r") as f:
                    meta = json.load(f)
                if meta["video_id"] in only_videos:
                    filtered_samples.append(s)
            self.samples = filtered_samples
            
        if splice_size > 1:
            # Calculate the total number of samples
            total_samples = len(self.samples)
            partition_size = total_samples // splice_size
            remainder = total_samples % splice_size
            start_index = splice_start * partition_size + min(splice_start, remainder)
            end_index = start_index + partition_size
            if splice_start < remainder:
                end_index += 1
            self.samples = self.samples[start_index:end_index]
            print(f"[Rank {splice_size}] Dataset spliced from {start_index} to {end_index} ")

    def __len__(self):
        return len(self.samples)

    def process_val(self, x, max_val):
        x = max(0, x)
        x = min(x, max_val)
        return x

    def __getitem__(self, i):
        """
        1. Retrieve the JSON filename from self.samples[i].
        2. Load the JSON file and parse its content.
        3. Build a tid->category mapping from 'subject/objects'.
        4. Load the corresponding video frames (up to self.max_vid_len).
        5. Build GT annotations for each frame; if not using sgdet, we try to load sgcls masks.
           If that fails or if we're using sgdet, fallback to bounding-box masks.
        6. Build dt annotations if load_sgdet is True, using sgdet metadata & masks.
           If sgdet fails or doesn't exist, we just skip dt data (empty).
        7. Build relations per frame.
        8. Return (datapoint, reshaped_raw_video).
        """

        # -------------------------
        # Step 1: Get the JSON file name
        # -------------------------
        sample_file = self.samples[i]
        json_path = os.path.join(self.data_path, sample_file)
        with open(json_path, 'r') as f:
            meta = json.load(f)

        video_id = meta["video_id"]
        trajectories = meta["trajectories"]       # List of frames, each is a list of bbox dicts
        relation_instances = meta.get("relation_instances", [])
        subj_objs = meta.get("subject/objects", [])
        frame_count = meta.get("frame_count", -1)
        fps = meta.get("fps", -1)
        w = meta.get("width", -1)
        h = meta.get("height", -1)

        # -------------------------
        # Step 2: Build tid->category map
        # -------------------------
        tid_to_category = {}
        for obj_meta in subj_objs:
            tid_to_category[obj_meta["tid"]] = obj_meta["category"]

        # -------------------------
        # Step 3: Load video frames
        # -------------------------
        video_file = os.path.join(self.video_path, video_id + ".mp4")
        all_frames, orig_height, orig_width = load_video(video_file, 0, -1)

        total_num_frames = len(all_frames)
        num_frames_to_use = min(total_num_frames, self.max_vid_len)

        truncated_frames = all_frames[:num_frames_to_use]
        truncated_trajectories = trajectories[:num_frames_to_use]

        # -------------------------
        # Step 4a: Decide if we'll load SGCLS (only if not using sgdet)
        # -------------------------
        use_sgcls = False
        sgcls_masks_dict = {}
        sgcls_meta = []
        if not self.load_sgdet:
            # We try to load SGCLS if it exists
            sgcls_dir = os.path.join(self.sgcls_path, f"{video_id}_mask")
            metadata_json_path = os.path.join(sgcls_dir, "metadata.json")
            masks_npz_path = os.path.join(sgcls_dir, "masks.npz")

            if os.path.exists(sgcls_dir) and os.path.exists(metadata_json_path) and os.path.exists(masks_npz_path):
                try:
                    with open(metadata_json_path, 'r') as f:
                        sgcls_meta = json.load(f)  # list of dicts, one per frame
                    with np.load(masks_npz_path, allow_pickle=False) as data:
                        sgcls_masks_dict = {key: data[key] for key in data.files}
                    if len(sgcls_meta) >= num_frames_to_use:
                        use_sgcls = True
                    else:
                        print(f"[Error] SGCLS for video {video_id} has fewer frames ({len(sgcls_meta)}) "
                              f"than needed ({num_frames_to_use}). Falling back to bounding boxes.")
                except Exception as e:
                    print(f"[Error] Loading SGCLS data for {video_id} failed: {e}. "
                          f"Falling back to bounding boxes.")
            else:
                print(f"[Warning] SGCLS data not found for video {video_id}. "
                      f"Falling back to bounding-box masks.")

        # -------------------------
        # Step 4b: Attempt to load SGDET (if self.load_sgdet=True)
        # -------------------------
        use_sgdet = False
        sgdet_masks_dict = {}
        sgdet_meta = []
        if self.load_sgdet:
            sgdet_dir = os.path.join(self.sgdet_path, f"{video_id}_mask")
            metadata_json_path = os.path.join(sgdet_dir, "metadata.json")
            masks_npz_path = os.path.join(sgdet_dir, "masks.npz")

            if os.path.exists(sgdet_dir) and os.path.exists(metadata_json_path) and os.path.exists(masks_npz_path):
                try:
                    with open(metadata_json_path, 'r') as f:
                        sgdet_meta_dict = json.load(f)  # This is a dict, presumably keyed by old_frame or idx
                        # Or it might be a list. The user said it would be a dict keyed by frame or a list of frames.
                        # We'll assume it's a list-of-dicts, similar to sgcls. If it's a dict-of-dicts, adjust here.
                        # In many parallel SG code, the structure is the same as sgcls, i.e. a list
                        # We'll treat it as a list for consistency:
                        sgdet_meta = sgdet_meta_dict

                    with np.load(masks_npz_path, allow_pickle=False) as data:
                        sgdet_masks_dict = {key: data[key] for key in data.files}

                    if len(sgdet_meta) >= num_frames_to_use:
                        use_sgdet = True
                    else:
                        print(f"[Error] SGDET metadata for video {video_id} has fewer frames ({len(sgdet_meta)}) "
                              f"than needed ({num_frames_to_use}). Using empty dt data.")
                except Exception as e:
                    print(f"[Error] Loading SGDET data for {video_id} failed: {e}. "
                          f"Using empty dt data.")
            else:
                print(f"[Warning] SGDET data not found for video {video_id}. Using empty dt data.")

        # -------------------------
        # Step 5: Build annotation structure
        # -------------------------
        annotations = []
        for frame_id in range(num_frames_to_use):
            # Ground‐Truth Info
            if frame_id < len(truncated_trajectories):
                frame_bboxes = truncated_trajectories[frame_id]
            else:
                frame_bboxes = []

            gt_instance_ids = []
            gt_bboxes = []
            gt_labels = []
            gt_masks = []

            for box_info in frame_bboxes:
                tid = box_info["tid"]
                cat = tid_to_category.get(tid, "unknown").replace('_',' ')

                # Clamp bounding box
                xmin = self.process_val(box_info["bbox"]["xmin"], orig_width)
                ymin = self.process_val(box_info["bbox"]["ymin"], orig_height)
                xmax = self.process_val(box_info["bbox"]["xmax"], orig_width)
                ymax = self.process_val(box_info["bbox"]["ymax"], orig_height)
                bbox = [xmin, ymin, xmax, ymax]

                # Decide how to load GT mask
                mask_3d = None
                if not self.load_sgdet and use_sgcls:
                    # Try to get the SGCLS mask
                    meta_frame = sgcls_meta[frame_id]
                    mask_tid = None
                    for ind, v in meta_frame.items():
                        if v['bbox'] == bbox:
                            mask_tid = ind
                            break
                    mask_key = f"{frame_id}_{mask_tid}"
                    if mask_key in sgcls_masks_dict:
                        sgcls_mask_2d = sgcls_masks_dict[mask_key]  # (H, W)
                        if sgcls_mask_2d.shape[0] != orig_height or sgcls_mask_2d.shape[1] != orig_width:
                            print(f"[Error] SGCLS mask shape mismatch for video {video_id}, "
                                  f"frame {frame_id}, tid {tid}. Expected ({orig_height}, {orig_width}), "
                                  f"got {sgcls_mask_2d.shape}. Falling back to bbox mask.")
                        else:
                            mask_3d = np.repeat(sgcls_mask_2d[:, :, None], 3, axis=2)
                    else:
                        # Missing in metadata
                        pass

                # If still None, use bounding-box mask
                if mask_3d is None:
                    mask_3d = bbox_to_mask(orig_height, orig_width, bbox)

                gt_instance_ids.append(tid)
                gt_bboxes.append(bbox)
                gt_labels.append(cat)
                gt_masks.append(mask_3d)

            # dt (Detection) Info
            dt_instance_ids = []
            dt_bboxes = []
            dt_masks = []

            if use_sgdet:
                # e.g. sgdet_meta[frame_id] is a dict: {obj_id: { 'bbox': [x1,y1,x2,y2], ...}, ... }
                meta_frame = sgdet_meta[frame_id]
                for obj_id_str, info_dict in meta_frame.items():
                    dt_bbox = info_dict["bbox"]
                    if dt_bbox is None:
                        # skip if no bbox
                        continue

                    # clamp
                    x1 = self.process_val(dt_bbox[0], orig_width)
                    y1 = self.process_val(dt_bbox[1], orig_height)
                    x2 = self.process_val(dt_bbox[2], orig_width)
                    y2 = self.process_val(dt_bbox[3], orig_height)
                    dt_bbox_clamped = [x1, y1, x2, y2]

                    # Try to load mask
                    mask_key = f"{frame_id}_{obj_id_str}"
                    dt_mask_3d = None
                    if mask_key in sgdet_masks_dict:
                        dt_mask_2d = sgdet_masks_dict[mask_key]  # shape (H, W)
                        # If entire mask is False, skip
                        if not dt_mask_2d.any():
                            continue
                        if dt_mask_2d.shape[0] != orig_height or dt_mask_2d.shape[1] != orig_width:
                            print(f"[Error] SGDET mask shape mismatch for video {video_id}, "
                                  f"frame {frame_id}, dt obj {obj_id_str}. "
                                  f"Expected ({orig_height}, {orig_width}), got {dt_mask_2d.shape}. Skipping.")
                            continue
                        dt_mask_3d = np.repeat(dt_mask_2d[:, :, None], 3, axis=2)
                    else:
                        # If the key isn't present, skip
                        continue

                    # If we reach here, dt_mask_3d is valid
                    dt_instance_ids.append(obj_id_str)
                    dt_bboxes.append(dt_bbox_clamped)
                    dt_masks.append(dt_mask_3d)

            # Save final annotation for this frame
            frame_annotation = {
                "gt_instance_ids": gt_instance_ids,
                "gt_bboxes": gt_bboxes,
                "gt_labels": gt_labels,
                "gt_masks": gt_masks,   # 3D masks (H, W, 3)

                # new detection fields
                "dt_instance_ids": dt_instance_ids,
                "dt_bboxes": dt_bboxes,
                "dt_masks": dt_masks,
            }

            annotations.append(frame_annotation)

        # -------------------------
        # Step 6: Build relations per frame
        # -------------------------
        relations_per_frame = [[] for _ in range(num_frames_to_use)]
        for rel in relation_instances:
            begin_fid = rel['begin_fid']
            end_fid = rel['end_fid']
            from_id = rel['subject_tid']
            to_id = rel['object_tid']
            rel_name = rel['predicate'].replace('_', ' ').replace('/','_')

            for fid in range(begin_fid, end_fid):
                if 0 <= fid < num_frames_to_use:
                    relations_per_frame[fid].append((from_id, to_id, rel_name))

        # -------------------------
        # Step 7: Build final datapoint
        # -------------------------
        datapoint = {
            "video_id": video_id,
            "frame_count": frame_count,
            "fps": fps,
            "width": w,
            "height": h,
            "annotations": annotations,       # includes GT and DT
            "relations": relations_per_frame,
            "binary_kwords": self.predicates,
            "gpt_spec": None,
            "neg_gpt_spec": None,
            "neg_kws": None,
        }

        reshaped_raw_video = truncated_frames

        # cleanup
        del sgcls_masks_dict
        del sgdet_masks_dict
        gc.collect()

        return (datapoint, reshaped_raw_video)



    @staticmethod
    def collate_fn(batch):
        """
        Custom collate function that batches both ground-truth (gt_*) and
        detection (dt_*) fields across frames in the dataset.
        """
        batched_videos = []
        batched_reshaped_raw_videos = []
        batched_captions = []
        batched_obj_pairs = []
        batched_ids = []
        batched_video_splits = []
        batched_gpt_specs = []

        batched_gt_bboxes = []
        batched_gt_masks = []
        batched_gt_obj_names = []
        batched_gt_object_rels = []
        batched_object_ids = []
        frame_ct_in_video = 0

        batched_neg_gpt_specs = []
        batched_neg_kws = []
        batched_binary_predicates = []

        # --- NEW dt fields ---
        batched_dt_object_ids = []
        batched_dt_bboxes = []
        batched_dt_masks = []

        for data_id, (datapoint, reshaped_raw_video) in enumerate(batch):
            batched_reshaped_raw_videos += reshaped_raw_video
            batched_ids.append(datapoint['video_id'])
            binary_kwords = datapoint['binary_kwords']

            annotations = datapoint['annotations']

            if 'gpt_spec' in datapoint and datapoint['gpt_spec']:
                batched_captions.append(datapoint['gpt_spec']['caption'])
                batched_gpt_specs.append(datapoint['gpt_spec'])
            else:
                batched_captions.append("")

            if 'neg_gpt_spec' in datapoint:
                batched_neg_gpt_specs.append(datapoint['neg_gpt_spec'])
            if 'neg_kws' in datapoint:
                batched_neg_kws.append(datapoint['neg_kws'])

            batched_gt_object_rels.append(datapoint['relations'])

            # Gather GT info
            all_obj_ids = set()
            for frame_id, frame in enumerate(annotations):
                all_obj_ids.update(frame['gt_instance_ids'])

            for frame_id, frame in enumerate(annotations):
                # ---- GT data ----
                gt_ids = frame['gt_instance_ids']
                gt_labels = frame['gt_labels']
                gt_bboxes = frame['gt_bboxes']
                gt_masks = frame['gt_masks']

                batched_gt_bboxes += gt_bboxes
                batched_gt_masks += gt_masks
                batched_object_ids += [(data_id, frame_id, obj_id) for obj_id in gt_ids]
                batched_gt_obj_names += [(data_id, frame_id, label) for label in gt_labels]

                #TODO: for sgcls only
                # Only add pairs if a ground truth relation exists for that specific pair in this frame.
                frame_relations = datapoint['relations'][frame_id]  # list of (from_id, to_id, rel_name)
                for oid1 in gt_ids:
                    for oid2 in gt_ids:
                        if oid1 != oid2:
                            # Check if there is any relation (with any predicate) for this ordered pair.
                            if any(r[0] == oid1 and r[1] == oid2 for r in frame_relations):
                                batched_obj_pairs.append((data_id, frame_id, (oid1, oid2)))


                # ---- DT data ----
                dt_ids = frame['dt_instance_ids']
                dt_bboxes = frame['dt_bboxes']
                dt_masks = frame['dt_masks']

                # Store dt fields in parallel
                batched_dt_object_ids += [(data_id, frame_id, dt_id) for dt_id in dt_ids]
                batched_dt_bboxes += dt_bboxes
                batched_dt_masks += dt_masks

            # Keep track of how many frames have been read across the batch
            frame_ct_in_video += len(reshaped_raw_video)
            batched_video_splits.append(frame_ct_in_video)

        gc.collect()

        res = {
            'batched_ids': batched_ids,
            'batched_captions': batched_captions,
            'batched_gt_masks': batched_gt_masks,
            'batched_gt_bboxes': batched_gt_bboxes,
            'batched_obj_pairs': batched_obj_pairs,
            'batched_object_ids': batched_object_ids,  # GT IDs
            'batched_video_splits': batched_video_splits,
            'batched_reshaped_raw_videos': batched_reshaped_raw_videos,
            'batched_gt_obj_names': batched_gt_obj_names,
            'batched_gt_object_rels': batched_gt_object_rels,
            'batched_gpt_specs': batched_gpt_specs,
            'batched_videos': batched_videos,
            'batched_binary_predicates': batched_binary_predicates,
            # NEW dt fields:
            'batched_dt_object_ids': batched_dt_object_ids,
            'batched_dt_bboxes': batched_dt_bboxes,
            'batched_dt_masks': batched_dt_masks,
        }

        if batched_neg_gpt_specs:
            res['batched_neg_gpt_specs'] = batched_neg_gpt_specs
        if batched_neg_kws:
            res['batched_neg_kws'] = batched_neg_kws

        return res



def open_vidvrd_loader(dataset_dir, batch_size, device, cache_path=None, 
                       dataset_name=None, dataloader_worker_ct=0, training_percentage=100, 
                       testing_percentage=100, max_video_len=8, neg_spec=False, neg_kws=False, 
                       neg_example_ct=5, require_gpt_spec=True, neg_example_file_name="neg_examples.json", 
                       set_norm_x=None, set_norm_y=None, backbone_model="violet", sampler=None, 
                       splice_start=0, splice_size=1, skip_videos=[], only_videos = [], ft_split=None):

    train_dataset = VidVRDDataset(dataset_dir, device=device, phase="train", only_videos=only_videos, ft_split=ft_split, skip_videos=skip_videos)
    if sampler is not None:
        train_sampler = sampler(train_dataset)
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_sampler,         # <--- pass distributed sampler here
            collate_fn=VidVRDDataset.collate_fn,
            drop_last=True,
            num_workers=dataloader_worker_ct
        )
    else:
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=True,
            collate_fn=VidVRDDataset.collate_fn,
            num_workers=dataloader_worker_ct
        )
    
    
    valid_dataset = VidVRDDataset(dataset_dir, device=device, phase="test", only_videos=only_videos, skip_videos=skip_videos, splice_start = splice_start, splice_size=splice_size)
    valid_loader = DataLoader(
            valid_dataset,
            batch_size=batch_size,
            shuffle=False,
            drop_last=True,
            collate_fn=VidVRDDataset.collate_fn,
            num_workers=dataloader_worker_ct
        )

    test_dataset = VidVRDDataset(dataset_dir, device=device, phase="test", only_videos=only_videos, splice_start = splice_start, splice_size=splice_size)
    # test_loader = DataLoader(valid_dataset, batch_size, collate_fn=VidVRDDataset.collate_fn, shuffle=False, drop_last=True, num_workers=dataloader_worker_ct)
    test_loader = DataLoader(test_dataset, batch_size, collate_fn=VidVRDDataset.collate_fn, shuffle=False, drop_last=True, num_workers=dataloader_worker_ct)

    return (train_dataset, valid_dataset, test_dataset, train_loader, valid_loader, test_loader)