import os
import subprocess

from torch import nn
from collections import OrderedDict
from einops import rearrange
from torch.utils.data import random_split

def exists(val):
    return val is not None

def default(val, d, func=None):
    if exists(func):
        return func(val) if exists(val) else d
    return val if exists(val) else d

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def nn_sequential(**args):
    return nn.Sequential(OrderedDict(args))

def fold(*xs):
    return *[rearrange(x, 's b ... -> (s b) ...') for x in xs], xs[0].shape[1]

def unfold(*xs, b):
    xs = [rearrange(x, '(s b) ... -> s b ...', b=b) for x in xs]
    return xs[0] if len(xs) == 1 else xs

def shuffle_dataset(dataset):
    dataset_length = len(dataset)
    shuffled_dataset, _ = random_split(dataset, [dataset_length, 0])
    return shuffled_dataset

