import torch 

def cat_tx(t, x):
    return torch.hstack([x, torch.scalar_tensor(t).to(x.device).expand(*x.shape[:-1], 1)])

def cat_tx2(t, x):
    return torch.hstack([x, t.to(x.device).expand(*x.shape[:-1], 1)])

def cat_stx(s, t, x):
    return torch.hstack([x, torch.scalar_tensor(t).to(x.device).expand(*x.shape[:-1], 1), torch.scalar_tensor(s).to(x.device).expand(*x.shape[:-1], 1)])

def pad_zeros_upfi(x):
    return torch.cat([torch.zeros_like(x[..., :1]), x], axis = -1)

def sample_batch(X, batch_size, add_noise = False, noise_level = 0.01, replacement = False):
    X_batch = torch.cat([x[torch.randint(x.shape[0], (batch_size, )), :].unsqueeze(0) for x in X]) if replacement \
        else torch.cat([x[torch.randperm(x.shape[0])[:batch_size], :].unsqueeze(0) for x in X])
    if add_noise:
        X_batch += torch.randn_like(X_batch) * noise_level
    return X_batch

def sample_batch_upfi(X, m_ratios, batch_size, **kwargs):
    return torch.cat([torch.log(m_ratios / batch_size).unsqueeze(1).tile((1, batch_size)).unsqueeze(2),
                            sample_batch(X, batch_size, **kwargs), ], dim = 2)

def get_flow(v, s, D):
    return v - (D/2)*s

cos_dist = torch.vmap(lambda u, v: 0.5*(1-torch.dot(u, v) / (u.norm() * v.norm())))

l2_dist = torch.vmap(lambda u, v: (u-v).norm()**2)

