import numpy as np
import os
import torch
import SimpleITK as sitk
import nibabel as nib
import torch.nn as nn
import pandas as pd
import glob
from .metrics import compute_dice, compute_hd95
import GPUtil
import torch.nn.functional as F

def gradient(input_coords, output, grad_outputs=None):
    """Compute the gradient of the output wrt the input."""

    grad_outputs = torch.ones_like(output)
    grad = torch.autograd.grad(
        output, [input_coords], grad_outputs=grad_outputs, create_graph=True
    )[0]
    return grad

def compute_jacobian_matrix_3d(input_coords, output, add_identity=True):
    """Compute the Jacobian matrix of the output wrt the input."""

    jacobian_matrix = torch.zeros(input_coords.shape[0], 3, 3)
    for i in range(3):
        jacobian_matrix[:, i, :] = gradient(input_coords, output[:, i])
        if add_identity:
            jacobian_matrix[:, i, i] += torch.ones_like(jacobian_matrix[:, i, i])
    return jacobian_matrix

def compute_deformation_regularity(network, masked_coords, output_shape=None, device='cuda'):
    network.eval()
    network.to(device)
    masked_coords = make_coordinate_tensor_3d(output_shape)
    # masked_coords = masked_coords.to('cpu')
    bs = 50000
    i = 0
    N = masked_coords.shape[0]
    jacobian_matrix = torch.zeros((N, 3, 3)).to('cpu')
    while i < N:
        curr_coords = masked_coords[i:min(N, i + bs)]
        curr_coords.requires_grad_(True)
        curr_output = network(curr_coords.to(device), 0)
        jacobian_matrix[i:min(N, i + bs)] = compute_jacobian_matrix_3d(curr_coords, curr_output, add_identity=True).detach().cpu()
        torch.cuda.empty_cache()
        i += bs
    jac_det = torch.det(jacobian_matrix)
    n_folded_voxels = (jac_det < 0).sum().item()
    n_all_voxels = jac_det.shape[0]
    folded_voxels_ratio = n_folded_voxels / n_all_voxels
    return 100 * folded_voxels_ratio

def compute_landmark_accuracy(landmarks_pred, landmarks_gt, voxel_size):
    landmarks_pred = np.round(landmarks_pred)
    landmarks_gt = np.round(landmarks_gt)

    difference = landmarks_pred - landmarks_gt
    difference = np.abs(difference)
    difference = difference * voxel_size

    means = np.mean(difference, 0)
    stds = np.std(difference, 0)

    difference = np.square(difference)
    difference = np.sum(difference, 1)
    difference = np.sqrt(difference)

    means = np.append(means, np.mean(difference))
    stds = np.append(stds, np.std(difference))

    means = np.round(means, 2)
    stds = np.round(stds, 2)

    means = means[::-1]
    stds = stds[::-1]

    return means, stds, difference


def compute_landmarks(network, landmarks_pre, image_size):
    scale_of_axes = [(0.5 * s) for s in image_size]
    coordinate_tensor = torch.FloatTensor(landmarks_pre / (scale_of_axes)) - 1.0

    network.eval()
    output = network(coordinate_tensor.cuda(), 0)
    delta = output.cpu().detach().numpy() * (scale_of_axes)
    return landmarks_pre + delta, delta

def compute_landmarks_batch(network, landmarks_pre, image_size, batch_size=50000):
    scale_of_axes = [(0.5 * s) for s in image_size]
    normalized_coords = (landmarks_pre / scale_of_axes) - 1.0
    coordinate_tensor = torch.FloatTensor(normalized_coords).to('cuda')

    network.eval()
    outputs = []

    with torch.no_grad():
        for i in range(0, coordinate_tensor.shape[0], batch_size):
            batch = coordinate_tensor[i:i + batch_size]
            out = network(batch, 0)  # ← Убедись, что 0 — это нужный аргумент
            outputs.append(out.cpu())

    full_output = torch.cat(outputs, dim=0)
    delta = full_output.numpy() * scale_of_axes
    return landmarks_pre + delta, delta


def compute_landmarks_tensorized(imp_reg, landmarks_pre, image_size):
    # print(landmarks_pre)
    scale_of_axes = [(0.5 * s) for s in image_size]
    landmarks_pre_torch = torch.tensor(landmarks_pre) - imp_reg.idxs_shift
    landmarks_pre_torch = landmarks_pre_torch.cuda()
    # output = imp_reg.network(imp_reg.coord_list, 
    #                          landmarks_pre_torch, sample=False).reshape((imp_reg.coord_list[0].shape[0],imp_reg.coord_list[1].shape[0],imp_reg.coord_list[2].shape[0],3))
    output = imp_reg.network(imp_reg.coord_list, 
                             landmarks_pre_torch, sample=True)
    
    # print(landmarks_pre.shape)
    # delta = output[landmarks_pre_torch[:, 0], landmarks_pre_torch[:, 1], landmarks_pre_torch[:, 2]].cpu().detach().numpy() * (scale_of_axes)
    delta = output.cpu().detach().numpy() * (scale_of_axes)
    # print(delta)
    return landmarks_pre + delta, delta

def load_landmarks_anhir(landmarks_path):
    landmarks = pd.read_csv(landmarks_path)
    landmarks = np.round(landmarks.to_numpy()[:, 1:])
    return landmarks[:, [1, 0]]

def load_image_DIRLab(variation=1, folder="D:\Data\DIRLAB\Case"):
    # Size of data, per image pair
    image_sizes = [
        0,
        [94, 256, 256],
        [112, 256, 256],
        [104, 256, 256],
        [99, 256, 256],
        [106, 256, 256],
        [128, 512, 512],
        [136, 512, 512],
        [128, 512, 512],
        [128, 512, 512],
        [120, 512, 512],
    ]

    # Scale of data, per image pair
    voxel_sizes = [
        0,
        [2.5, 0.97, 0.97],
        [2.5, 1.16, 1.16],
        [2.5, 1.15, 1.15],
        [2.5, 1.13, 1.13],
        [2.5, 1.1, 1.1],
        [2.5, 0.97, 0.97],
        [2.5, 0.97, 0.97],
        [2.5, 0.97, 0.97],
        [2.5, 0.97, 0.97],
        [2.5, 0.97, 0.97],
    ]

    shape = image_sizes[variation]

    folder = folder + str(variation) + "Pack" + os.path.sep

    # Images
    dtype = np.dtype(np.int16)

    with open(folder + "Images/case" + str(variation) + "_T00_s.img", "rb") as f:
        data = np.fromfile(f, dtype)
    image_insp = data.reshape(shape)

    with open(folder + "Images/case" + str(variation) + "_T50_s.img", "rb") as f:
        data = np.fromfile(f, dtype)
    image_exp = data.reshape(shape)

    imgsitk_in = sitk.ReadImage(folder + "Masks/case" + str(variation) + "_T00_s.mhd")

    imgsitk_in_2 = sitk.ReadImage(folder + "Masks/case" + str(variation) + "_T50_s.mhd")

    mask_exp = np.clip(sitk.GetArrayFromImage(imgsitk_in), 0, 1)

    mask_insp = np.clip(sitk.GetArrayFromImage(imgsitk_in_2), 0, 1)

    image_insp = torch.FloatTensor(image_insp)
    image_exp = torch.FloatTensor(image_exp)

    # Landmarks
    with open(
        folder + "ExtremePhases/Case" + str(variation) + "_300_T00_xyz.txt"
    ) as f:
        landmarks_insp = np.array(
            [list(map(int, line[:-1].split("\t")[:3])) for line in f.readlines()]
        )

    with open(
        folder + "ExtremePhases/Case" + str(variation) + "_300_T50_xyz.txt"
    ) as f:
        landmarks_exp = np.array(
            [list(map(int, line[:-1].split("\t")[:3])) for line in f.readlines()]
        )

    landmarks_insp[:, [0, 2]] = landmarks_insp[:, [2, 0]]
    landmarks_exp[:, [0, 2]] = landmarks_exp[:, [2, 0]]

    return (
        image_insp,
        image_exp,
        landmarks_insp,
        landmarks_exp,
        mask_exp,
        mask_insp,
        voxel_sizes[variation],
    )

def load_image_COPDgene(variation=1, folder="D:\Data\DIRLAB\Case"):
    # Size of data, per image pair
    image_sizes = [
        0,
        [121, 512, 512],
        [102, 512, 512],
        [126, 512, 512],
        [126, 512, 512],
        [131, 512, 512],
        [119, 512, 512],
        [112, 512, 512],
        [115, 512, 512],
        [116, 512, 512],
        [135, 512, 512],
    ]

    # Scale of data, per image pair
    voxel_sizes = [
        0,
        [2.5, 0.625, 0.625],
        [2.5, 0.645, 0.645],
        [2.5, 0.652, 0.652],
        [2.5, 0.59, 0.59],
        [2.5, 0.647, 0.647],
        [2.5, 0.633, 0.633],
        [2.5, 0.625, 0.625],
        [2.5, 0.586, 0.586],
        [2.5, 0.664, 0.664],
        [2.5, 0.742, 0.742],
    ]

    shape = image_sizes[variation]

    folder = folder + os.path.sep

    # Images
    dtype = np.dtype(np.int16)

    with open(folder + f"/copd{variation}/copd{variation}_" + "iBHCT.img", "rb") as f:
        data = np.fromfile(f, dtype)
    image_insp = data.reshape(shape)

    with open(folder + f"copd{variation}/copd{variation}_" + "eBHCT.img", "rb") as f:
        data = np.fromfile(f, dtype)
    image_exp = data.reshape(shape)

    imgsitk_in = sitk.ReadImage(folder + f"copd{variation}/copd{variation}_" + "iBHCT.mhd")

    mask = np.clip(sitk.GetArrayFromImage(imgsitk_in), 0, 1)

    image_insp = torch.FloatTensor(image_insp)
    image_exp = torch.FloatTensor(image_exp)

    # Landmarks
    with open(
        folder + f"copd{variation}/copd{variation}_" + "300_iBH_xyz_r1.txt"
    ) as f:
        landmarks_insp = np.array(
            [list(map(float, line[:-1].split()[:3])) for line in f.readlines()], dtype=np.int32
        )

    with open(
        folder + f"copd{variation}/copd{variation}_" + "300_eBH_xyz_r1.txt"
    ) as f:
        landmarks_exp = np.array(
            [list(map(float, line[:-1].split()[:3])) for line in f.readlines()], dtype=np.int32
        )

    landmarks_insp[:, [0, 2]] = landmarks_insp[:, [2, 0]]
    landmarks_exp[:, [0, 2]] = landmarks_exp[:, [2, 0]]

    return (
        image_insp,
        image_exp,
        landmarks_insp,
        landmarks_exp,
        mask,
        voxel_sizes[variation],
    )

def load_pair_oasis_2d(dataset_path, folders, i, j):
    def load_image(i):
        path = os.path.join(dataset_path, folders[i])
        image_path = os.path.join(path, 'slice_norm.nii.gz')
        nim1 = nib.load(image_path)
        image1 = nim1.get_fdata()[:, :, 0]
        image1 = np.array(image1, dtype='float32')

        return image1
    
    def load_label(i):
        path = os.path.join(dataset_path, folders[i])
        image_path = os.path.join(path, 'slice_seg24.nii.gz')
        nim1 = nib.load(image_path)
        label1 = nim1.get_fdata()[:, :, 0]
        label1 = np.array(label1, dtype='float32')
        return label1
    
    return load_image(i), load_image(j), load_label(i), load_label(j)

def load_pair_oasis_3d(dataset_path, folders, i, j):
    def load_image(i):
        path = os.path.join(dataset_path, folders[i])
        image_path = os.path.join(path, 'aligned_norm.nii.gz')
        nim1 = nib.load(image_path)
        image1 = nim1.get_fdata()
        image1 = np.array(image1, dtype='float32')

        return image1
    
    def load_label(i):
        path = os.path.join(dataset_path, folders[i])
        image_path = os.path.join(path, 'aligned_seg35.nii.gz')
        nim1 = nib.load(image_path)
        label1 = nim1.get_fdata()
        label1 = np.array(label1, dtype='float32')
        return label1
    
    return load_image(i), load_image(j), load_label(i), load_label(j)

def load_pair_acdc(dataset_path, patient_num, flag=0):
    
    metadata_path = os.path.join(dataset_path, f'patient{patient_num}', f'Info.cfg')
    i = 0
    ed_num = 0
    es_num = 0
    with open(metadata_path, 'r') as f:
        for line in f:
            idx = line.strip().split(' ')[-1]
            if i == 0:
                ed_num = idx
            else:
                es_num = idx
            i += 1
            if i == 2:
                break
        
    es_path = os.path.join(dataset_path, f'patient{patient_num}', f'patient{patient_num}_frame{es_num.zfill(2)}.nii')
    ed_path = os.path.join(dataset_path, f'patient{patient_num}', f'patient{patient_num}_frame{ed_num.zfill(2)}.nii')
    es_img = nib.load(es_path)
    header = es_img.header
    es_img = es_img.get_fdata()
    voxel_size = header.get_zooms()
    es_img = np.array(es_img, dtype=np.float32)
    ed_img = nib.load(ed_path).get_fdata()
    ed_img = np.array(ed_img, dtype=np.float32)

    es_gt_path = es_path.replace('.nii', '_gt.nii')
    ed_gt_path = ed_path.replace('.nii', '_gt.nii')
    es_gt = nib.load(es_gt_path).get_fdata()
    es_gt = np.array(es_gt, dtype=np.float32)
    ed_gt = nib.load(ed_gt_path).get_fdata()
    ed_gt = np.array(ed_gt, dtype=np.float32)

    if flag:
        fixed_img = ed_img
        moving_img = es_img
        fixed_labels = ed_gt
        moving_labels = es_gt
    else:
        fixed_img = es_img
        moving_img = ed_img
        fixed_labels = es_gt
        moving_labels = ed_gt
    return moving_img, fixed_img, moving_labels, fixed_labels, voxel_size
    
# def eval_single_pair_oasis_2d(imp_reg, output_shape, fixed_labels, moving_labels, device='cpu'):   
#     coords = make_coordinate_tensor_2d(output_shape).to(device)
#     network = imp_reg.network.to(device)
#     network.eval()
#     with torch.no_grad():
#         output = network(coords)
#         output = output.view([output_shape[0], output_shape[1], 2])
#         moving_labels = torch.from_numpy(moving_labels).to(device)
#         _, warped_labels = interp_full_grid(moving_labels[None, ...][None, ...].float(), output[None, ...].float(), mod='nearest')
#         warped_labels = warped_labels.squeeze(dim=[0, 1])
#         wl = warped_labels.detach().long().cpu().numpy()
#         mean_dice, dice_arr = compute_dice(wl, fixed_labels.astype('int64'))
#     return mean_dice, dice_arr

def eval_segmentation_accuracy(imp_reg, output_shape, fixed_labels, moving_labels, voxel_size=(1., 1., 1.,), device='cuda'):   
    coords = make_coordinate_tensor_3d(output_shape)
    network = imp_reg.network.to(device)
    network.eval()
    with torch.no_grad():
        output = torch.zeros_like(coords).to('cpu')
        bs = 500000
        i = 0
        N = coords.shape[0]
        while i < N:
            output[i:min(N, i + bs)] = network(coords[i:min(N, i + bs)].to(device), 0).detach().cpu()
            i += bs
        output = output.view([output_shape[0], output_shape[1], output_shape[2], 3])
        moving_labels = torch.from_numpy(moving_labels).to('cpu')
        _, warped_labels = interp_full_grid_3d(moving_labels[None, ...][None, ...].float(), output[None, ...].float(), mod='nearest')
        warped_labels = warped_labels.squeeze(dim=[0, 1])
        wl = warped_labels.detach().long().cpu().numpy()
        mean_dice, dice_arr = compute_dice(wl, fixed_labels.astype('int64'))
        hd95 = compute_hd95(wl, fixed_labels.astype('int64'), voxel_size)
    return mean_dice, dice_arr, hd95
  
def fast_trilinear_interpolation_2d(input_array, x_indices, y_indices):
    x_indices = (x_indices + 1) * (input_array.shape[0] - 1) * 0.5
    y_indices = (y_indices + 1) * (input_array.shape[1] - 1) * 0.5

    x0 = torch.floor(x_indices.detach()).to(torch.long)
    y0 = torch.floor(y_indices.detach()).to(torch.long)
    x1 = x0 + 1
    y1 = y0 + 1

    x0 = torch.clamp(x0, 0, input_array.shape[0] - 1)
    y0 = torch.clamp(y0, 0, input_array.shape[1] - 1)
    x1 = torch.clamp(x1, 0, input_array.shape[0] - 1)
    y1 = torch.clamp(y1, 0, input_array.shape[1] - 1)

    x = x_indices - x0
    y = y_indices - y0

    output = (
        input_array[x0, y0] * (1 - x) * (1 - y)
        + input_array[x1, y0] * x * (1 - y)
        + input_array[x0, y1] * (1 - x) * y
        + input_array[x1, y1] * x * y
    )
    return output

def fast_trilinear_interpolation_3d(input_array, x_indices, y_indices, z_indices):
    x_indices = (x_indices + 1) * (input_array.shape[0] - 1) * 0.5
    y_indices = (y_indices + 1) * (input_array.shape[1] - 1) * 0.5
    z_indices = (z_indices + 1) * (input_array.shape[2] - 1) * 0.5

    x0 = torch.floor(x_indices.detach()).to(torch.long)
    y0 = torch.floor(y_indices.detach()).to(torch.long)
    z0 = torch.floor(z_indices.detach()).to(torch.long)
    x1 = x0 + 1
    y1 = y0 + 1
    z1 = z0 + 1

    x0 = torch.clamp(x0, 0, input_array.shape[0] - 1)
    y0 = torch.clamp(y0, 0, input_array.shape[1] - 1)
    z0 = torch.clamp(z0, 0, input_array.shape[2] - 1)
    x1 = torch.clamp(x1, 0, input_array.shape[0] - 1)
    y1 = torch.clamp(y1, 0, input_array.shape[1] - 1)
    z1 = torch.clamp(z1, 0, input_array.shape[2] - 1)

    x = x_indices - x0
    y = y_indices - y0
    z = z_indices - z0

    output = (
        input_array[x0, y0, z0] * (1 - x) * (1 - y) * (1 - z)
        + input_array[x1, y0, z0] * x * (1 - y) * (1 - z)
        + input_array[x0, y1, z0] * (1 - x) * y * (1 - z)
        + input_array[x0, y0, z1] * (1 - x) * (1 - y) * z
        + input_array[x1, y0, z1] * x * (1 - y) * z
        + input_array[x0, y1, z1] * (1 - x) * y * z
        + input_array[x1, y1, z0] * x * y * (1 - z)
        + input_array[x1, y1, z1] * x * y * z
    )
    return output

def interp_full_grid(mov_image, flow, mod = 'bilinear'):
    h2, w2 = mov_image.shape[-2:]
    grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)])
    grid_h = grid_h.to(flow.device).float()
    grid_w = grid_w.to(flow.device).float()
    grid_w = nn.Parameter(grid_w, requires_grad=False)
    grid_h = nn.Parameter(grid_h, requires_grad=False)
    flow_h = flow[:,:,:,0]
    flow_w = flow[:,:,:,1]
    disp_h = (grid_h + (flow_h)).squeeze(1)
    disp_w = (grid_w + (flow_w)).squeeze(1)
    sample_grid = torch.stack((disp_w, disp_h), 3)
    warped = torch.nn.functional.grid_sample(mov_image, sample_grid, mode = mod, align_corners = True)
        
    return sample_grid, warped

def interp_full_grid_3d(mov_image, flow, mod = 'bilinear'):
    d2, h2, w2 = mov_image.shape[-3:]
    grid_d, grid_h, grid_w = torch.meshgrid([torch.linspace(-1, 1, d2), torch.linspace(-1, 1, h2), torch.linspace(-1, 1, w2)])
    grid_h = grid_h.to(flow.device).float()
    grid_d = grid_d.to(flow.device).float()
    grid_w = grid_w.to(flow.device).float()
    grid_d = nn.Parameter(grid_d, requires_grad=False)
    grid_w = nn.Parameter(grid_w, requires_grad=False)
    grid_h = nn.Parameter(grid_h, requires_grad=False)
    flow_d = flow[:,:,:,:,0]
    flow_h = flow[:,:,:,:,1]
    flow_w = flow[:,:,:,:,2]
    
    # Remove Channel Dimension
    disp_d = (grid_d + (flow_d)).squeeze(1)
    disp_h = (grid_h + (flow_h)).squeeze(1)
    disp_w = (grid_w + (flow_w)).squeeze(1)
    sample_grid = torch.stack((disp_w, disp_h, disp_d), 4)  # shape (N, D, H, W, 3)
    warped = torch.nn.functional.grid_sample(mov_image, sample_grid, mode = mod, align_corners = True)
        
    return sample_grid, warped

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def make_coordinate_slice_2d(dims=(28, 28), dimension=0, slice_pos=0):
    """Make a coordinate tensor."""

    dims = list(dims)
    dims.insert(dimension, 1)

    coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(2)]
    coordinate_tensor[dimension] = torch.linspace(slice_pos, slice_pos, 1)
    coordinate_tensor = torch.meshgrid(*coordinate_tensor)
    coordinate_tensor = torch.stack(coordinate_tensor, dim=2)
    coordinate_tensor = coordinate_tensor.view([np.prod(dims), 2])

    coordinate_tensor = coordinate_tensor.cuda()

    return coordinate_tensor


def make_coordinate_slice_3d(dims=(28, 28, 28), dimension=0, slice_pos=0):
    """Make a coordinate tensor."""

    dims = list(dims)
    dims.insert(dimension, 1)

    coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(3)]
    coordinate_tensor[dimension] = torch.linspace(slice_pos, slice_pos, 1)
    coordinate_tensor = torch.meshgrid(*coordinate_tensor)
    coordinate_tensor = torch.stack(coordinate_tensor, dim=3)
    coordinate_tensor = coordinate_tensor.view([np.prod(dims), 3])

    coordinate_tensor = coordinate_tensor.cuda()

    return coordinate_tensor


def make_coordinate_tensor_2d(dims=(28, 28, 28)):
    """Make a coordinate tensor."""

    coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(2)]
    coordinate_tensor = torch.meshgrid(*coordinate_tensor)
    coordinate_tensor = torch.stack(coordinate_tensor, dim=2)
    coordinate_tensor = coordinate_tensor.view([np.prod(dims), 2])

    coordinate_tensor = coordinate_tensor.cuda()

    return coordinate_tensor

def make_coordinate_tensor_3d(dims=(28, 28, 28)):
    """Make a coordinate tensor."""

    coordinate_tensor = [torch.linspace(-1, 1, dims[i]) for i in range(3)]
    coordinate_tensor = torch.meshgrid(*coordinate_tensor)
    coordinate_tensor = torch.stack(coordinate_tensor, dim=3)
    coordinate_tensor = coordinate_tensor.view([np.prod(dims), 3])

    coordinate_tensor = coordinate_tensor.cuda()

    return coordinate_tensor

def make_masked_coordinate_tensor_2d(mask, dims=(28, 28), stride_2d=1):
    """Make a coordinate tensor."""

    dims = torch.tensor(dims)
    init_step = 2 / (dims - 1)
    step = init_step * stride_2d
    coordinate_tensor = [torch.arange(start=-1, end=1 + 1e-6, step=step[i]) for i in range(2)]
    new_dims = np.array([x.shape[0] for x in coordinate_tensor])
    coordinate_tensor = torch.meshgrid(*coordinate_tensor)
    coordinate_tensor = torch.stack(coordinate_tensor, dim=2)
    coordinate_tensor = coordinate_tensor.view([np.prod(new_dims), 2])
    mask = mask[::stride_2d, ::stride_2d]
    # print(mask.shape, new_dims)
    coordinate_tensor = coordinate_tensor[mask.flatten() > 0, :]

    coordinate_tensor = coordinate_tensor.cuda()

    return coordinate_tensor

STRIDE = 1
def make_masked_coordinate_tensor_3d(mask, dims=(28, 28, 28)):
    """Make a coordinate tensor."""

    coordinate_tensor = [torch.linspace(-1, 1, dims[i] // STRIDE + (dims[i] % STRIDE > 0)) for i in range(3)]
    new_dims = np.array([x.shape[0] for x in coordinate_tensor])
    coordinate_tensor = torch.meshgrid(*coordinate_tensor)
    coordinate_tensor = torch.stack(coordinate_tensor, dim=3)
    coordinate_tensor = coordinate_tensor.view([np.prod(new_dims), 3])
    mask = mask[::STRIDE, ::STRIDE, ::STRIDE]
    coordinate_tensor = coordinate_tensor[mask.flatten() > 0, :]

    coordinate_tensor = coordinate_tensor.cuda()

    return coordinate_tensor

def make_masked_feature_tensor_3d(mask, feat, dims=(28, 28, 28)):
    """Make a coordinate tensor."""
    feat_tensor = feat.view([np.prod(dims), 16])
    feat_tensor = feat_tensor[mask.flatten() > 0, :]

    feat_tensor = feat_tensor.cuda()

    return feat_tensor

def get_gpu_used_memory():
    gpus = GPUtil.getGPUs()
    gpu_num = os.environ['CUDA_VISIBLE_DEVICES']
    gpu_num = int(gpu_num)
    gpu = gpus[gpu_num]
    return gpu.memoryUsed


