"""
Helper functions from https://github.com/zhangjun001/ICNet.

Some functions has been modified.
"""

import numpy as np
import torch.utils.data as Data
import nibabel as nib
import torch
import itertools
from scipy.ndimage import zoom

import matplotlib.pyplot as plt

def generate_grid(imgshape):
    x = np.arange(imgshape[0])
    y = np.arange(imgshape[1])
    z = np.arange(imgshape[2])
    grid = np.rollaxis(np.array(np.meshgrid(z, y, x)), 0, 4)
    grid = np.swapaxes(grid,0,2)
    grid = np.swapaxes(grid,1,2)
    return grid


def generate_grid_unit(imgshape):
    x = (np.arange(imgshape[0]) - ((imgshape[0]-1)/2)) / (imgshape[0]-1) * 2
    y = (np.arange(imgshape[1]) - ((imgshape[1]-1)/2)) / (imgshape[1]-1) * 2
    z = (np.arange(imgshape[2]) - ((imgshape[2]-1)/2)) / (imgshape[2]-1) * 2
    grid = np.rollaxis(np.array(np.meshgrid(z, y, x)), 0, 4)
    grid = np.swapaxes(grid,0,2)
    grid = np.swapaxes(grid,1,2)
    return grid


def transform_unit_flow_to_flow(flow):
    x, y, z, _ = flow.shape
    flow[:, :, :, 0] = flow[:, :, :, 0] * (z-1)
    flow[:, :, :, 1] = flow[:, :, :, 1] * (y-1)
    flow[:, :, :, 2] = flow[:, :, :, 2] * (x-1)

    return flow


def transform_unit_flow_to_flow_cuda(flow):
    b, x, y, z, c = flow.shape
    flow[:, :, :, :, 0] = flow[:, :, :, :, 0] * (z-1)
    flow[:, :, :, :, 1] = flow[:, :, :, :, 1] * (y-1)
    flow[:, :, :, :, 2] = flow[:, :, :, :, 2] * (x-1)

    return flow


def load_4D(name):
    X = nib.load(name)
    X = X.get_fdata()
    X = np.reshape(X, (1,) + X.shape)
    return X


def crop_center(img, cropx, cropy, cropz):
    x, y, z = img.shape
    startx = x//2 - cropx//2
    starty = y//2 - cropy//2
    startz = z//2 - cropz//2
    return img[startx:startx+cropx, starty:starty+cropy, startz:startz+cropz]


def load_4D_with_crop(name, cropx, cropy, cropz):
    X = nib.load(name)
    X = X.get_fdata()

    x, y, z = X.shape
    startx = x//2 - cropx//2
    starty = y//2 - cropy//2
    startz = z//2 - cropz//2

    X = X[startx:startx+cropx, starty:starty+cropy, startz:startz+cropz]

    X = np.reshape(X, (1,) + X.shape)
    return X


def load_4D_with_header(name):
    X = nib.load(name)
    X_npy = X.get_fdata()
    X_npy = np.reshape(X_npy, (1,) + X_npy.shape)
    return X_npy, X.header, X.affine


def load_5D(name):
    X = fixed_nii = nib.load(name)
    X = X.get_fdata()
    X = np.reshape(X, (1,)+(1,)+ X.shape)
    return X


def imgnorm(img):
    i_max = np.max(img)
    i_min = np.min(img)
    norm = (img - i_min)/(i_max - i_min)
    return norm


def save_img(I_img,savename,header=None,affine=None):
    if header is None or affine is None:
        affine = np.diag([1, 1, 1, 1])
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=None)
    else:
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=header)

    nib.save(new_img, savename)


def save_flow(I_img,savename):
    affine = np.diag([1, 1, 1, 1])
    new_img = nib.nifti1.Nifti1Image(I_img, affine, header=None)
    nib.save(new_img, savename)


class Dataset_epoch(Data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, names, labels=None, norm=False):
        'Initialization'
        super(Dataset_epoch, self).__init__()
        self.names = names
        self.labels = labels
        self.norm = norm
        self.index_pair = list(itertools.permutations(names, 2))
        self.label_pair = list(itertools.permutations(labels, 2)) if labels is not None else None

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.index_pair)

  def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        img_A = load_4D(self.index_pair[step][0])
        img_B = load_4D(self.index_pair[step][1])

        if self.norm:
            ret_A, ret_B = torch.from_numpy(imgnorm(img_A)).float(), torch.from_numpy(imgnorm(img_B)).float()
        else:
            ret_A, ret_B = torch.from_numpy(img_A).float(), torch.from_numpy(img_B).float()

        # img_A = zoom(img_A, (1, 0.5, 0.5, 0.5), order=0)
        # img_B = zoom(img_B, (1, 0.5, 0.5, 0.5), order=0)
        if self.labels is None:
            return ret_A, ret_B
        else:
            lab_A = load_4D(self.label_pair[step][0])
            lab_B = load_4D(self.label_pair[step][1])
            return ret_A, ret_B, torch.from_numpy(lab_A).float(), torch.from_numpy(lab_B).float()


class Dataset_epoch_crop(Data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, names, labels=None, norm=False):
        'Initialization'
        super(Dataset_epoch_crop, self).__init__()
        self.names = names
        self.norm = norm
        self.index_pair = list(itertools.permutations(names, 2))

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.index_pair)

  def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        img_A = load_4D_with_crop(self.index_pair[step][0], cropx=160, cropy=144, cropz=192)
        img_B = load_4D_with_crop(self.index_pair[step][1], cropx=160, cropy=144, cropz=192)
        # img_A = zoom(img_A, (1, 0.5, 0.5, 0.5), order=0)
        # img_B = zoom(img_B, (1, 0.5, 0.5, 0.5), order=0)

        if self.norm:
            return torch.from_numpy(imgnorm(img_A)).float(), torch.from_numpy(imgnorm(img_B)).float()
        else:
            return torch.from_numpy(img_A).float(), torch.from_numpy(img_B).float()


class Predict_dataset_crop(Data.Dataset):
    def __init__(self, fixed_list, move_list, fixed_label_list, move_label_list, norm=False):
        super(Predict_dataset_crop, self).__init__()
        self.fixed_list = fixed_list
        self.move_list = move_list
        self.fixed_label_list = fixed_label_list
        self.move_label_list = move_label_list
        self.norm = norm

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.move_list)

    def __getitem__(self, index):
        fixed_img = load_4D_with_crop(self.fixed_list, cropx=160, cropy=144, cropz=192)
        moved_img = load_4D_with_crop(self.move_list[index], cropx=160, cropy=144, cropz=192)
        fixed_label = load_4D_with_crop(self.fixed_label_list, cropx=160, cropy=144, cropz=192)
        moved_label = load_4D_with_crop(self.move_label_list[index], cropx=160, cropy=144, cropz=192)

        if self.norm:
            fixed_img = imgnorm(fixed_img)
            moved_img = imgnorm(moved_img)

        fixed_img = torch.from_numpy(fixed_img)
        moved_img = torch.from_numpy(moved_img)
        fixed_label = torch.from_numpy(fixed_label)
        moved_label = torch.from_numpy(moved_label)

        output = {'fixed': fixed_img.float(), 'move': moved_img.float(),
                  'fixed_label': fixed_label.float(), 'move_label': moved_label.float(), 'index': index}
        return output


class Predict_dataset(Data.Dataset):
    def __init__(self, fixed_list, move_list, fixed_label_list, move_label_list, norm=False):
        super(Predict_dataset, self).__init__()
        self.fixed_list = fixed_list
        self.move_list = move_list
        self.fixed_label_list = fixed_label_list
        self.move_label_list = move_label_list
        self.norm = norm

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.move_list)

    def __getitem__(self, index):
        fixed_img = load_4D(self.fixed_list)
        moved_img = load_4D(self.move_list[index])
        fixed_label = load_4D(self.fixed_label_list)
        moved_label = load_4D(self.move_label_list[index])

        if self.norm:
            fixed_img = imgnorm(fixed_img)
            moved_img = imgnorm(moved_img)

        fixed_img = torch.from_numpy(fixed_img)
        moved_img = torch.from_numpy(moved_img)
        fixed_label = torch.from_numpy(fixed_label)
        moved_label = torch.from_numpy(moved_label)

        output = {'fixed': fixed_img.float(), 'move': moved_img.float(),
                  'fixed_label': fixed_label.float(), 'move_label': moved_label.float(), 'index': index}
        return output


