import copy
import io
import gc
from typing import Dict, Union, Any, Optional, List
import math
import os
import random
import sys
import time
import contextlib
from typing import Literal
from collections.abc import Sequence, Mapping
from pathlib import Path
from omegaconf import OmegaConf
from einops import rearrange
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import IterableDataset
import webdataset as wds
try:
    import torchvision.transforms.v2.functional as TF
    from torchvision.io import read_image
except:
    pass

try:
    from torchcodec.decoders import VideoDecoder, VideoStreamMetadata
except Exception as e:
    print(e)
    pass

try:
    import av
except:
    pass


def nested_to(d, **to_kwargs):
    if isinstance(d, Sequence) and not isinstance(d, str):
        return [ nested_to(d_, **to_kwargs) for d_ in d ]
    elif isinstance(d, Mapping):
        return { k: nested_to(v, **to_kwargs) for k, v in d.items() }
    elif isinstance(d, torch.Tensor):
        return d.to(**to_kwargs)
    else:
        return d

# ===========================================================================================
# Transform and augmentation functions
# ===========================================================================================
def black_border_crop(video, threshold=12):
    F, C, H, W = video.shape
    N = min(10, video.shape[0])
    idcs = torch.linspace(0, video.shape[0]-1, N, dtype=torch.long)
    mask = video[idcs] > threshold
    mask = torch.any(mask, dim=[0, 1])
    mask_rows = torch.any(mask, dim=1)
    v, i = mask_rows.unique_consecutive(return_counts=True)
    h_min = 0 if v[0] else i[0].item()
    h_max = H if v[-1] else H - i[-1].item()

    mask_columns = torch.any(mask, dim=0)
    v, i = mask_columns.unique_consecutive(return_counts=True)
    w_min = 0 if v[0] else i[0].item()
    w_max = W if v[-1] else W - i[-1].item()

    return h_min, w_min, h_max - h_min, w_max - w_min


def resize(imgs: torch.Tensor, image_size: int) -> torch.Tensor:
    ss = imgs.shape
    assert ss[-3] == 3

    H, W = ss[-2:]

    if len(ss) == 3:
        imgs = imgs.unsqueeze(0)

    side = min(H, W)
    factor = side // image_size

    imgs = TF.center_crop(imgs, [side, side])
    if factor > 1:
        imgs = F.avg_pool2d(imgs, factor)
    H, W = ss[-2:]

    if H != image_size or W != image_size:
        imgs = F.interpolate(imgs, [image_size, image_size], mode="bilinear")

    if len(ss) == 3:
        imgs = imgs[0]

    return imgs


def file_reader(source):
    for sample in source:
        with open(sample['url'], 'rb') as fp:
            raw = fp.read()
        name, ext = os.path.splitext(sample['url'])
        out = {
            '__key__': name,
            ext[1:]: raw
        }
        yield out

def get_frames_at_with_pyav(container, frames_at):
    index = 0
    fps = float(container.streams.video[0].base_rate)
    frames = {}
    i = 0
    for frame in container.decode(video=0):
        index = i
        if index in frames_at:
            frames[index] = {
                'data': frame.to_ndarray(format="rgb24"),
                'pts_seconds': frame.time
            }
        i += 1
    frame_idcs = torch.tensor(sorted(list(frames.keys())), dtype=torch.long)
    data = []
    pts_seconds = []
    for i in frames_at:
        if i in frame_idcs:
            idx = i
        else:
            idx = frame_idcs[torch.searchsorted(frame_idcs, i, right=False)].item()
        data.append(frames[idx]['data'])
        pts_seconds.append(frames[idx]['pts_seconds'])
    frames_batch = {
        'data': torch.from_numpy(np.stack(data)).permute(0, 3, 1, 2),  # (T, C, H, W)
        'pts_seconds': torch.tensor(pts_seconds, dtype=torch.float32),
    }
    return frames_batch


class ResampledShardLists(IterableDataset):
    def __init__(self, tar_paths, n_repeats=1, shuffle=False, seed=1337, continue_from_step=0):
        super().__init__()
        if isinstance(tar_paths, str):
            self.tar_paths = wds.shardlists.expand_urls(tar_paths)
        else:
            self.tar_paths = list(tar_paths)
        self.tar_paths = sorted(self.tar_paths)
        self.n_repeats = n_repeats
        self.shuffle = shuffle
        if self.shuffle:
            self.rng = random.Random(seed)
        self.from_step = continue_from_step

    def __iter__(self):
        count = 0
        for _ in range(self.n_repeats):
            tar_paths = self.tar_paths.copy()
            if self.shuffle:
                self.rng.shuffle(tar_paths)
            for tar_path in tar_paths:
                if count >= self.from_step:
                    yield dict(url=tar_path)
                count += 1


def collate_fn(x):
    return x

class VideoLoader:
    VIDEO_EXTENSIONS = ['mp4', 'avi', 'webm']

    def __init__(
        self, 
        tar_paths,
        batch_size=1,
        num_workers=0,
        mode: Literal['wds', 'files'] = 'wds',
        sequence_len: int | None = None, # in frames
        n_sequences_per_video: int | None = None, # if set, extract a constant number of sequences of each video rather than extracting consecutive snippets
        sequence_sample_window=None, # in seconds
        sequence_shift=False,
        extensions: List[str] | None = None,
        fps=None,
        fps_multiplier=1,
        frame_skip=1,
        min_num_frames=0,
        min_fps=None,
        max_fps=None,
        width=None,
        height=None,
        min_width=None,
        min_height=None,
        max_size=None,
        max_aspect_ratio=None,
        min_aspect_ratio=None,
        crop_black_borders=False,
        trim=0,
        pin_memory=True,
        partial=True,
        n_repeats=1,
        epoch_size=None,
        shuffle=0,
        shardshuffle=False,
        node_splitting=True,
        split_strategy='shards',
        transfer_device='cpu',
        transfer_prefetch_factor=2,
        decoding_device='cpu',
        deterministic=True,
        seed=1337,
        use_pyav=False,
        *args, **kwargs
    ):
        super().__init__()
        self.mode = mode

        if self.mode == 'wds' and isinstance(tar_paths, str):
            self.tar_paths = [str(tar) for tar in Path(tar_paths).rglob('*.tar')]
        elif isinstance(tar_paths, (list, tuple)):
            self.tar_paths = list(tar_paths)
        else: 
            raise Exception("file_paths must be str or sequence")

        self.extensions = None if extensions is None else tuple(extensions)

        assert sequence_len is None or sequence_len > 0
        self.seq_len = sequence_len
        if self.seq_len is not None:
            if int(self.seq_len) != self.seq_len:
                print(f"WARN: Rounding sequence_len {sequence_len} to {int(sequence_len)}")
            self.seq_len = int(self.seq_len)
        assert not sequence_shift or n_sequences_per_video is None, "parameters sequence_shift and n_sequences_per_video are exclusive"
        self.seq_shift = sequence_shift
        assert n_sequences_per_video is None or n_sequences_per_video > 0
        self.n_seqs_per_vid = n_sequences_per_video
        self.seq_sample_window = sequence_sample_window
        if self.seq_len is not None:
            self.min_n_frames = max(self.seq_len, min_num_frames)
        else:
            self.min_n_frames = min_num_frames
        
        if frame_skip != 1:
            assert self.seq_sample_window is None
        self.frame_skip = frame_skip
        self.fps = fps
        self.fps_multiplier = fps_multiplier or 1
        assert fps_multiplier == 1, "fps_multiplier currently not supported"
        assert trim == 0, "trim currently not supported"
        self.trim = trim
        self.min_fps = min_fps
        self.max_fps = max_fps
        self.min_width = min_width
        self.min_height = min_height
        self.width = width
        self.height = height
        self.max_size = max_size
        if width is not None and height is not None:
            ratio = width / height
            self.max_aspect_ratio = ratio
            self.min_aspect_ratio = ratio
        else:
            self.max_aspect_ratio = max_aspect_ratio
            self.min_aspect_ratio = min_aspect_ratio
        self.crop_black_borders = crop_black_borders

        self.epoch_size = epoch_size
        if self.epoch_size is None or self.epoch_size < 1:
            self.epoch_size = None
        self.batch_size = batch_size
        self.partial = partial
        self.n_repeats = n_repeats
        self.shardshuffle = shardshuffle
        self.shuffle_size = shuffle
        self.deterministic = deterministic
        assert split_strategy in ['shards', 'videos', 'snippets']
        self.split_strategy = split_strategy
        self.node_splitting = node_splitting
        self.pin_memory = pin_memory

        self.device = torch.device(transfer_device)
        self.prefetch_factor = transfer_prefetch_factor

        self.n_workers = int(num_workers)
        self.dec_device = torch.device(decoding_device)
        self.seed = seed or 1337

        self.reset_state()

        self.use_pyav = use_pyav

    def __iter__(self):
        return self.iterator()

    def reset_state(self):
        self.global_rng = random.Random(self.seed)
        rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info(group=None)
        seed_list = [self.seed, rank]
        if not self.deterministic:
            seed_list.extend([os.getpid(), time.time_ns(), os.urandom(4)])
        self.local_rng = random.Random(wds.utils.make_seed(*seed_list))
        self.loader = None
        self.world_size = world_size
        self.rank = rank
        self._n_seen_videos = None
        self._n_used_videos = None
        print(f'[RANK {rank}] Initialized video loader state')

    def iterator(self):
        if self.epoch_size is not None:
            n_batches_per_rank = self.epoch_size // (self.world_size * self.batch_size)
        else:
            n_batches_per_rank = None
        
        def source_fn():
            n_repeats = sys.maxsize if n_batches_per_rank is not None else self.n_repeats
            return ResampledShardLists(self.tar_paths, n_repeats=n_repeats, shuffle=self.shardshuffle, seed=int.from_bytes(self.global_rng.randbytes(4)))

        if self.mode == 'wds':
            def to_samples_fn():
                return wds.tarfile_to_samples(handler=wds.warn_and_continue)
        elif self.mode == 'files':
            def to_samples_fn():
                return wds.pipelinefilter(file_reader)()

        def nop(src):
            yield from src

        self._n_seen_videos = 0
        self._n_used_videos = 0
        if self.split_strategy == 'shards':
            dataset = wds.DataPipeline(
                source_fn(),
                wds.split_by_node if self.node_splitting else nop,
                wds.split_by_worker,
                to_samples_fn(),
                wds.pipelinefilter(self.decode_and_cut)(),
            )
        elif self.split_strategy == 'videos':
            dataset = wds.DataPipeline(
                source_fn(),
                to_samples_fn(),
                wds.split_by_node if self.node_splitting else nop,
                wds.split_by_worker,
                wds.pipelinefilter(self.decode_and_cut)(),
            )
        elif self.split_strategy == 'snippets':
            dataset = wds.DataPipeline(
                source_fn(),
                to_samples_fn(),
                wds.pipelinefilter(self.decode_and_cut)(),
                wds.split_by_node if self.node_splitting else nop,
                wds.split_by_worker,
            )
        else:
            raise f"Unsupported splitting strategy: {self.split_strategy}"

        loader = wds.WebLoader(
            dataset, 
            batch_size=8,
            num_workers=self.n_workers,
            pin_memory=self.pin_memory and (self.device.type == 'cuda'),
            prefetch_factor=2 if self.n_workers > 0 else None,
            drop_last=False,
            collate_fn=collate_fn
        )
        loader.append(wds.pipelinefilter(self.unbatch)())
        if self.shuffle_size > 0:
            loader.append(wds.shuffle(
                self.shuffle_size, 
                initial=self.shuffle_size,
                rng=self.local_rng
            ))
        loader.append(wds.pipelinefilter(self.batch_and_transfer)())

        if n_batches_per_rank is not None:
            loader.with_epoch(n_batches_per_rank)

        yield from loader

    def make_sample_infos(self, video_name: str, metadata: VideoStreamMetadata):
        snippet_infos = []
        # video_duration = min(metadata.duration_seconds, metadata.duration_seconds_from_header)
        begin_secs = metadata.begin_stream_seconds + 1e-3
        end_secs = metadata.end_stream_seconds - 1e-3
        video_duration = end_secs - begin_secs
        n_video_frames = metadata.num_frames
        
        if self.min_n_frames > n_video_frames:
            return snippet_infos

        # full videos
        if self.seq_sample_window is None and self.seq_len is None:
            assert self.n_seqs_per_vid is None or self.n_seqs_per_vid == 1
            if self.fps is None:
                snippet_infos.append({'frames': {'frames_in_range': (0, n_video_frames)}})
            else:
                frames_played_at = torch.arange(
                    int(video_duration * self.fps)
                ).float().div(self.fps).add(begin_secs).clamp_max(end_secs)
                snippet_infos.append({'frames': {'frames_played_at': frames_played_at.tolist()}})
        
        # snippets of specific duration in seconds
        elif self.seq_sample_window is not None and self.seq_len is None:
            if video_duration < self.seq_sample_window:
                return None
            
            n_snippets = self.n_seqs_per_vid or int(video_duration / self.seq_sample_window)
            for snippet_idx in range(n_snippets):
                if self.n_seqs_per_vid is not None:
                    offset = begin_secs + (self.local_rng.random() * (end_secs - self.seq_sample_window))
                else:
                    offset = begin_secs + (snippet_idx * self.seq_sample_window)
                    if self.seq_shift:
                        offset += (self.local_rng.random() * self.seq_sample_window) - (self.seq_sample_window / 2)
                offset = max(offset, begin_secs)
                offset = min(offset, end_secs - self.seq_sample_window)
                
                info = {}
                if self.fps is None:
                    info['frames_played_in_range'] = (offset, offset + self.seq_sample_window)
                else:
                    info['frames_played_at'] = torch.arange(int(self.seq_sample_window * self.fps)).float().div(self.fps).add(offset)
                snippet_infos.append({'frames': info})

        # snippets with a certain number of frames
        elif self.seq_sample_window is None and self.seq_len is not None:
            if self.fps is not None:
                n_video_frames = int(video_duration * self.fps)
            
            if n_video_frames < self.seq_len:
                return None

            # n_video_frames = int(video_duration * video_fps) # conservative lower guess
            n_snippets = self.n_seqs_per_vid or int(n_video_frames / self.seq_len)
            for snippet_idx in range(n_snippets):
                if self.n_seqs_per_vid is not None:
                    offset = self.local_rng.random() * (n_video_frames - self.seq_len)
                else:
                    offset = snippet_idx * self.seq_len
                    if self.seq_shift:
                        offset += (self.local_rng.random() * self.seq_len) - (self.seq_len / 2)
                offset = int(offset)
                offset = min(max(offset, 0), n_video_frames - self.seq_len)
                
                info = {}
                if self.fps is None:
                    info['frames_in_range'] = (offset, offset + (self.frame_skip * self.seq_len), self.frame_skip)
                else:
                    frames_played_at = torch.arange(self.seq_len).add(offset).float().div(self.fps).add(begin_secs).clamp_max(end_secs)
                    info['frames_played_at'] = frames_played_at.tolist()
                snippet_infos.append({'frames': info})

        return snippet_infos

    def process_sample(self, sample):
        keys_to_ignore = ['video_name', 'video_url']
        if self.extensions is not None:
            keys_to_ignore.extend(self.extensions)
        for k in sample.keys():
            if k in keys_to_ignore:
                continue
            
            frames = sample[k]['data']

            # Some optional cropping
            top, left, height, width = 0, 0, frames.shape[-2], frames.shape[-1]
            if self.crop_black_borders:
                top, left, height, width = black_border_crop(frames, 6)

            ratio = width / height
            if self.max_aspect_ratio is not None and ratio > self.max_aspect_ratio: # e.g. 4/3
                new_width = int(round(height * self.max_aspect_ratio))
                left = left + int((width - new_width) / 2)
                width = new_width
            elif self.min_aspect_ratio is not None and ratio < self.min_aspect_ratio: # e.g. 3/4
                new_height = int(round(width / self.min_aspect_ratio))
                top = top + int((height - new_height) / 2)
                height = new_height
            
            # discard small samples
            if (
                (self.min_width is not None and width < self.min_width) or 
                (self.min_height is not None and height < self.min_height)
            ):
                return None

            # apply crop
            if top != 0 or left != 0 or width != frames.shape[-1] or height != frames.shape[-2]:
                frames = TF.crop(frames, top=top, left=left, height=height, width=width)

            # optionally resize
            if self.width is not None and self.height is not None:
                frames = TF.resize(frames, size=(self.height, self.width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True)
            elif self.max_size is not None:
                max_size = max(height, width)
                if max_size > self.max_size:
                    new_min_size = int(min(height, width) * (self.max_size / max_size))
                    frames = TF.resize(frames, size=new_min_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True)

            sample[k]['data'] = frames
        
        return sample

    def construct_batch(self, samples):
        keys_to_ignore = ['video_name', 'video_url']
        if self.extensions is not None:
            keys_to_ignore.extend(self.extensions)
    
        batch = { 
            'video_name': [s['video_name'] for s in samples],
            'video_url': [s['video_url'] for s in samples],
        }
        if self.extensions is not None:
            for ext in self.extensions:
                if all([ext in s for s in samples]):
                    batch[ext] = [s[ext] for s in samples]
        
        for sample in samples:
            for k, v in sample.items():
                if k in keys_to_ignore:
                    continue
                if k not in batch:
                    batch[k] = { 'data': [], 'pts': [] }
                batch[k]['data'].append(v['data'])
                batch[k]['pts'].append(v['pts'])
        return batch

    def decode_and_cut(self, source):
        for entry in source:
            decoder = None
            video_name = entry['__key__']
            video_url = entry['__url__']
            self._n_seen_videos += 1
            try:
                # get encoded video bytes
                enc_video = None
                for ext in self.VIDEO_EXTENSIONS:
                    if ext in entry:
                        assert enc_video is None
                        enc_video = entry[ext]
                assert enc_video is not None

                # (optional) get additional extensions
                extensions = {}
                if self.extensions is not None:
                    for ext in self.extensions:
                        if ext in entry:
                            extensions[ext] = entry[ext]

                # initialize decoder and extract metadata
                if self.use_pyav:
                    av_decoder = av.open(io.BytesIO(enc_video))
                    av_stream = av_decoder.streams.video[0]
                    video_fps = float(av_stream.average_rate)
                    height = av_stream.height
                    width = av_stream.width
                    metadata = VideoStreamMetadata(
                        num_frames_from_header=av_stream.frames,
                        num_frames_from_content=av_stream.frames,
                        duration_seconds_from_header=float(av_stream.duration * av_stream.time_base),
                        begin_stream_seconds=float(av_stream.start_time * av_stream.time_base),
                        end_stream_seconds=float((av_stream.start_time + av_stream.duration) * av_stream.time_base),
                        average_fps_from_header=video_fps,
                        bit_rate=av_stream.bit_rate,
                        codec=av_stream.codec.name,
                        stream_index=0,
                        height=height,
                        width=width,
                    )
                else:
                    decoder = VideoDecoder(enc_video)
                    metadata = decoder.metadata
                    video_fps = metadata.average_fps_from_header
                    height = metadata.height
                    width = metadata.width

                # discard videos with low fps
                if self.min_fps is not None and video_fps < self.min_fps:
                    continue

                # discard videos with low fps
                if self.max_fps is not None and video_fps > self.max_fps:
                    continue

                # discard videos with small width or height
                if (
                    (self.min_width is not None and width < self.min_width) or 
                    (self.min_height is not None and height < self.min_height)
                ):
                    continue

                # generate sample infos
                snippet_infos = self.make_sample_infos(video_name, metadata)
                if snippet_infos is None:
                    continue

                # decode samples
                for snippet_info in snippet_infos:
                    try:
                        sample = { 'video_name': video_name, 'video_url': video_url }
                        sample.update(extensions)
                        frames_at = []
                        frames_played_at = []
                        for entry_name, entry in snippet_info.items():
                            if entry.get('frames_at', None) is not None:
                                frames_at.append((entry_name, entry.get('frames_at')))
                            if entry.get('frames_played_at', None) is not None:
                                frames_played_at.append((entry_name, entry.get('frames_played_at')))
                            if entry.get('frames_in_range', None) is not None:
                                if self.use_pyav:
                                    frames_at.append((entry_name, list(range(*entry.get('frames_in_range')))))
                                else:
                                    frame_batch = decoder.get_frames_in_range(*entry.get('frames_in_range'))
                                    sample[entry_name] = {
                                        'data': frame_batch.data,
                                        'pts': frame_batch.pts_seconds.float()
                                    }
                            if entry.get('frames_played_in_range', None) is not None:
                                frame_batch = decoder.get_frames_played_in_range(*entry.get('frames_played_in_range'))
                                sample[entry_name] = {
                                    'data': frame_batch.data,
                                    'pts': frame_batch.pts_seconds.float()
                                }

                        if len(frames_at) > 0:
                            start = 0
                            if self.use_pyav:
                                frame_batch = get_frames_at_with_pyav(av_decoder, sum([e[1] for e in frames_at], []))
                                for entry_name, idcs in frames_at:
                                    end = start + len(idcs)
                                    sample[entry_name] = {
                                        'data': frame_batch['data'][start:end],
                                        'pts': frame_batch['pts_seconds'][start:end],
                                    }
                                    start = end
                            else:
                                frame_batch = decoder.get_frames_at(sum([e[1] for e in frames_at], []))
                                for entry_name, idcs in frames_at:
                                    end = start + len(idcs)
                                    sample[entry_name] = {
                                        'data': frame_batch.data[start:end].data,
                                        'pts': frame_batch.pts_seconds[start:end].float(),
                                    }
                                    start = end

                        if len(frames_played_at) > 0:
                            frame_batch = decoder.get_frames_played_at(sum([e[1] for e in frames_played_at], []))
                            start = 0
                            for entry_name, idcs in frames_played_at:
                                end = start + len(idcs)
                                sample[entry_name] = {
                                    'data': frame_batch.data[start:end].data,
                                    'pts': frame_batch.pts_seconds[start:end].float(),
                                }
                                start = end

                        sample = self.process_sample(sample)

                        if sample is None:
                            continue

                        yield sample
                    
                    except Exception as e:
                        print(f"[VideoLoader] {video_name}: {e}")

                self._n_used_videos += 1
            except Exception as e:
                print(f"[VideoLoader] {video_name}: {e}")

    def unbatch(self, source):
        for batch in source:
            yield from batch


    def batch_and_transfer(self, source):
        use_side_stream = False

        if use_side_stream:
            stream = torch.cuda.Stream(self.device)
        else:
            stream = None
        ctx = contextlib.nullcontext() if stream is None else torch.cuda.stream(stream)

        def collate(samples):
            event = None
            with ctx:
                batch = self.construct_batch(samples)
            if stream is not None:
                event = stream.record_event()
            return batch, event

        sample_buf = []
        batch_buf = []
        for sample in source:
            if self.device.type == 'cuda':
                with ctx:
                    sample = nested_to(sample, device=self.device, non_blocking=True)
            
            sample_buf.append(sample)

            if len(sample_buf) >= self.batch_size:
                batch_buf.append(collate(sample_buf))
                sample_buf = []
            
            if len(batch_buf) >= self.prefetch_factor:
                batch, event = batch_buf.pop(0)
                if event is not None:
                    torch.cuda.default_stream(self.device).wait_event(event)
                yield batch

        while len(batch_buf) > 0:
            batch, event = batch_buf.pop(0)
            if event is not None:
                torch.cuda.default_stream(self.device).wait_event(event)
            yield batch

        if self.partial and len(sample_buf) > 0:
            batch, event = collate(sample_buf)
            if event is not None:
                torch.cuda.default_stream(self.device).wait_event(event)
            yield batch
    

import pytorch_lightning as pl
class VideoDataModule(pl.LightningDataModule):
    def __init__(self, train=None, validation=None, test=None):
        super().__init__()
        self.dataset_configs = {"train": train, "validation": validation, "test": test}

    def get_loader(self, config):
        config = copy.deepcopy(config)

        device = self.trainer.strategy.root_device
        device = config.pop('device', None) or device

        # Gather shards
        tar_base = config.pop('tar_base')
        tars = config.pop('tars')
        if isinstance(tar_base, str) and isinstance(tars, str):
            shards = os.path.join(tar_base, tars)
        elif isinstance(tar_base, Sequence) and isinstance(tars, Sequence):
            shards = []
            for i in range(len(tar_base)):
                p = Path(tar_base[i])
                if not p.exists():
                    print(p)
                    raise f"Path does not exist: {p}"
                shards += map(lambda p: str(p), p.rglob(tars[i]))
        elif isinstance(tar_base, str) and isinstance(tars, Sequence):
            shards = []
            for i in range(len(tars)):
                shards += wds.shardlists.expand_urls(os.path.join(tar_base, tars[i]))
        else:
            raise Exception("Invalid tar configuration!")
        
        if len(shards) == 0:
            raise Exception("No tar files found!")

        # Create loader
        resampled = config.pop("resampled", True)
        loader = VideoLoader(
            tar_paths = shards,
            partial = False,
            n_repeats = sys.maxsize if resampled else 1,
            transfer_device = device,
            **OmegaConf.to_container(config),
        )
        return loader

    def setup(self, stage: str):
        pass

    def train_dataloader(self):
        return self.get_loader(self.dataset_configs["train"])

    def val_dataloader(self, device=None):
        return self.get_loader(self.dataset_configs["validation"])

    def test_dataloader(self):
        return self.get_loader(self.dataset_configs["test"])

    def teardown(self, stage: str):
        pass
