import math
from os import listdir, path
from random import choices, randint
from typing import Any, Callable, Dict, List, Tuple, Optional

import cv2 as cv
import json
import torch
import torch.nn.functional as F
from einops import rearrange
from lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data import get_worker_info
from tqdm import tqdm


def exists(var) -> bool:
    return var is not None


def default(var, val) -> Any:
    return var if exists(var) else val


def default_worker_init_fn(worker_id: int) -> None:
    torch.manual_seed(torch.initial_seed() + worker_id)
    worker_info = get_worker_info()

    if exists(worker_info):
        dataset = worker_info.dataset
        glob_start = dataset._start
        glob_end = dataset._end

        per_worker = int((glob_end - glob_start) / worker_info.num_workers)
        worker_id = worker_info.id

        dataset._start = glob_start + worker_id * per_worker
        dataset._end = min(dataset._start + per_worker, glob_end)


class LightningDataset(LightningDataModule):
    """
    Abstract LightningDataModule that represents a dataset we can train a Lightning module on.
    """

    def __init__(
            self,
            *args,
            batch_size: int = 8,
            num_workers: int = 16,
            train_shuffle: bool = True,
            val_shuffle: bool = False,
            val_batch_size: int = None,
            worker_init_fn: Callable = None,
            collate_fn: Callable = None,
            train_sampler: Callable = None,
            test_sampler: Callable = None,
            val_sampler: Callable = None
    ) -> None:
        super(LightningDataset, self).__init__()
        self.train_dataset = None
        self.test_dataset = None
        self.val_dataset = None

        val_batch_size = default(val_batch_size, batch_size)

        self.num_workers = num_workers
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        self.train_shuffle = train_shuffle
        self.val_shuffle = val_shuffle
        self.train_sampler = train_sampler
        self.test_sampler = test_sampler
        self.val_sampler = val_sampler
        self.collate_fn = collate_fn
        self.worker_init_fn = worker_init_fn

    def train_dataloader(self) -> DataLoader:
        if isinstance(self.train_dataset, IterableDataset):
            worker_init_fn = default(self.worker_init_fn, default_worker_init_fn)
        else:
            worker_init_fn = self.worker_init_fn
        def safe_collate(batch):
            try:
                return torch.utils.data.default_collate(batch)
            except KeyError:
                keys = set(batch[0].keys())
                batch = [b for b in batch if set(b.keys()) == keys]
                return torch.utils.data.default_collate(batch)
        return DataLoader(
            self.train_dataset,
            sampler=self.train_sampler,
            batch_size=self.batch_size,
            shuffle=self.train_shuffle,
            collate_fn=self.collate_fn or safe_collate,
            num_workers=self.num_workers,
            worker_init_fn=worker_init_fn
        )

    def val_dataloader(self) -> DataLoader:
        if isinstance(self.train_dataset, IterableDataset):
            worker_init_fn = default(self.worker_init_fn, default_worker_init_fn)
        else:
            worker_init_fn = self.worker_init_fn
        def safe_collate(batch):
            try:
                return torch.utils.data.default_collate(batch)
            except KeyError:
                keys = set(batch[0].keys())
                batch = [b for b in batch if set(b.keys()) == keys]
                return torch.utils.data.default_collate(batch)
        return DataLoader(
            self.val_dataset,
            sampler=self.val_sampler,
            batch_size=self.val_batch_size,
            shuffle=self.val_shuffle,
            collate_fn=self.collate_fn or safe_collate,
            num_workers=self.num_workers,
            worker_init_fn=worker_init_fn
        )

    def test_dataloader(self) -> DataLoader:
        if isinstance(self.train_dataset, IterableDataset):
            worker_init_fn = default(self.worker_init_fn, default_worker_init_fn)
        else:
            worker_init_fn = self.worker_init_fn
        def safe_collate(batch):
            try:
                return torch.utils.data.default_collate(batch)
            except KeyError:
                keys = set(batch[0].keys())
                batch = [b for b in batch if set(b.keys()) == keys]
                return torch.utils.data.default_collate(batch)
        return DataLoader(
            self.test_dataset,
            sampler=self.test_sampler,
            batch_size=self.val_batch_size,
            shuffle=self.val_shuffle,
            collate_fn=self.collate_fn or safe_collate,
            num_workers=self.num_workers,
            worker_init_fn=worker_init_fn
        )


class VideoDataset(Dataset):
    def __init__(
            self,
            split_path: str,
            padding: str = "repeat",
            randomize: bool = False,
            resolution: int = 256,
            num_frames: int = 16,
            output_format: str = "t h w c",
            color_aug: bool = True
    ) -> None:
        super(VideoDataset, self).__init__()
        self.padding = padding
        self.randomize = randomize
        self.resolution = resolution
        self.num_frames = num_frames
        self.output_format = output_format
        self.color_aug = color_aug

        # Get all the file path based on the split path
        self.file_names = []
        for file_name in listdir(split_path):
            if file_name.endswith(".mp4") or file_name.endswith(".webm"):
                self.file_names.append(path.join(split_path, file_name))

    def __len__(self) -> int:
        return len(self.file_names)

    def __getitem__(self, idx: int) -> Dict:
        video_path = self.file_names[idx]
        while True:
            try:
                video = self.load_video_slice(
                    video_path,
                    self.num_frames,
                    None if self.randomize else 0
                )
                return self.build_data_dict(video)
            except:
                idx = randint(0, len(self) - 1)
                video_path = self.file_names[idx]

    def load_video_slice(
            self,
            video_path: str,
            num_frames: int,
            start_frame: int = None,
            frame_skip: int = 1
    ) -> Tensor:
        cap = cv.VideoCapture(video_path)
        total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
        if "retro" in video_path:
            frame_skip = 4
        elif "procgen" not in video_path and "ssv2" not in video_path and "mira" not in video_path:
            frame_skip = 2
        num_frames = num_frames * frame_skip

        start_frame = start_frame if exists(start_frame) else randint(0, max(0, total_frames - num_frames))
        cap.set(cv.CAP_PROP_POS_FRAMES, start_frame)
        frames = []
        for _ in range(num_frames):
            ret, frame = cap.read()
            if ret:
                # Frame was successfully read, parse it
                frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
                frame = torch.from_numpy(frame)
                frames.append(frame)
            else:
                # Reach the end of video, deal with padding and return
                if self.padding == "none":
                    pass
                elif self.padding == "repeat":
                    frames.extend([frames[-1]] * (num_frames - len(frames)))
                elif self.padding == "zero":
                    frames.extend([torch.zeros_like(frames[-1])] * (num_frames - len(frames)))
                elif self.padding == "random":
                    frames.extend([torch.rand_like(frames[-1])] * (num_frames - len(frames)))
                else:
                    raise ValueError(f"Invalid padding type: {self.padding}")
                break
        cap.release()

        video = torch.stack(frames[::frame_skip]) / 255.0

        # Crop the video to be square
        if video.shape[1] != video.shape[2]:
            square_len = min(video.shape[1], video.shape[2])
            h_crop = (video.shape[1] - square_len) // 2
            w_crop = (video.shape[2] - square_len) // 2
            video = video[:, h_crop:h_crop + square_len, w_crop:w_crop + square_len]

        if video.shape[-2] != self.resolution or video.shape[-3] != self.resolution:
            video = rearrange(video, "t h w c -> c t h w")
            video = F.interpolate(video, self.resolution, mode="bicubic")
            video = rearrange(video, f"c t h w -> {self.output_format}")
        else:
            video = rearrange(video, f"t h w c -> {self.output_format}")
        return video

    def build_data_dict(self, video: Tensor) -> Dict:
        if self.color_aug:
            # Brightness jitter
            video = (video + torch.rand(1) * 0.2 - 0.1).clamp(0, 1)

        data_dict = {
            "videos": video
        }
        return data_dict


class OriginalVideoDataset(Dataset):
    def __init__(
            self,
            data_root: str,
            env_source: str = "libero",
            split: str = "train",
            padding: str = "repeat",
            randomize: bool = False,
            resolution: int = 256,
            num_frames: int = 16,
            output_format: str = "t h w c",
            color_aug: bool = True
    ) -> None:
        super(OriginalVideoDataset, self).__init__()
        self.padding = padding
        self.randomize = randomize
        self.resolution = resolution
        self.num_frames = num_frames
        self.output_format = output_format
        self.color_aug = color_aug

        # Get all the file path based on the split
        folders = []
        if env_source == "procgen":
            for env in listdir(path.join(data_root, "procgen")):
                folders.append(path.join(data_root, "procgen", env, split))
        elif env_source == "retro":
            for env in listdir(path.join(data_root, "retro")):
                folders.append(path.join(data_root, "retro", env, split))
        elif env_source == "game":
            for env in listdir(path.join(data_root, "procgen")):
                folders.append(path.join(data_root, "procgen", env, split))
            for env in listdir(path.join(data_root, "retro")):
                folders.append(path.join(data_root, "retro", env, split))
        elif env_source == "robot":
            for env in listdir(path.join(data_root, "openx")):
                folders.append(path.join(data_root, "openx", env, split))
        elif path.exists(path.join(data_root, env_source, split)):
            folders.append(path.join(data_root, env_source, split))
        else:
            raise ValueError(f"Invalid source: {env_source}")
        self.file_names = []
        for folder in folders:
            self.file_names.extend([
                path.join(folder, f)
                for f in listdir(folder) if f.endswith(".mp4")
            ])

    def __len__(self) -> int:
        return len(self.file_names)

    def __getitem__(self, idx: int) -> Dict:
        video_path = self.file_names[idx]
        while True:
            try:
                video = self.load_video_slice(
                    video_path,
                    self.num_frames,
                    None if self.randomize else 0
                )
                return self.build_data_dict(video)
            except:
                idx = randint(0, len(self) - 1)
                video_path = self.file_names[idx]

    def load_video_slice(
            self,
            video_path: str,
            num_frames: int,
            start_frame: int = None,
            frame_skip: int = 1
    ) -> Tensor:
        cap = cv.VideoCapture(video_path)
        total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
        if "retro" in video_path:
            frame_skip = 4
        elif "procgen" not in video_path and "ssv2" not in video_path and "mira" not in video_path:
            frame_skip = 2
        num_frames = num_frames * frame_skip

        start_frame = start_frame if exists(start_frame) else randint(0, max(0, total_frames - num_frames))

        cap.set(cv.CAP_PROP_POS_FRAMES, start_frame)
        frames = []
        for _ in range(num_frames):
            ret, frame = cap.read()
            if ret:
                # Frame was successfully read, parse it
                frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
                frame = torch.from_numpy(frame)
                frames.append(frame)
            else:
                # Reach the end of video, deal with padding and return
                if self.padding == "none":
                    pass
                elif self.padding == "repeat":
                    frames.extend([frames[-1]] * (num_frames - len(frames)))
                elif self.padding == "zero":
                    frames.extend([torch.zeros_like(frames[-1])] * (num_frames - len(frames)))
                elif self.padding == "random":
                    frames.extend([torch.rand_like(frames[-1])] * (num_frames - len(frames)))
                else:
                    raise ValueError(f"Invalid padding type: {self.padding}")
                break
        cap.release()

        video = torch.stack(frames[::frame_skip]) / 255.0

        # Crop the video to be square
        if video.shape[1] != video.shape[2]:
            square_len = min(video.shape[1], video.shape[2])
            h_crop = (video.shape[1] - square_len) // 2
            w_crop = (video.shape[2] - square_len) // 2
            video = video[:, h_crop:h_crop + square_len, w_crop:w_crop + square_len]

        if video.shape[-2] != self.resolution or video.shape[-3] != self.resolution:
            video = rearrange(video, "t h w c -> c t h w")
            video = F.interpolate(video, self.resolution, mode="bicubic")
            video = rearrange(video, f"c t h w -> {self.output_format}")
        else:
            video = rearrange(video, f"t h w c -> {self.output_format}")
        return video

    def build_data_dict(self, video: Tensor) -> Dict:
        if self.color_aug:
            # Brightness jitter
            video = (video + torch.rand(1) * 0.2 - 0.1).clamp(0, 1)

        data_dict = {
            "videos": video
        }
        return data_dict


class MultiSourceSamplerDataset(Dataset):
    def __init__(
            self,
            data_root: str,
            env_source: str = "libero",
            split: str = "train",
            samples_per_epoch: int = 1000000,
            sampling_strategy: str = "sample",
            color_aug: bool = True,
            **kwargs
    ) -> None:
        self.samples_per_epoch = samples_per_epoch

        # Create all subsets
        folders = []
        if env_source == "procgen":
            for env in listdir(path.join(data_root, "procgen")):
                folders.append(path.join(data_root, "procgen", env, split))
        elif env_source == "retro":
            for env in listdir(path.join(data_root, "retro")):
                folders.append(path.join(data_root, "retro", env, split))
        elif env_source == "game":
            for env in listdir(path.join(data_root, "procgen")):
                folders.append(path.join(data_root, "procgen", env, split))
            for env in listdir(path.join(data_root, "retro")):
                folders.append(path.join(data_root, "retro", env, split))
        elif env_source == "robot":
            for env in listdir(path.join(data_root, "openx")):
                folders.append(path.join(data_root, "openx", env, split))
        elif path.exists(path.join(data_root, env_source, split)):
            folders.append(path.join(data_root, env_source, split))
        else:
            raise ValueError(f"Invalid source: {env_source}")
        self.subsets = []
        for folder in tqdm(folders, desc="Loading subsets..."):
            print("Subset:", folder.split("/")[-2])
            self.subsets.append(VideoDataset(split_path=folder, color_aug=color_aug, **kwargs))
        print("Number of subsets:", len(self.subsets))

        if sampling_strategy == "sample":
            # Sample uniformly from all samples
            probs = [len(d) for d in self.subsets]
        elif sampling_strategy == "dataset":
            # Sample uniformly from all datasets
            probs = [1 for _ in self.subsets]
        elif sampling_strategy == "log":
            # Generate probabilities according to the scale of each dataset
            probs = [math.log(len(d)) if len(d) else 0 for d in self.subsets]
        elif sampling_strategy == "pi":
            # Generate probabilities according to the scale of each dataset
            probs = [len(d) ** 0.43 for d in self.subsets]
        else:
            raise ValueError(f"Unavailable sampling strategy: {sampling_strategy}")
        total_prob = sum(probs)
        assert total_prob > 0, "No sample is available"
        self.sample_probs = [x / total_prob for x in probs]

    def __len__(self) -> int:
        return self.samples_per_epoch

    def __getitem__(self, idx: int) -> Dict:
        """
        Args:
        index (int): Index (ignored since we sample randomly).

        Returns:
        TensorDict: Dict containing all the data blocks.
        """

        # Randomly select a subset based on weights
        subset = choices(self.subsets, self.sample_probs)[0]

        # Sample a valid sample with a random index
        sample_idx = randint(0, len(subset) - 1)
        sample_item = subset[sample_idx]
        return sample_item


class LightningVideoDataset(LightningDataset):
    def __init__(
            self,
            data_root: str,
            env_source: str = "libero",
            padding: str = "repeat",
            randomize: bool = False,
            resolution: int = 256,
            num_frames: int = 16,
            output_format: str = "t h w c",
            samples_per_epoch: int = 1000000,
            sampling_strategy: str = "sample",
            lerobot: Dict = None,
            multiview: Dict = None,
            egoinexo: Dict = None,
            **kwargs
    ) -> None:
        super(LightningVideoDataset, self).__init__(**kwargs)
        self.data_root = data_root
        self.env_source = env_source
        self.padding = padding
        self.randomize = randomize
        self.resolution = resolution
        self.num_frames = num_frames
        self.output_format = output_format
        self.samples_per_epoch = samples_per_epoch
        self.sampling_strategy = sampling_strategy
        self.lerobot = lerobot
        self.multiview = multiview
        self.egoinexo = egoinexo

        self.save_hyperparameters()

    def setup(self, stage: str) -> None:
        if stage == "fit":
            lerobot_enabled = self.env_source == "lerobot" or (self.lerobot and self.lerobot.get("enabled", False))
            egoinexo_enabled = bool(self.egoinexo and self.egoinexo.get("enabled", False))
            if lerobot_enabled and egoinexo_enabled:
                require_paired = True
                lr_ds = MultiLeRobotSamplerDataset(
                    specs=self.lerobot.get("sources", []),
                    padding=self.padding,
                    randomize=self.randomize,
                    resolution=self.resolution,
                    num_frames=self.num_frames,
                    output_format=self.output_format,
                    samples_per_epoch=self.samples_per_epoch // 2,
                    sampling_strategy=self.sampling_strategy,
                    default_camera_index=self.lerobot.get("default_camera_index", None),
                    default_paired_camera_index=self.lerobot.get("default_paired_camera_index", None),
                    default_video_backend=self.lerobot.get("default_video_backend", None),
                    require_paired=require_paired,
                )
                ee_ds = MultiEgoExoSamplerDataset(
                    root=self.egoinexo.get("root"),
                    json_path=self.egoinexo.get("json_path"),
                    padding=self.padding,
                    randomize=self.randomize,
                    resolution=self.resolution,
                    num_frames=self.num_frames,
                    output_format=self.output_format,
                    samples_per_epoch=self.samples_per_epoch // 2,
                    sampling_strategy=self.sampling_strategy,
                    color_aug=True,
                    view1_key=self.egoinexo.get("view1_key", "exo"),
                    view2_key=self.egoinexo.get("view2_key", "ego"),
                )
                self.train_dataset = CombinedSamplerDataset(
                    subsets=[lr_ds, ee_ds],
                    samples_per_epoch=self.samples_per_epoch,
                    sampling_strategy=self.sampling_strategy,
                )
                self.val_dataset = self.train_dataset
                return
            if lerobot_enabled:
                require_paired = bool(self.lerobot.get("default_paired_camera_index") or any(
                    (s.get("paired_camera_key") is not None) or (s.get("paired_camera_index") is not None)
                    for s in self.lerobot.get("sources", [])
                ))
                self.train_dataset = MultiLeRobotSamplerDataset(
                    specs=self.lerobot.get("sources", []),
                    padding=self.padding,
                    randomize=self.randomize,
                    resolution=self.resolution,
                    num_frames=self.num_frames,
                    output_format=self.output_format,
                    samples_per_epoch=self.samples_per_epoch,
                    sampling_strategy=self.sampling_strategy,
                    default_camera_index=self.lerobot.get("default_camera_index", None),
                    default_paired_camera_index=self.lerobot.get("default_paired_camera_index", None),
                    default_video_backend=self.lerobot.get("default_video_backend", None),
                    require_paired=require_paired,
                )
                self.val_dataset = self.train_dataset
                return
            if egoinexo_enabled:
                self.train_dataset = MultiEgoExoSamplerDataset(
                    root=self.egoinexo.get("root"),
                    json_path=self.egoinexo.get("json_path"),
                    padding=self.padding,
                    randomize=self.randomize,
                    resolution=self.resolution,
                    num_frames=self.num_frames,
                    output_format=self.output_format,
                    samples_per_epoch=self.samples_per_epoch,
                    sampling_strategy=self.sampling_strategy,
                    color_aug=True,
                    view1_key=self.egoinexo.get("view1_key", "exo"),
                    view2_key=self.egoinexo.get("view2_key", "ego"),
                )
                self.val_dataset = self.train_dataset
                return
            self.train_dataset = MultiSourceSamplerDataset(
                data_root=self.data_root,
                env_source=self.env_source,
                split="train",
                padding=self.padding,
                randomize=self.randomize,
                resolution=self.resolution,
                num_frames=self.num_frames,
                output_format=self.output_format,
                samples_per_epoch=self.samples_per_epoch,
                sampling_strategy=self.sampling_strategy
            )
            self.val_dataset = MultiSourceSamplerDataset(
                data_root=self.data_root,
                env_source="procgen",
                split="test",
                padding=self.padding,
                randomize=self.randomize,
                resolution=self.resolution,
                num_frames=self.num_frames,
                output_format=self.output_format,
                samples_per_epoch=self.samples_per_epoch // 1000,
                sampling_strategy=self.sampling_strategy,
                color_aug=False
            )
        elif stage == "test":
            if self.env_source == "lerobot" or (self.lerobot and self.lerobot.get("enabled", False)):
                first = self.lerobot.get("sources", [])[0]
                self.test_dataset = LeRobotVideoDataset(
                    repo_id=first.get("repo_id"),
                    root=first.get("root"),
                    camera_key=first.get("camera_key"),
                    padding=self.padding,
                    randomize=self.randomize,
                    resolution=self.resolution,
                    num_frames=self.num_frames,
                    output_format=self.output_format,
                    color_aug=False,
                )
            else:
                self.test_dataset = OriginalVideoDataset(
                    data_root=self.data_root,
                    env_source=self.env_source,
                    split="test",
                    padding=self.padding,
                    randomize=self.randomize,
                    resolution=self.resolution,
                    num_frames=self.num_frames,
                    output_format=self.output_format,
                    color_aug=False
                )
        else:
            raise ValueError(f"Invalid stage: {stage}")


class LeRobotVideoDataset(Dataset):
    def __init__(
            self,
            repo_id: str,
            root: str = None,
            camera_key: Optional[str] = None,
            paired_camera_key: Optional[str] = None,
            camera_index: Optional[int] = None,
            paired_camera_index: Optional[int] = None,
            video_backend: Optional[str] = None,
            padding: str = "repeat",
            randomize: bool = False,
            resolution: int = 256,
            num_frames: int = 16,
            output_format: str = "t h w c",
            color_aug: bool = True
    ) -> None:
        super().__init__()
        from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
        self.padding = padding
        self.randomize = randomize
        self.resolution = resolution
        self.num_frames = num_frames
        self.output_format = output_format
        self.color_aug = color_aug
        if exists(video_backend):
            self.dataset = LeRobotDataset(repo_id, root=root, video_backend=video_backend)
        else:
            self.dataset = LeRobotDataset(repo_id, root=root)
        self.meta = LeRobotDatasetMetadata(repo_id, root=root)
        print("All keys: ", self.meta.camera_keys)
        print(f"view1 camera_index: {camera_index}, view2 camera_index: {paired_camera_index}")
        if camera_key is not None:
            self.camera_key = camera_key
        elif exists(camera_index):
            self.camera_key = self.meta.camera_keys[int(camera_index)]
        else:
            self.camera_key = self.meta.camera_keys[0]
        if paired_camera_key is not None:
            self.paired_camera_key = paired_camera_key
        elif exists(paired_camera_index):
            self.paired_camera_key = self.meta.camera_keys[int(paired_camera_index)]
        else:
            self.paired_camera_key = None
        print("view1 camera_key: {}, view2 camera_key: {}\n".format(self.camera_key, self.paired_camera_key))
        episodes = self.meta.episodes
        self.ep_bounds: List[Tuple[int, int]] = [
            (episodes["dataset_from_index"][i], episodes["dataset_to_index"][i])
            for i in range(self.meta.total_episodes)
        ]

    def __len__(self) -> int:
        return len(self.ep_bounds)

    def _load_frames(self, start: int, end: int) -> Tensor:
        frames = [self.dataset[idx][self.camera_key] for idx in range(start, end)]
        video = torch.stack(frames) #/ 255.0, already in [0, 1]
        video = rearrange(video, "t c h w -> t h w c")
        if video.shape[1] != video.shape[2]:
            square_len = min(video.shape[1], video.shape[2])
            h_crop = (video.shape[1] - square_len) // 2
            w_crop = (video.shape[2] - square_len) // 2
            video = video[:, h_crop:h_crop + square_len, w_crop:w_crop + square_len]
        if video.shape[-2] != self.resolution or video.shape[-3] != self.resolution:
            video = rearrange(video, "t h w c -> c t h w")
            video = F.interpolate(video, self.resolution, mode="bicubic")
            video = rearrange(video, f"c t h w -> {self.output_format}")
        else:
            video = rearrange(video, f"t h w c -> {self.output_format}")
        return video

    def _load_paired_frames(self, start: int, end: int) -> Optional[Tensor]:
        if not exists(self.paired_camera_key):
            return None
        frames = [self.dataset[idx][self.paired_camera_key] for idx in range(start, end)]
        video = torch.stack(frames) #/ 255.0, already in [0, 1]
        video = rearrange(video, "t c h w -> t h w c")
        if video.shape[1] != video.shape[2]:
            square_len = min(video.shape[1], video.shape[2])
            h_crop = (video.shape[1] - square_len) // 2
            w_crop = (video.shape[2] - square_len) // 2
            video = video[:, h_crop:h_crop + square_len, w_crop:w_crop + square_len]
        if video.shape[-2] != self.resolution or video.shape[-3] != self.resolution:
            video = rearrange(video, "t h w c -> c t h w")
            video = F.interpolate(video, self.resolution, mode="bicubic")
            video = rearrange(video, f"c t h w -> {self.output_format}")
        else:
            video = rearrange(video, f"t h w c -> {self.output_format}")
        return video

    def __getitem__(self, idx: int) -> Dict:
        attempts = 0
        while attempts < 5:
            start, end = self.ep_bounds[idx]
            total = end - start
            s = randint(0, max(0, total - self.num_frames)) if (exists(self.randomize) and self.randomize) else 0
            e = min(total, s + self.num_frames)
            # clamp to dataset global num_frames to avoid backend invalid frame index
            if hasattr(self.dataset, "num_frames"):
                max_len = max(0, self.dataset.num_frames - start)
                e = min(e, max_len)
            try:
                video = self._load_frames(start + s, start + e)
                paired = self._load_paired_frames(start + s, start + e)
            except Exception:
                attempts += 1
                # shrink end to avoid out-of-bound frames; if too small, resample episode
                if e - s > 1:
                    e -= 1
                    continue
                idx = randint(0, len(self) - 1)
                continue

            if e - s < self.num_frames:
                pad_len = self.num_frames - (e - s)
                if self.padding == "repeat":
                    pad = video[-1:].repeat(pad_len, 1, 1, 1)
                elif self.padding == "zero":
                    pad = torch.zeros_like(video[-1:]).repeat(pad_len, 1, 1, 1)
                elif self.padding == "random":
                    pad = torch.rand_like(video[-1:]).repeat(pad_len, 1, 1, 1)
                else:
                    pad = torch.empty(0)
                video = torch.cat([video, pad], dim=0)
                if paired is not None:
                    if self.padding == "repeat":
                        pad2 = paired[-1:].repeat(pad_len, 1, 1, 1)
                    elif self.padding == "zero":
                        pad2 = torch.zeros_like(paired[-1:]).repeat(pad_len, 1, 1, 1)
                    elif self.padding == "random":
                        pad2 = torch.rand_like(paired[-1:]).repeat(pad_len, 1, 1, 1)
                    else:
                        pad2 = torch.empty(0)
                    paired = torch.cat([paired, pad2], dim=0)
            if self.color_aug:
                video = (video + torch.rand(1) * 0.2 - 0.1).clamp(0, 1)
                if paired is not None:
                    paired = (paired + torch.rand(1) * 0.2 - 0.1).clamp(0, 1)
            return {"videos": video} if paired is None else {"view1": {"videos": video}, "view2": {"videos": paired}}

        # If all attempts failed, return a zero tensor to avoid crashing the loader
        dummy = torch.zeros(self.num_frames, self.resolution, self.resolution, 3)
        return {"videos": dummy}


class MultiLeRobotSamplerDataset(Dataset):
    def __init__(
            self,
            specs: List[Dict],
            samples_per_epoch: int = 1000000,
            sampling_strategy: str = "sample",
            require_paired: bool = False,
            **kwargs
    ) -> None:
        self.samples_per_epoch = samples_per_epoch
        default_cam = kwargs.pop("default_camera_index", None)
        default_paired = kwargs.pop("default_paired_camera_index", None)
        default_paired = default_paired if default_paired != -1 else None
        default_backend = kwargs.pop("default_video_backend", None)
        self.subsets: List[LeRobotVideoDataset] = []
        for s in specs:
            self.subsets.append(
                LeRobotVideoDataset(
                    repo_id=s.get("repo_id"),
                    root=path.join(s.get("root"), s.get("repo_id")),
                    camera_key=s.get("camera_key"),
                    camera_index=s.get("camera_index", default_cam),
                    paired_camera_key=s.get("paired_camera_key"),
                    paired_camera_index=s.get("paired_camera_index", default_paired),
                    video_backend=s.get("video_backend", default_backend),
                    **kwargs
                )
            )
        if require_paired:
            self.subsets = [d for d in self.subsets if d.paired_camera_key is not None]
            if len(self.subsets) == 0:
                raise ValueError("No paired-camera LeRobot datasets available; please set paired_camera_index/key or disable multiview training.")
        if sampling_strategy == "sample":
            # Sample uniformly from all samples
            probs = [len(d) for d in self.subsets]
        elif sampling_strategy == "dataset":
            # Sample uniformly from all datasets
            probs = [1 for _ in self.subsets]
        elif sampling_strategy == "log":
            # Generate probabilities according to the scale of each dataset
            probs = [math.log(len(d)) if len(d) else 0 for d in self.subsets]
        elif sampling_strategy == "pi":
            # Generate probabilities according to the scale of each dataset
            probs = [len(d) ** 0.43 for d in self.subsets]
        else:
            raise ValueError(f"Unavailable sampling strategy: {sampling_strategy}")
        total_prob = sum(probs)
        self.sample_probs = [x / total_prob for x in probs]

    def __len__(self) -> int:
        return self.samples_per_epoch

    def __getitem__(self, idx: int) -> Dict:
        subset = choices(self.subsets, self.sample_probs)[0]
        sample_idx = randint(0, len(subset) - 1)
        item = subset[sample_idx]
        return item


class EgoExoPairedVideoDataset(Dataset):
    def __init__(
        self,
        root: str,
        json_path: str,
        padding: str = "repeat",
        randomize: bool = True,
        resolution: int = 256,
        num_frames: int = 16,
        output_format: str = "t h w c",
        color_aug: bool = True,
        view1_key: str = "exo",
        view2_key: str = "ego",
    ) -> None:
        super().__init__()
        self.root = root
        self.padding = padding
        self.randomize = randomize
        self.resolution = resolution
        self.num_frames = num_frames
        self.output_format = output_format
        self.color_aug = color_aug
        self.view1_key = view1_key
        self.view2_key = view2_key
        with open(json_path, "r") as f:
            self.pairs = json.load(f)

    def __len__(self) -> int:
        return len(self.pairs)

    def _load_video_slice(self, video_path: str, num_frames: int, start_frame: int = None, frame_skip: int = 2) -> Tensor:
        cap = cv.VideoCapture(video_path)
        total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
        num_frames = num_frames * frame_skip
        start_frame = start_frame if exists(start_frame) else randint(0, max(0, total_frames - num_frames))
        cap.set(cv.CAP_PROP_POS_FRAMES, start_frame)
        frames = []
        for _ in range(num_frames):
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
            frame = torch.from_numpy(frame)
            frames.append(frame)
        cap.release()
        if len(frames) == 0:
            return torch.zeros(self.num_frames, self.resolution, self.resolution, 3)
        video = torch.stack(frames[::frame_skip]).float() / 255.0
        if video.shape[1] != video.shape[2]:
            s = min(video.shape[1], video.shape[2])
            hc = (video.shape[1] - s) // 2
            wc = (video.shape[2] - s) // 2
            video = video[:, hc:hc + s, wc:wc + s]
        if video.shape[-2] != self.resolution or video.shape[-3] != self.resolution:
            video = rearrange(video, "t h w c -> c t h w")
            video = F.interpolate(video, self.resolution, mode="bicubic")
            video = rearrange(video, f"c t h w -> {self.output_format}")
        else:
            video = rearrange(video, f"t h w c -> {self.output_format}")
        return video

    def __getitem__(self, idx: int) -> Dict:
        entry = self.pairs[idx]
        exo = path.join(self.root, entry["exo_filename"])
        ego = path.join(self.root, entry["ego_filename"])
        cap_exo = cv.VideoCapture(exo)
        cap_ego = cv.VideoCapture(ego)
        te = int(cap_exo.get(cv.CAP_PROP_FRAME_COUNT))
        tg = int(cap_ego.get(cv.CAP_PROP_FRAME_COUNT))
        cap_exo.release(); cap_ego.release()
        min_total = max(0, min(te, tg) - self.num_frames)
        s = randint(0, min_total) if (exists(self.randomize) and self.randomize) else 0
        v_exo = self._load_video_slice(exo, self.num_frames, s)
        v_ego = self._load_video_slice(ego, self.num_frames, s)
        if v_exo.shape[0] < self.num_frames:
            pad_len = self.num_frames - v_exo.shape[0]
            v_exo = torch.cat([v_exo, v_exo[-1:].repeat(pad_len, 1, 1, 1)], dim=0)
        if v_ego.shape[0] < self.num_frames:
            pad_len = self.num_frames - v_ego.shape[0]
            v_ego = torch.cat([v_ego, v_ego[-1:].repeat(pad_len, 1, 1, 1)], dim=0)
        if self.color_aug:
            v_exo = (v_exo + torch.rand(1) * 0.2 - 0.1).clamp(0, 1)
            v_ego = (v_ego + torch.rand(1) * 0.2 - 0.1).clamp(0, 1)
        if self.view1_key == "exo":
            return {"view1": {"videos": v_exo}, "view2": {"videos": v_ego}}
        else:
            return {"view1": {"videos": v_ego}, "view2": {"videos": v_exo}}


class MultiEgoExoSamplerDataset(Dataset):
    def __init__(
        self,
        root: str,
        json_path: str,
        samples_per_epoch: int = 1000000,
        sampling_strategy: str = "sample",
        padding: str = "repeat",
        randomize: bool = True,
        resolution: int = 256,
        num_frames: int = 16,
        output_format: str = "t h w c",
        color_aug: bool = True,
        view1_key: str = "exo",
        view2_key: str = "ego",
    ) -> None:
        self.samples_per_epoch = samples_per_epoch
        self.dataset = EgoExoPairedVideoDataset(
            root=root,
            json_path=json_path,
            padding=padding,
            randomize=randomize,
            resolution=resolution,
            num_frames=num_frames,
            output_format=output_format,
            color_aug=color_aug,
            view1_key=view1_key,
            view2_key=view2_key,
        )
        self.sampling_strategy = sampling_strategy

    def __len__(self) -> int:
        return self.samples_per_epoch

    def __getitem__(self, idx: int) -> Dict:
        sample_idx = randint(0, len(self.dataset) - 1)
        return self.dataset[sample_idx]


class CombinedSamplerDataset(Dataset):
    def __init__(
        self,
        subsets: List[Dataset],
        samples_per_epoch: int = 1000000,
        sampling_strategy: str = "sample",
    ) -> None:
        self.samples_per_epoch = samples_per_epoch
        self.subsets = subsets
        if sampling_strategy == "sample":
            probs = [len(d) for d in self.subsets]
        elif sampling_strategy == "dataset":
            probs = [1 for _ in self.subsets]
        elif sampling_strategy == "log":
            probs = [math.log(len(d)) if len(d) else 0 for d in self.subsets]
        elif sampling_strategy == "pi":
            probs = [len(d) ** 0.43 for d in self.subsets]
        else:
            raise ValueError(f"Unavailable sampling strategy: {sampling_strategy}")
        total_prob = sum(probs)
        assert total_prob > 0
        self.sample_probs = [x / total_prob for x in probs]

    def __len__(self) -> int:
        return self.samples_per_epoch

    def __getitem__(self, idx: int) -> Dict:
        subset = choices(self.subsets, self.sample_probs)[0]
        sample_idx = randint(0, len(subset) - 1)
        return subset[sample_idx]
