import os
import nibabel as nib
import numpy as np
import torch
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import axes3d

def vector2tensor(vector):
    vector_size = vector.shape
    if vector.shape[-1] == 6:
        output = torch.ones((vector_size[0], vector_size[1], vector_size[2], 3, 3))
        if vector.is_cuda:
            output = output.cuda()
        output[:,:,:,0,0] = vector[:,:,:,0]
        output[:,:,:,0,1] = vector[:,:,:,1]
        output[:,:,:,1,0] = vector[:,:,:,1]
        output[:,:,:,0,2] = vector[:,:,:,2]
        output[:,:,:,2,0] = vector[:,:,:,2]
        output[:,:,:,1,1] = vector[:,:,:,3]
        output[:,:,:,1,2] = vector[:,:,:,4]
        output[:,:,:,2,1] = vector[:,:,:,4]
        output[:,:,:,2,2] = vector[:,:,:,5]
        return output
    elif vector.shape[-1] == 3:
        output = torch.ones((vector_size[0], vector_size[1], vector_size[2], 2, 2))
        if vector.is_cuda:
            output = output.cuda()
        output[:,:,:,0,0] = vector[:,:,:,0]
        output[:,:,:,0,1] = vector[:,:,:,1]
        output[:,:,:,1,0] = vector[:,:,:,1]
        output[:,:,:,1,1] = vector[:,:,:,2]
        return output
    return

def tensor2vector(tensor):
    tensor_size = tensor.shape
    if tensor.shape[-1] == 3:
        output = torch.ones((tensor_size[0], tensor_size[1], tensor_size[2], 6))
        if tensor.is_cuda:
            output = output.cuda()
        output[:,:,:,0] = tensor[:,:,:,0,0]
        output[:,:,:,1] = tensor[:,:,:,0,1]
        output[:,:,:,2] = tensor[:,:,:,0,2]
        output[:,:,:,3] = tensor[:,:,:,1,1]
        output[:,:,:,4] = tensor[:,:,:,1,2]
        output[:,:,:,5] = tensor[:,:,:,2,2]
        return output
    elif tensor.shape[-1] == 2:
        output = torch.ones((tensor_size[0], tensor_size[1], tensor_size[2], 3))
        if tensor.is_cuda:
            output = output.cuda()
        output[:,:,:,0] = tensor[:,:,:,0,0]
        output[:,:,:,1] = tensor[:,:,:,0,1]
        output[:,:,:,2] = tensor[:,:,:,1,1]
        return output
    return

def vector2tensor_1dim(vector):
    #print(vector)
    vector_size = vector.shape
    if vector.shape[-1] == 6:
        output = torch.ones((vector_size[0], 3, 3))
        if vector.is_cuda:
            output = output.cuda()
        output[:,0,0] = vector[:,0]
        output[:,0,1] = vector[:,1]
        output[:,1,0] = vector[:,1]
        output[:,0,2] = vector[:,2]
        output[:,2,0] = vector[:,2]
        output[:,1,1] = vector[:,3]
        output[:,1,2] = vector[:,4]
        output[:,2,1] = vector[:,4]
        output[:,2,2] = vector[:,5]
        return output
    elif vector.shape[-1] == 3:
        output = torch.ones((vector_size[0], 2, 2))
        if vector.is_cuda:
            output = output.cuda()
        output[:,0,0] = vector[:,0]
        output[:,0,1] = vector[:,1]
        output[:,1,0] = vector[:,1]
        output[:,1,1] = vector[:,2]
        return output
    return

def tensor2vector_1dim(tensor):
    tensor_size = tensor.shape
    if tensor.shape[-1] == 3:
        output = torch.ones((tensor_size[0], 6))
        if tensor.is_cuda:
            output = output.cuda()
        output[:,0] = tensor[:,0,0]
        output[:,1] = tensor[:,0,1]
        output[:,2] = tensor[:,0,2]
        output[:,3] = tensor[:,1,1]
        output[:,4] = tensor[:,1,2]
        output[:,5] = tensor[:,2,2]
        return output
    elif tensor.shape[-1] == 2:
        output = torch.ones((tensor_size[0], 3))
        if tensor.is_cuda:
            output = output.cuda()
        output[:,0] = tensor[:,0,0]
        output[:,1] = tensor[:,0,1]
        output[:,2] = tensor[:,1,1]
        return output
    return

def Log2Log_vec(data):
    data_shape = data.shape
    if len(data_shape) == 4:
        tempdata = data.clone()
        if data_shape[-1] == 6:
            tempdata[:,:,:,1] *= np.sqrt(2)
            tempdata[:,:,:,2] *= np.sqrt(2)
            tempdata[:,:,:,4] *= np.sqrt(2)
        elif data_shape[-1] == 3:
            tempdata[:,:,:,1] *= np.sqrt(2)
        return tempdata
    elif len(data_shape) == 2:
        tempdata = data.clone()
        if data_shape[-1] == 6:
            tempdata[:,1] *= np.sqrt(2)
            tempdata[:,2] *= np.sqrt(2)
            tempdata[:,4] *= np.sqrt(2)
        elif data_shape[-1] == 3:
            tempdata[:,1] *= np.sqrt(2)
        return tempdata
    return 

def Log_vec2Log(data):
    data_shape = data.shape
    if len(data_shape) == 4:
        tempdata = data.clone()
        if data_shape[-1] == 6:
            tempdata[:,:,:,1] /= np.sqrt(2)
            tempdata[:,:,:,2] /= np.sqrt(2)
            tempdata[:,:,:,4] /= np.sqrt(2)
        elif data_shape[-1] == 3:
            tempdata[:,:,:,1] /= np.sqrt(2)
        return tempdata
    elif len(data_shape) == 2:
        tempdata = data.clone()
        if data_shape[-1] == 6:
            tempdata[:,1] /= np.sqrt(2)
            tempdata[:,2] /= np.sqrt(2)
            tempdata[:,4] /= np.sqrt(2)
        elif data_shape[-1] == 3:
            tempdata[:,1] /= np.sqrt(2)
        return tempdata
    return

def getClassNum(labelname, labelset, labelindex):
    for i, label in enumerate(labelset):
        if label == labelname:
            return labelindex[i]
    return -1

def getClassNum_old(labelname, use_4class):
    # if 'use_4class' is a list, output the index of 'use_4class (it is now labelset)' that matches the 'labelname'
    if isinstance(use_4class, list):
        for i, label in enumerate(use_4class):
            if label == labelname:
                return i
        return -1
    else:
        if use_4class:
            if labelname == 'AD':
                cur_label = 0
            elif labelname == 'CN':
                cur_label = 3
            elif labelname == 'EMCI':
                cur_label = 2
            elif labemname == 'LMCI':
                cur_label = 1
        else:
            if labelname == 'AD':
                cur_label = 0
            elif labelname == 'CN':
                cur_label = 2
            else:
                cur_label = 1
    return cur_label

def getDataRangeFromCenter(data_dims, random_range = None):
    if random_range is None:
        randx = 0
        randy = 0
        randz = 0
    else:
        randx = np.random.randint(random_range[0]) - int(random_range[0]/2)
        randy = np.random.randint(random_range[1]) - int(random_range[1]/2)
        randz = np.random.randint(random_range[2]) - int(random_range[2]/2)
    dx1 = randx - int(data_dims[0]/2)
    dx2 = randx + int(data_dims[0]/2)
    dy1 = randy - int(data_dims[1]/2)
    dy2 = randy + int(data_dims[1]/2)
    dz1 = randz - int(data_dims[2]/2)
    dz2 = randz + int(data_dims[2]/2)
    if dx2 - dx1 != data_dims[0]:
        dx1 -= 1
    if dy2 - dy1 != data_dims[1]:
        dy1 -= 1
    if dz2 - dz1 != data_dims[2]:
        dz1 -= 1
    return dx1, dx2, dy1, dy2, dz1, dz2

def getDataRangeFromCenter_insideEllipse(data_dims, center_voxel, ellipse_size):
    dist = 100
    eff_size = [min(ellipse_size[0]-center_voxel[0], center_voxel[0])*2, 
               min(ellipse_size[1]-center_voxel[1], center_voxel[1])*2, 
               min(ellipse_size[2]-center_voxel[2], center_voxel[2])*2]
    a = (eff_size[0] - data_dims[0])/2
    b = (eff_size[1] - data_dims[1])/2
    c = (eff_size[2] - data_dims[2])/2
    while dist > 1:
        x = np.random.randint(int(2*a)) - int(a)
        y = np.random.randint(int(2*b)) - int(b)
        z = np.random.randint(int(2*c)) - int(c)
        dist = (x/a)**2 + (y/b)**2 + (z/c)**2
    dx1 = x - int(data_dims[0]/2)
    dx2 = x + int(data_dims[0]/2)
    dy1 = y - int(data_dims[1]/2)
    dy2 = y + int(data_dims[1]/2)
    dz1 = z - int(data_dims[2]/2)
    dz2 = z + int(data_dims[2]/2)
    if dx2 - dx1 != data_dims[0]:
        dx1 -= 1
    if dy2 - dy1 != data_dims[1]:
        dy1 -= 1
    if dz2 - dz1 != data_dims[2]:
        dz1 -= 1
    return dx1, dx2, dy1, dy2, dz1, dz2

def cropImage(data, center_voxel, dx1, dx2, dy1, dy2, dz1, dz2, normalize):
    if normalize:
        temp = (data[center_voxel[0]+dx1:center_voxel[0]+dx2, 
                     center_voxel[1]+dy1:center_voxel[1]+dy2, 
                     center_voxel[2]+dz1:center_voxel[2]+dz2]
                    - torch.mean(data)) / (torch.std(data) + 1e-5)
    else:
        temp = data[center_voxel[0]+dx1:center_voxel[0]+dx2, 
                    center_voxel[1]+dy1:center_voxel[1]+dy2, 
                    center_voxel[2]+dz1:center_voxel[2]+dz2]
    return temp

def randomFlip3d_xdir(data, is_tensor = True, is_forced = False):
    # data shape (C,D,H,W)
    if is_forced or np.random.randint(2) == 1:
        # flip
        temp = torch.flip(data, [1])
        if is_tensor:
            temp[1] = -temp[1]
            temp[2] = -temp[2]
        return temp
    # do not flip
    return data

def randomFlip3d_ydir(data, is_tensor = True, is_forced = False):
    # data shape (C,D,H,W)
    if is_forced or np.random.randint(2) == 1:
        #flip
        temp = torch.flip(data, [2])
        if is_tensor:
            temp[1] = -temp[1]
            temp[4] = -temp[4]
        return temp
    # do not flip
    return data
    

def randomFlip3d_zdir(data, is_tensor = True, is_forced = False):
    # data shape (C,D,H,W)
    if is_forced or np.random.randint(2) == 1:
        # flip
        temp = torch.flip(data, [3])
        if is_tensor:
            temp[2] = -temp[2]
            temp[4] = -temp[4]
        return temp
    # do not flip
    return data

def randomRotX3d_cubicInput(data, is_tensor = True, angle = None):
    # data shape (C,D,H,W)
    randIdx = np.random.randint(4)
    if (randIdx == 1 and angle is None) or angle == 90:
        # rotate 90 deg
        temp = data.permute(0,1,3,2)
        temp = torch.flip(temp,[3])
        if is_tensor:
            idx = [0,2,1,5,4,3]
            temp = temp[idx]
            temp[1] = -temp[1]
            temp[4] = -temp[4]
        return temp
    elif (randIdx == 2 and angle is None) or angle == 180:
        # rotate 180 deg
        temp = torch.flip(data, [2,3])
        if is_tensor:
            temp[1] = -temp[1]
            temp[2] = -temp[2]
        return temp
    elif (randIdx == 3 and angle is None) or angle == 270:
        # rotate 270 deg
        temp = data.permute(0,1,3,2)
        temp = torch.flip(temp,[2])
        if is_tensor:
            idx = [0,2,1,5,4,3]
            temp = temp[idx]
            temp[2] = -temp[2]
            temp[4] = -temp[4]
        return temp
    return data

def randomRotY3d_cubicInput(data, is_tensor = True, angle = None):
    # data shape (C,D,H,W)
    randIdx = np.random.randint(4)
    if (randIdx == 1 and angle is None) or angle == 90:
        # rotate 90 deg
        temp = data.permute(0,3,2,1)
        temp = torch.flip(temp,[1])
        if is_tensor:
            idx = [5,4,2,3,1,0]
            temp = temp[idx]
            temp[2] = -temp[2]
            temp[4] = -temp[4]
        return temp
    elif (randIdx == 2 and angle is None) or angle == 180:
        # rotate 180 deg
        temp = torch.flip(data, [1,3])
        if is_tensor:
            temp[1] = -temp[1]
            temp[4] = -temp[4]
        return temp
    elif (randIdx == 3 and angle is None) or angle == 270:
        # rotate 270 deg
        temp = data.permute(0,3,2,1)
        temp = torch.flip(temp,[3])
        if is_tensor:
            idx = [5,4,2,3,1,0]
            temp = temp[idx]
            temp[1] = -temp[1]
            temp[2] = -temp[2]
        return temp
    return data

def randomRotZ3d_cubicInput(data, is_tensor = True, angle = None):
    # data shape (C,D,H,W)
    randIdx = np.random.randint(4)
    if (randIdx == 1 and angle is None) or angle == 90:
        # rotate 90 deg
        temp = data.permute(0,2,1,3)
        temp = torch.flip(temp,[2])
        if is_tensor:
            idx = [3,1,4,0,2,5]
            temp = temp[idx]
            temp[1] = -temp[1]
            temp[2] = -temp[2]
        return temp
    elif (randIdx == 2 and angle is None) or angle == 180:
        # rotate 180 deg
        temp = torch.flip(data, [1,2])
        if is_tensor:
            temp[2] = -temp[2]
            temp[4] = -temp[4]
        return temp
    elif (randIdx == 3 and angle is None) or angle == 270:
        # rotate 270 deg
        temp = data.permute(0,2,1,3)
        temp = torch.flip(temp,[1])
        if is_tensor:
            idx = [3,1,4,0,2,5]
            temp = temp[idx]
            temp[1] = -temp[1]
            temp[4] = -temp[4]
        return temp
    return data

def affineTransform3d_cuda(data, mat, device = torch.device('cuda')):
    ### only consider 3D or 4D inputs
    ### data and mat are cuda tensors
    x, y, z = torch.meshgrid([torch.arange(0,data.shape[0]), torch.arange(0,data.shape[1]), torch.arange(0,data.shape[2])])
    X = x.contiguous().view(1,-1).float()
    Y = y.contiguous().view(1,-1).float()
    Z = z.contiguous().view(1,-1).float()
    O = torch.ones(1,Z.shape[1])
    P = torch.cat((X,Y,Z,O), 0).cuda(device = device)
    AP = torch.mm(mat, P)
    
    weights_c = AP[0:3] - torch.floor(AP[0:3])
    weights_f = torch.ones(weights_c.shape, dtype=torch.float32, device=device) - weights_c
    
    index = torch.zeros((8,3,X.shape[1]), dtype=torch.long, device=device)
    index[0] = torch.floor(AP[0:3]).long()
    for i in range(3):
        weights_c[:,index[0,i,:]<-1] = 0
        weights_c[:,index[0,i,:]>data.shape[i]-2] = 0
        weights_f[:,index[0,i,:]<0] = 0
        weights_f[:,index[0,i,:]>data.shape[i]-1] = 0
    
    weights = torch.zeros(8, weights_f.shape[1], dtype=torch.float32, device=device)
    weights[0] = weights_f[0,:]*weights_f[1,:]*weights_f[2,:]
    weights[1] = weights_c[0,:]*weights_f[1,:]*weights_f[2,:]
    weights[2] = weights_f[0,:]*weights_c[1,:]*weights_f[2,:]
    weights[3] = weights_c[0,:]*weights_c[1,:]*weights_f[2,:]
    weights[4] = weights_f[0,:]*weights_f[1,:]*weights_c[2,:]
    weights[5] = weights_c[0,:]*weights_f[1,:]*weights_c[2,:]
    weights[6] = weights_f[0,:]*weights_c[1,:]*weights_c[2,:]
    weights[7] = weights_c[0,:]*weights_c[1,:]*weights_c[2,:]
    
    index[1] = index[0].clone()
    index[1,0,:] += 1
    index[2] = index[0].clone()
    index[2,1,:] += 1
    index[3] = index[0].clone()
    index[3,0,:] += 1
    index[3,1,:] += 1
    index[4] = index[0].clone()
    index[4,2,:] += 1
    index[5] = index[1].clone()
    index[5,2,:] += 1
    index[6] = index[2].clone()
    index[6,2,:] += 1
    index[7] = index[3].clone()
    index[7,2,:] += 1
    for j in range(8):
        index[j,0,index[j,0,:]<0] = 0
        index[j,0,index[j,0,:]>data.shape[0]-1] = data.shape[0]-1
        index[j,1,index[j,1,:]<0] = 0
        index[j,1,index[j,1,:]>data.shape[1]-1] = data.shape[1]-1
        index[j,2,index[j,2,:]<0] = 0
        index[j,2,index[j,2,:]>data.shape[2]-1] = data.shape[2]-1
            
    if len(data.shape) == 3:
        output = torch.zeros(X.shape[1], dtype=torch.float32, device=device)
        for i in range(8):
            output += weights[i]*data[index[i,0], index[i,1], index[i,2]]
        return output.view(data.shape[0], data.shape[1], data.shape[2])
    else:
        output = torch.zeros((X.shape[1], data.shape[3]), dtype=torch.float32, device=device)
        for i in range(8):
            output += weights[i].view(-1,1)*data[index[i,0], index[i,1], index[i,2]]
        return output.view(data.shape[0], data.shape[1], data.shape[2],-1)

def affineTransform3d(data, mat):
    ### only consider 3D or 4D inputs
    x, y, z = torch.meshgrid([torch.arange(0,data.shape[0]), torch.arange(0,data.shape[1]), torch.arange(0,data.shape[2])])
    X = x.contiguous().view(1,-1).float()
    Y = y.contiguous().view(1,-1).float()
    Z = z.contiguous().view(1,-1).float()
    O = torch.ones(1,Z.shape[1])
    P = torch.cat((X,Y,Z,O), 0)
    AP = torch.mm(mat, P)
    
    weights_c = AP[0:3] - torch.floor(AP[0:3])
    weights_f = torch.ones(weights_c.shape) - weights_c
    
    index = torch.zeros(8,3,X.shape[1]).long()
    index[0] = torch.floor(AP[0:3]).long()
    for i in range(3):
        weights_c[:,index[0,i,:]<-1] = 0
        weights_c[:,index[0,i,:]>data.shape[i]-2] = 0
        weights_f[:,index[0,i,:]<0] = 0
        weights_f[:,index[0,i,:]>data.shape[i]-1] = 0
    
    weights = torch.zeros(8, weights_f.shape[1])
    weights[0] = weights_f[0,:]*weights_f[1,:]*weights_f[2,:]
    weights[1] = weights_c[0,:]*weights_f[1,:]*weights_f[2,:]
    weights[2] = weights_f[0,:]*weights_c[1,:]*weights_f[2,:]
    weights[3] = weights_c[0,:]*weights_c[1,:]*weights_f[2,:]
    weights[4] = weights_f[0,:]*weights_f[1,:]*weights_c[2,:]
    weights[5] = weights_c[0,:]*weights_f[1,:]*weights_c[2,:]
    weights[6] = weights_f[0,:]*weights_c[1,:]*weights_c[2,:]
    weights[7] = weights_c[0,:]*weights_c[1,:]*weights_c[2,:]
    
    index[1] = index[0].clone()
    index[1,0,:] += 1
    index[2] = index[0].clone()
    index[2,1,:] += 1
    index[3] = index[0].clone()
    index[3,0,:] += 1
    index[3,1,:] += 1
    index[4] = index[0].clone()
    index[4,2,:] += 1
    index[5] = index[1].clone()
    index[5,2,:] += 1
    index[6] = index[2].clone()
    index[6,2,:] += 1
    index[7] = index[3].clone()
    index[7,2,:] += 1
    for j in range(8):
        index[j,0,index[j,0,:]<0] = 0
        index[j,0,index[j,0,:]>data.shape[0]-1] = data.shape[0]-1
        index[j,1,index[j,1,:]<0] = 0
        index[j,1,index[j,1,:]>data.shape[1]-1] = data.shape[1]-1
        index[j,2,index[j,2,:]<0] = 0
        index[j,2,index[j,2,:]>data.shape[2]-1] = data.shape[2]-1
            
    if len(data.shape) == 3:
        output = torch.zeros(X.shape[1])
        for i in range(8):
            output += weights[i]*data[index[i,0], index[i,1], index[i,2]]
        return output.view(data.shape[0], data.shape[1], data.shape[2])
    else:
        output = torch.zeros(X.shape[1], data.shape[3])
        for i in range(8):
            output += weights[i].view(-1,1)*data[index[i,0], index[i,1], index[i,2]]
        return output.view(data.shape[0], data.shape[1], data.shape[2],-1)
"""    
def applyAffineGroupAction(tensor, mat):
    # apply A*T*A' for T in 6-dim vector form
    if tensor.is_cuda:
        output = torch.zeros(tensor.shape).cuda()
        
    else:
        output = torch.zeros(tensor.shape)
    # (0,0)
    output[:,0] = tensor[:,0] * mat[0,0] * mat[0,0] + tensor[:,3] * mat[0,1] * mat[1,0] + tensor[:,5] * mat[0,2] * mat[2,0] \
                  + 2 * (tensor[:,1] * mat[0,0] * mat[1,0] + tensor[:,2] * mat[0,0] * mat[2,0] + tensor[:,4] * mat[0,1] * mat[2,0])
    # (1,1)
    output[:,3] = tensor[:,0] * mat[1,0] * mat[0,1] + tensor[:,3] * mat[1,1] * mat[1,1] + tensor[:,5] * mat[1,2] * mat[2,1] \
                  + 2 * (tensor[:,1] * mat[1,0] * mat[1,1] + tensor[:,2] * mat[1,0] * mat[2,1] + tensor[:,4] * mat[1,1] * mat[2,1])
    # (2,2)
    output[:,5] = tensor[:,0] * mat[2,0] * mat[0,2] + tensor[:,3] * mat[2,1] * mat[1,2] + tensor[:,5] * mat[2,2] * mat[2,2] \
                  + 2 * (tensor[:,1] * mat[0,0] * mat[1,0] + tensor[:,2] * mat[0,0] * mat[2,0] + tensor[:,4] * mat[0,1] * mat[2,0])
    
    return output
"""

def plot_weights_2d(tensor, num_cols=6):
    num_kernels = tensor.shape[0]
    num_rows = 1+ num_kernels // num_cols
    fig = plt.figure(figsize=(num_cols,num_rows))
    for i in range(tensor.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        ax1.imshow(tensor[i])
        ax1.axis('off')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
    

def plot_vectorfield_3d(tensor, kernel_size = 5):
    fig = plt.figure()
    ax = fig.gca(projection='3d')

    y, x, z = np.meshgrid(np.linspace(0, 1, kernel_size),
                      np.linspace(0, 1, kernel_size),
                      np.linspace(0, 1, kernel_size))

    u = tensor[:,:,:,0]/torch.max(torch.abs(tensor))
    v = tensor[:,:,:,1]/torch.max(torch.abs(tensor))
    w = tensor[:,:,:,2]/torch.max(torch.abs(tensor))
    for i in range(kernel_size):
        if i % 3 == 0:
            ax.quiver(x[:,:,i], y[:,:,i], z[:,:,i], u[:,:,i], v[:,:,i], w[:,:,i], length=0.1, color='r')
        if i % 3 == 1:
            ax.quiver(x[:,:,i], y[:,:,i], z[:,:,i], u[:,:,i], v[:,:,i], w[:,:,i], length=0.1, color='g')
        if i % 3 == 2:
            ax.quiver(x[:,:,i], y[:,:,i], z[:,:,i], u[:,:,i], v[:,:,i], w[:,:,i], length=0.1, color='b')
    plt.show()