import os
import numpy as np
import torch
import time
import nibabel as nib
from tensor_data_util import Exp, posMetric_sqrt_func
from data_util import Log_vec2Log
from scipy.special import comb

#####################################################################################    
#                              FUNCTIONS TO SAVE RESULTS                            #
#####################################################################################

def calculate_ARI_using_CC(dataPoints, dx, clusterIds, groups_pointIdx, posMean = None, roi_range = None):
    ################ formula for ARI in 
    ################ Bart Moberts et al., 'Evaluation of Fiber Clustering Methods for Diﬀusion Tensor Imaging'
    dataPoints = dataPoints.clone()
    if posMean is not None:
        dataPoints[:,:3] += posMean.cuda()
    interestedROI1 = torch.FloatTensor(nib.load('cc_back_warped.nii.gz').get_data()).cuda()
    interestedROI1 /= torch.max(interestedROI1)
    interestedROI2 = torch.FloatTensor(nib.load('cc_body_warped.nii.gz').get_data()).cuda()
    interestedROI2 /= torch.max(interestedROI2)
    interestedROI3 = torch.FloatTensor(nib.load('cc_front_warped.nii.gz').get_data()).cuda()
    interestedROI3 /= torch.max(interestedROI3)
    
    interestedROI = interestedROI1 + interestedROI2 + interestedROI3
    interestedROI[interestedROI > 1] = 1
    interestedROI = interestedROI[25:65, 30:85, 30:60].long()    
    if roi_range is not None:
        interestedROI = interestedROI[roi_range[0]:roi_range[1], roi_range[2]:roi_range[3], roi_range[4]:roi_range[5]]
        
    isInsideCC = interestedROI[(dataPoints[:,0]/dx).long(), (dataPoints[:,1]/dx).long(), (dataPoints[:,2]/dx).long()]
    
    temp = np.unique(clusterIds)
    contingency_table = np.zeros((2, len(temp[temp>=0])))
    group_number = []
    i = 0
    m2 = 0
    # fill contingency table and calculate m2 value for clustering results
    for k, groupIdxSet in enumerate(groups_pointIdx):
        if len(groupIdxSet) > 0 and clusterIds[k] > 0:
            group_number.append(k)
            for pointIdx in groupIdxSet:
                if isInsideCC[pointIdx]:
                    contingency_table[0,i] += 1
                else:
                    contingency_table[1,i] += 1
            m2 += comb(len(groupIdxSet), 2)
            i += 1
    # calculate a value
    a = 0
    for i in range(contingency_table.shape[0]):
        for j in range(contingency_table.shape[1]):
            a += comb(contingency_table[i,j], 2)
    n = np.sum(contingency_table)
    M = comb(n, 2)
    m1 = 0.0
    for i in range(contingency_table.shape[0]):
        m1 += comb(np.sum(contingency_table[i]), 2)
    ARI = (a - m1*m2/M) / (0.5*(m1+m2) - m1*m2/M)
    
    return ARI, (M, m1, m2, a)


def save_DTI_seg_result_file(data_dim, dataPoints, clusterIds, groups_pointIdx, dx, foldername, mat2 = None, 
                             posMean = None, min_len = None):
    if mat2 is None:
        mat2 = np.eye(4)
        mat2[0,0] = -2
        mat2[1,1] = 2
        mat2[2,2] = 2
        mat2[0,3] = 41
        mat2[1,3] = -67
        mat2[2,3] = -13
    dataPoints = dataPoints.clone()
    if posMean is not None:
        dataPoints[:,:3] += posMean.cuda()
    clusterIds_unique = np.unique(clusterIds)
    if not os.path.exists(foldername):
        os.makedirs(foldername)
    if min_len is None:
        min_len = 0
    for cluster in clusterIds_unique:
        if cluster >= 0 and len(groups_pointIdx[cluster]) > min_len:
            #print(cluster)
            savedata = torch.zeros(data_dim[0], data_dim[1], data_dim[2]).long()
            clusterData = dataPoints[torch.LongTensor(groups_pointIdx[cluster])]
            savedata[(clusterData[:,0]/dx).long(), (clusterData[:,1]/dx).long(), (clusterData[:,2]/dx).long()] = \
            torch.LongTensor([1])
            # save as nifti image
            nib.save(nib.Nifti1Image(savedata.cpu().numpy(), mat2), 
                             foldername+'/cluster'+'_num_'+str(cluster)\
                     +'_size_'+str(len(groups_pointIdx[cluster]))+'.nii.gz')
    return

def save_DTI_shiftedTensor_result_file(data_dim, dataPoints, shiftedDataPoints, 
                                       dx, savefilename, mat2 = None, posMean = None, use_logvec = True):
    if mat2 is None:
        mat2 = np.eye(4)
        mat2[0,0] = -2
        mat2[1,1] = 2
        mat2[2,2] = 2
        mat2[0,3] = 41
        mat2[1,3] = -67
        mat2[2,3] = -13
    dataPoints = dataPoints.clone()
    if posMean is not None:
        dataPoints[:,:3] += posMean.cuda()
    if use_logvec:
        shiftedTensor = Exp(Log_vec2Log(shiftedDataPoints[:,3:]))
    else:
        shiftedTensor = shiftedDataPoints[:,3:]
    tensorVal = torch.zeros(data_dim[0], data_dim[1], data_dim[2], 6).cuda()
    tensorVal[(dataPoints[:,0]/dx).long(), (dataPoints[:,1]/dx).long(), (dataPoints[:,2]/dx).long(),:] = \
        shiftedTensor
    
    # save as nifti image
    nib.save(nib.Nifti1Image(tensorVal.cpu().numpy(), mat2), 
                             savefilename)
    return tensorVal

def save_DTI_shiftedTensor_result_file_2dim(data_dim, dataPoints, shiftedDataPoints, dx, savefilename, 
                                            project_suffix, slice_val, mat2 = None, posMean = None):
    if mat2 is None:
        mat2 = np.eye(4)
        mat2[0,0] = -2
        mat2[1,1] = 2
        mat2[2,2] = 2
        mat2[0,3] = 41
        mat2[1,3] = -67
        mat2[2,3] = -13
    dummy_val = 1e-5
    shiftedTensor = Exp(Log_vec2Log(shiftedDataPoints[:,2:]))
    tempTensorVal = torch.zeros(data_dim[0], data_dim[1], data_dim[2], 3).cuda()
    dummyTensorVal = torch.zeros(data_dim[0], data_dim[1], data_dim[2]).cuda()
    tensorVal = torch.zeros(data_dim[0], data_dim[1], data_dim[2], 6).cuda()
    dataPoints = dataPoints.clone()
    if posMean is not None:
        dataPoints[:,:2] += posMean.cuda()
    
    if project_suffix == '_xy':
        val_idxSet = [0,1,3]
        dummy_idx = 5
        tempTensorVal[(dataPoints[:,0]/dx).long(), (dataPoints[:,1]/dx).long(), (slice_val/dx).long(),:] = \
        shiftedTensor
        dummyTensorVal[(dataPoints[:,0]/dx).long(), (dataPoints[:,1]/dx).long(), (slice_val/dx).long()] = \
        dummy_val
    elif project_suffix == '_yz':
        val_idxSet = [3,4,5]
        dummy_idx = 0
        tempTensorVal[(slice_val/dx).long(), (dataPoints[:,0]/dx).long(), (dataPoints[:,1]/dx).long(),:] = \
        shiftedTensor
        dummyTensorVal[(slice_val/dx).long(), (dataPoints[:,0]/dx).long(), (dataPoints[:,1]/dx).long()] = \
        dummy_val
    else:
        val_idxSet = [0,2,5]
        dummy_idx = 3
        tempTensorVal[(dataPoints[:,0]/dx).long(), (slice_val/dx).long(), (dataPoints[:,1]/dx).long(),:] = \
        shiftedTensor
        dummyTensorVal[(dataPoints[:,0]/dx).long(), (slice_val/dx).long(), (dataPoints[:,1]/dx).long()] = \
        dummy_val
    for i in range(3):
        tensorVal[:,:,:,val_idxSet[i]] = tempTensorVal[:,:,:,i]
    tensorVal[:,:,:,dummy_idx] = dummyTensorVal
    # save as nifti image
    nib.save(nib.Nifti1Image(tensorVal.cpu().numpy(), mat2), 
                             savefilename)
    return tensorVal

def save_DTI_overlap_seg_result_file(data_dim, dataPoints, clusterIds, groups_pointIdx, dx, savefilename, 
                                     mat2 = None, posMean = None, roi_range = None, savefile = True):
    if mat2 is None:
        mat2 = np.eye(4)
        mat2[0,0] = -2
        mat2[1,1] = 2
        mat2[2,2] = 2
        mat2[0,3] = 41
        mat2[1,3] = -67
        mat2[2,3] = -13
    dataPoints = dataPoints.clone()
    if posMean is not None:
        dataPoints[:,:3] += posMean.cuda()
    interestedROI1 = torch.FloatTensor(nib.load('cc_back_warped.nii.gz').get_data()).cuda()
    interestedROI1 /= torch.max(interestedROI1)
    interestedROI2 = torch.FloatTensor(nib.load('cc_body_warped.nii.gz').get_data()).cuda()
    interestedROI2 /= torch.max(interestedROI2)
    interestedROI3 = torch.FloatTensor(nib.load('cc_front_warped.nii.gz').get_data()).cuda()
    interestedROI3 /= torch.max(interestedROI3)
    
    interestedROI = interestedROI1 + interestedROI2 + interestedROI3
    interestedROI[interestedROI > 1] = 1
    interestedROI = interestedROI[25:65, 30:85, 30:60].long()
    if roi_range is not None:
        data_dim = [roi_range[1] - roi_range[0], roi_range[3] - roi_range[2], roi_range[5] - roi_range[4]]
        interestedROI = interestedROI[roi_range[0]:roi_range[1], roi_range[2]:roi_range[3], roi_range[4]:roi_range[5]]
    clusterIds_unique = np.unique(clusterIds)
    savedata = torch.zeros(data_dim[0], data_dim[1], data_dim[2]).long().cuda()
    savedata2 = torch.zeros(data_dim[0], data_dim[1], data_dim[2]).long().cuda()
    num_clusters = 0
    overlapping_clusters = []
    overlapping_score = []
    overlapping_precision = []
    overlapping_clusters_tight = []
    for cluster in clusterIds_unique:
        if cluster >= 0:
            #print(cluster)
            tempdata = torch.zeros(data_dim[0], data_dim[1], data_dim[2]).long().cuda()
            clusterData = dataPoints[torch.LongTensor(groups_pointIdx[cluster])]
            tempdata[(clusterData[:,0]/dx).long(), (clusterData[:,1]/dx).long(), (clusterData[:,2]/dx).long()] = 1
            #print(tempdata.shape)
            if torch.sum(tempdata * interestedROI) > 0:
                savedata += tempdata
                num_clusters += 1
                overlapping_clusters.append(cluster)
                intersection = torch.sum(tempdata * interestedROI).cpu().view(1).float()
                union = (torch.sum(tempdata) + torch.sum(interestedROI)).cpu().view(1).float() - intersection
                overlapping_score.append(intersection/union)
                precision = intersection / torch.sum(tempdata).cpu().view(1).float()
                overlapping_precision.append(precision)
                if precision > 0.5:
                    overlapping_clusters_tight.append(cluster)
                    savedata2 += tempdata
            
    # save as nifti image
    if savefile:
        nib.save(nib.Nifti1Image(savedata.cpu().numpy(), mat2), 
                             savefilename)
        nib.save(nib.Nifti1Image(savedata2.cpu().numpy(), mat2), 
                             savefilename[:-7]+'_tight.nii.gz')
    #print( num_clusters )
    overlapping_score = torch.cat(overlapping_score, 0)
    overlapping_precision = torch.cat(overlapping_precision, 0)
    return overlapping_clusters, overlapping_score, overlapping_precision, overlapping_clusters_tight