import os
import sys
import cv2
import numpy as np
from sklearn import base

def concat_subviews(base_path, view_subimg_pairs, concat_mode='horizontal'):
    """
    Process and concatenate subimages from different views.
    
    Args:
    - base_path (str): Path to the directory containing view directories.
    - view_subimg_pairs (list of tuples): Each tuple contains (view_idx, subimg_idx, opacity_idx (optional)).
    - concat_mode (str): 'horizontal' or 'vertical' concatenation of the final image.
    
    Returns:
    - result_img (numpy array): The concatenated image.
    """
    # Initialize list to store cropped images
    cropped_images = []
    
    # Iterate through each specified view and subimage index
    for view in view_subimg_pairs:
        view_idx, subimg_idx = view[:2]
    
        # Construct path to the image
        img_path = os.path.join(base_path, f'{view_idx}.png')
        
        # Read the image
        img = cv2.imread(img_path)
        
        # Check if image is loaded properly
        if img is None:
            raise FileNotFoundError(f"Image not found at path: {img_path}")
        
        # Determine the size of each subimage
        subimg_width = img.shape[0]  # height of the image is the width of each subimage
        
        # Calculate the cropping coordinates
        x_start = subimg_idx * subimg_width
        x_end = x_start + subimg_width
        
        # Crop the subimage
        subimg = img[:, x_start:x_end]
        
        
         # Combine the cropped image with the opacity map
        if len(subimg.shape) == 2 or subimg.shape[2] == 1:  # grayscale image
            subimg = cv2.cvtColor(subimg, cv2.COLOR_GRAY2BGRA)
        elif subimg.shape[2] == 3:  # BGR image
            subimg = cv2.cvtColor(subimg, cv2.COLOR_BGR2BGRA)
        
        if len(view) > 2:
            opacity_index = view[2]
            opacity_x_start = opacity_index * subimg_width
            opacity_x_end = opacity_x_start + subimg_width

            # Crop the opacity map
            opacity_map = img[:, opacity_x_start:opacity_x_end]
            opacity_map = cv2.cvtColor(opacity_map, cv2.COLOR_BGR2GRAY)  # Convert to grayscale

            # Set the alpha channel
            subimg[:, :, 3] = opacity_map
        
        # Add the cropped image to the list
        cropped_images.append(subimg)
    
    # Concatenate all the cropped images in the specified mode
    if concat_mode == 'horizontal':
        result_img = np.hstack(cropped_images)
    elif concat_mode == 'vertical':
        result_img = np.vstack(cropped_images)
    else:
        raise ValueError("concat_mode must be 'horizontal' or 'vertical'")
    
    return result_img

def concat_and_scale_images(subdir_path, horizontal_list, vertical_list):
    """
    Concatenate subimages horizontally and vertically, then scale and combine.
    
    Args:
    - base_path (str): Base directory path containing the image files.
    - horizontal_list (list): List of tuples for horizontal concatenation.
    - vertical_list (list): List of tuples for vertical concatenation.
    
    Returns:
    - Combined image with horizontal and scaled vertical images side by side.
    """
    horizontal_image = concat_subviews(subdir_path, horizontal_list, concat_mode='horizontal')
    vertical_image = concat_subviews(subdir_path, vertical_list, concat_mode='vertical')
    
    # Scale the vertical image to match the height of the horizontal image
    vertical_scaled = cv2.resize(vertical_image, (int(vertical_image.shape[1] * horizontal_image.shape[0] / vertical_image.shape[0]), horizontal_image.shape[0]), interpolation=cv2.INTER_AREA)
    
    # Concatenate horizontally (side by side)
    combined_image = np.hstack((horizontal_image, vertical_scaled))
    return combined_image

def create_image_grid(images, grid_size, gap_ratio):
    img_height, img_width = images[0].shape[:2]
    gap = int(img_height * gap_ratio)
    grid = np.zeros((grid_size[0] * (img_height + gap) - gap, grid_size[1] * (img_width + gap) - gap, 4), dtype=np.uint8)
    
    for idx, img in enumerate(images):
        if img.shape[2] == 3:  # Convert to RGBA if necessary
            img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
        row, col = divmod(idx, grid_size[1])
        y = row * (img_height + gap)
        x = col * (img_width + gap)
        grid[y:y+img_height, x:x+img_width] = img

    return grid

def process_all_subdirs(base_path, iteration, horizontal_pairs, vertical_pairs, ignore_list=[], grid_size=(2, 3), gap_ratio=0.05):
    subdirs = [os.path.join(base_path, d) for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    result_images = []
    for subdir in subdirs:
        if any([ignore in subdir for ignore in ignore_list]):
            continue
        subdir_path = os.path.join(subdir, 'save', f'it{iteration}-test')
        combined_image = concat_and_scale_images(subdir_path, horizontal_pairs, vertical_pairs)
        result_images.append(combined_image)
        if len(result_images) == grid_size[0] * grid_size[1]:
            break  # Limit to the number of images we can place in the grid

    grid_image = create_image_grid(result_images, grid_size, gap_ratio)
    return grid_image

# accept base_path as an argument
base_path="mass_runs/20240505_215321_Default_check_seeding/ours-inversion"
# base_path = sys.argv[1]

gap_ratio=-0.2
view_1 = 0
view_2 = 20

# Layout 1
# horizontal_pairs = [(view_1, 0, 2)]
# vertical_pairs   = [(view_1, 1, 2), (view_2, 0, 2), (view_2, 1, 2)]
# suffix=""

# Layout 2
# horizontal_pairs = [(view_1, 0, 2), (view_1, 1, 2)]
# vertical_pairs   = [(view_2, 0, 2), (view_2, 1, 2)]
# suffix="_l2"

# Layout 3
horizontal_pairs = [(view_1, 0, 2), (view_2, 0, 2)]
vertical_pairs   = [(view_1, 1, 2), (view_2, 1, 2)]
suffix="_l3"

concatenated_image = process_all_subdirs(base_path, 8000,
                                         horizontal_pairs, vertical_pairs,
                                         gap_ratio=gap_ratio,
                                         ignore_list = ["cello", "piano"])

out_path = os.path.join(base_path, "main_ours_results" + suffix + ".png")
cv2.imwrite(out_path, concatenated_image)
print("Image saved to: ", out_path)
