import numpy as np
import torch

from deptheval.utils.constants import VALID_DEPTH_LB, VALID_DEPTH_UB
from deptheval.utils.torch import reformat_as_torch_tensor


def get_depth_valid(depth, valid_depth_lb=VALID_DEPTH_LB, valid_depth_ub=VALID_DEPTH_UB):
    if isinstance(depth, np.ndarray):
        return (~np.isnan(depth)) & (~np.isinf(depth)) & (depth >= valid_depth_lb) & (depth <= valid_depth_ub)
    elif isinstance(depth, torch.Tensor):
        return (~torch.isnan(depth)) & (~torch.isinf(depth)) & (depth >= valid_depth_lb) & (depth <= valid_depth_ub)
    else:
        raise ValueError(f'{type(depth)=}')


def load_data(depth_f, as_torch=False):
    data = np.load(depth_f)
    depth, intr, valid = data['depth'], data['intr'], data['valid']
    depth[~valid] = 1
    if as_torch:
        depth = reformat_as_torch_tensor(depth)
        intr = reformat_as_torch_tensor(intr)
        valid = reformat_as_torch_tensor(valid)
    return depth, intr, valid
