import numpy as np


def print_matrix(matrix1: np.ndarray, matrix2: np.ndarray, matrix_name: str, style: str = "box"):
    """Print matrix with different styles, showing both matrices as a(b) format"""
    msg = ""
    msg += f"\n{matrix_name} Matrix ({matrix1.shape[0]}x{matrix1.shape[1]}):\n"
    
    if style == "box":
        # box style
        msg += "┌" + "─" * (matrix1.shape[1] * 12 + 1) + "┐\n"
        for i, (row1, row2) in enumerate(zip(matrix1, matrix2)):
            msg += "│"
            for j, (val1, val2) in enumerate(zip(row1, row2)):
                msg += f"{val1:5.3f}({val2:5.3f})"
            msg += "│\n"
        msg += "└" + "─" * (matrix1.shape[1] * 12 + 1) + "┘\n"
    
    elif style == "simple":
        # simple style
        for i, (row1, row2) in enumerate(zip(matrix1, matrix2)):
            msg += f"Row {i+1}: ["
            for j, (val1, val2) in enumerate(zip(row1, row2)):
                if j > 0:
                    msg += ", "
                msg += f"{val1:.3f}({val2:.3f})"
            msg += "]\n"
    
    elif style == "table":
        # table style
        msg += "     "
        for j in range(matrix1.shape[1]):
            msg += f"Col{j+1:>10}"
        msg += "\n"
        msg += "-" * (matrix1.shape[1] * 12 + 5) + "\n"
        for i, (row1, row2) in enumerate(zip(matrix1, matrix2)):
            msg += f"Row{i+1:>3}"
            for val1, val2 in zip(row1, row2):
                msg += f"{val1:5.3f}({val2:5.3f})"
            msg += "\n"
    
    elif style == "heatmap":
        # heatmap style
        msg += "     "
        for j in range(matrix1.shape[1]):
            msg += f"Col{j+1:>6}"
        msg += "\n"
        for i, (row1, row2) in enumerate(zip(matrix1, matrix2)):
            msg += f"Row{i+1:>3}"
            for val1, val2 in zip(row1, row2):
                # Use val1 for heatmap symbol, val2 in parentheses
                if val1 == 0:
                    msg += f"  .({val2:.2f})"
                elif val1 < 0.001:
                    msg += f"  ·({val2:.2f})"
                elif val1 < 0.01:
                    msg += f"  o({val2:.2f})"
                elif val1 < 0.1:
                    msg += f"  O({val2:.2f})"
                else:
                    msg += f"  ■({val2:.2f})"
            msg += "\n"
    
    elif style == "markdown":
        # markdown table style
        msg += "|"
        for j in range(matrix1.shape[1]):
            msg += f" Col{j+1} |"
        msg += "\n"
        msg += "|"
        for j in range(matrix1.shape[1]):
            msg += " --- |"
        msg += "\n"
        for i, (row1, row2) in enumerate(zip(matrix1, matrix2)):
            msg += "|"
            for val1, val2 in zip(row1, row2):
                msg += f" {val1:.3f}({val2:.3f}) |"
            msg += "\n"
    
    print(msg)
    return msg