import torch
from torch.distributions.beta import Beta
from einops import repeat

def get_output_dims(modes):
    dims_out = {}
    for mode in modes:
        dims_out[mode] = 1
    return dims_out

def get_input_dims(modes):
    dims_out = {}
    for mode in modes:
        dims_out[mode] = 1
    return dims_out

def get_out_bias(mode):
    return 0.0

def get_input_range(mode):
    return 1.0

def sample_nmr(xs: torch.Tensor, ys: torch.Tensor, num_samples: int):
    B, N, Dx = xs.shape
    Dy = ys.shape[-1]
    device = xs.device

    noise = torch.rand(B, N, device=device)
    ids_sample = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_sample, dim = 1)
    ids_sample = ids_sample[:, :num_samples]

    x_nmr = torch.gather(xs, dim=1, index=ids_sample.unsqueeze(-1).repeat(1, 1, Dx))
    y_nmr = torch.gather(ys, dim=1, index=ids_sample.unsqueeze(-1).repeat(1, 1, Dy))

    # 1 if sampled otherwise 0.
    ms_sample = torch.zeros_like(xs[..., 0])
    ms_sample = ms_sample.scatter(1, ids_sample, 1 - ms_sample)

    return x_nmr, y_nmr, ids_restore, ids_sample


def generate_random_masks(input_tokens,
                          num_sample_min=0,
                          num_sample_max=None,
                          dist = Beta(1, 1),
                          ) :
    B, N = input_tokens.shape[:2]
    num_sample_min = max(0, num_sample_min)
    num_sample_max = min(N, N if num_sample_max is None else num_sample_max)

    if num_sample_min == num_sample_max:
        cutoff = repeat(torch.tensor([num_sample_max + 1]), '1 -> B 1', B = B)
    else:
        cutoff = (dist.sample((B,1)) * (num_sample_max + 1 - num_sample_min) + num_sample_min)

    mask = (repeat(torch.arange(1, N + 1), 'N -> B N', B = B) > cutoff).long()
    mask = mask.to(input_tokens)
    return mask

def to_device(data, device):
    '''
    Load data with arbitrary structure on device.
    from MTP
    '''
    def to_device_wrapper(data):
        if isinstance(data, torch.Tensor):
            return data.to(device)
        elif isinstance(data, tuple):
            return tuple(map(to_device_wrapper, data))
        elif isinstance(data, list):
            return list(map(to_device_wrapper, data))
        elif isinstance(data, dict):
            return {key: to_device_wrapper(data[key]) for key in data}
        else:
            raise NotImplementedError

    return to_device_wrapper(data)
