import nibabel as nib
import numpy as np
from skimage.measure import label, regionprops, marching_cubes
from skimage.morphology import binary_erosion, remove_small_objects

def surface_area(regionmask):
    if regionmask.ndim != 3:
        return 0
    if np.any(np.array(regionmask.shape)<2):
        return 0
    # Compute the surface area using marching cubes
    verts, faces, _, _ = marching_cubes(regionmask, level=0, spacing = (0.78125, 0.78125, 2.0))
    surface_area = 0
    for face in faces:
        v0, v1, v2 = verts[face]
        # Compute the area of the triangle formed by v0, v1, and v2
        surface_area += np.linalg.norm(np.cross(v1 - v0, v2 - v0)) / 2
    return surface_area


def get_connected_components(image):
    label_image, num = label(image, connectivity=3, return_num=True)
    if num < 2:
        print('Warning: less than 2 connected components')
        label_image, num = label(image, connectivity=2, return_num=True)
        print('Connected components:', num)
    # regions = regionprops(label_image)
    regions = regionprops(label_image, extra_properties=[surface_area,])
    return label_image, regions

def create_multichannel_mask(image):
    # Get connected components
    liver_image = image == 2
    label_image, regions = get_connected_components(image==1)
    sorted_regions = sorted(regions, key=lambda r: r.surface_area, reverse=True)
    
    if len(sorted_regions) < 2:
        raise ValueError("Expected at least two connected components for the lungs.")
    
    # Assume the two largest components are the lungs
    largest_region = sorted_regions[0]
    second_largest_region = sorted_regions[1]
    
    # Determine which lung is left and which is right based on area
    if largest_region.area > second_largest_region.area:
        right_lung_label = largest_region.label
        left_lung_label = second_largest_region.label
    else:
        right_lung_label = second_largest_region.label
        left_lung_label = largest_region.label
        
    # if area too small raise error
    if largest_region.surface_area < 1000 or second_largest_region.surface_area < 1000:
        raise ValueError("Expected at least two connected components for the lungs.")
    
    print(f'Area of left lung: {second_largest_region.surface_area}, right lung: {largest_region.surface_area}')
    
    # Create the multichannel mask
    multichannel_mask = np.zeros(image.shape, dtype=np.uint8)
    
    multichannel_mask[(label_image == left_lung_label)]=1
    multichannel_mask[(label_image == right_lung_label)]=3
    multichannel_mask[liver_image]=2
    
    return multichannel_mask





def create_multichannel_mask_with_erosion(image):
    # Get connected components
    liver_image = image == 2
    lung_image = image == 1
    # remove small connected components
    lung_image = remove_small_objects(lung_image, 500)
    label_image, regions = get_connected_components(lung_image)
    sorted_regions = sorted(regions, key=lambda r: r.surface_area, reverse=True)
    
    c = 0
    while len(sorted_regions) < 2:
        # erode the lung_image
        print('Warning: less than 2 connected components ... Eroding the image', c)
        c += 1
        lung_image = binary_erosion(lung_image)
        label_image, regions = get_connected_components(lung_image)
        sorted_regions = sorted(regions, key=lambda r: r.surface_area, reverse=True)
        #raise ValueError("Expected at least two connected components for the lungs.")
    # calculate midline and split original image with midline
    # Assume the two largest components are the lungs
    largest_region = sorted_regions[0]
    second_largest_region = sorted_regions[1]
    
    # Determine which lung is left and which is right based on area
    if largest_region.area > second_largest_region.area:
        right_lung_label = largest_region.label
        left_lung_label = second_largest_region.label
    else:
        right_lung_label = second_largest_region.label
        left_lung_label = largest_region.label
        
    
    print(f'Area of left lung: {second_largest_region.surface_area}, right lung: {largest_region.surface_area}')
    
    # Create the multichannel mask
    multichannel_mask = np.zeros(image.shape, dtype=np.uint8)
    
    multichannel_mask[(label_image == left_lung_label)]=1
    multichannel_mask[(label_image == right_lung_label)]=3
    multichannel_mask[liver_image]=2
    
    return multichannel_mask
