{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":14236875,"sourceType":"datasetVersion","datasetId":9082981}],"dockerImageVersionId":31260,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nimport sys\nimport time\nimport json\nfrom glob import glob\nfrom pathlib import Path\n\nimport numpy as np\nimport cv2\nimport tensorflow as tf\nfrom tensorflow.keras.layers import (\n    Input, Conv2D, DepthwiseConv2D, GlobalAveragePooling2D,\n    Reshape, Multiply, BatchNormalization, Activation, Dropout,\n    Concatenate, Add, UpSampling2D, LayerNormalization, Dense, Layer,\n    MaxPooling2D, Conv2DTranspose, MultiHeadAttention, Lambda\n)\nfrom tensorflow.keras.models import Model\nfrom tensorflow.keras.regularizers import l2\nfrom tensorflow.keras.callbacks import (\n    ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau\n)\nfrom tensorflow.keras import backend as K\nimport albumentations as A\nimport matplotlib.pyplot as plt\nimport pandas as pd\n\nos.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"\n\n# ==============================================================================\n# CONFIGURATION\n# ==============================================================================\n\nclass Config:\n    \"\"\"Central configuration for training\"\"\"\n    \n    # ==================== MODEL SELECTION ====================\n    MODEL_NAME = \"VM-UNet\"  # Options: \"UNet\", \"UNetPlusPlus\", \"PraNet\", \"RAPUNet\", \n                             #          \"SwinUNet\", \"TransUNet\", \"UMamba\", \"DUCKNET\",\n                             #          \"MedSegNet-SSF\", \"nnUNet\", \"VM-UNet\", \"MedNeXt\"\n    \n    # ==================== GPU CONFIGURATION ====================\n    GPU_NUMBERS = [0]  # List of GPU indices to use, e.g., [0] or [0, 1]\n    \n    # ==================== DATA PATHS ====================\n    DATA_ROOT = \"/kaggle/input/endovis-17/BinarySegmentation\"\n    TRAIN_DIR = os.path.join(DATA_ROOT, \"train\")\n    VAL_DIR   = os.path.join(DATA_ROOT, \"val\")\n    TEST_DIR  = os.path.join(DATA_ROOT, \"test\")\n    \n    # ==================== MODEL ARCHITECTURE ====================\n    INPUT_SIZE  = 352\n    F1, F2, F3, F4, F5 = 24, 32, 64, 80, 128\n    \n    # MedSegNet-SSF Specific\n    USE_MRF_SE = True\n    USE_SSTM   = True\n    USE_BFP    = True\n    MRF_KERNELS = [3, 5, 7]\n    SE_REDUCTION = 16\n    EXPAND_RATIO = 6\n    SSTM_NUM_FREQUENCIES = 32\n    SSTM_SSM_STATE_DIM = 16\n    SSTM_USE_SPECTRAL = [True, True, True, True, True]\n    SSTM_USE_SSM = [False, False, True, True, True]\n    SSTM_DROPOUT = 0.1\n    \n    # Transformer-based models\n    NUM_HEADS = 4\n    TRANSFORMER_LAYERS = 6\n    PATCH_SIZE = 16\n    \n    # DUCK-Net specific\n    DUCKNET_FILTERS = [17, 34, 68, 136, 272]\n    DUCKNET_DILATION_RATES = [1, 3, 5]\n    \n    # Universal parameters\n    DROPOUT = 0.1\n    L2_REG = 1e-4\n\n    # ==================== TRAINING SETTINGS ====================\n    BATCH_SIZE    = 4\n    EPOCH_EXPANSION_FACTOR = 30  # Virtual epoch multiplier\n    EPOCHS        = 30\n    LEARNING_RATE = 1e-4\n    \n    EARLY_STOPPING_PATIENCE = 40\n    CHECKPOINT_MONITOR      = \"val_dice_coefficient\"\n    CHECKPOINT_MODE         = \"max\"\n    \n    # Loss function\n    USE_MASL = False  # Set to True to use MASL, False for binary_crossentropy\n\n    SEED          = 42\n    DETERMINISTIC = False\n\n    def __init__(self):\n        self.SAVE_DIR = f\"{self.MODEL_NAME}_OUTPUT\"\n        os.makedirs(self.SAVE_DIR, exist_ok=True)\n        print(f\"🔥 TRAINING CONFIGURATION\")\n        print(f\"   Model: {self.MODEL_NAME}\")\n        print(f\"   Virtual Epoch Factor: {self.EPOCH_EXPANSION_FACTOR}x\")\n        print(f\"   Loss Function: {'MASL' if self.USE_MASL else 'Binary Crossentropy'}\")\n\nconfig = Config()\n\n# ==============================================================================\n# GPU SETUP\n# ==============================================================================\n\ndef setup_gpus(gpu_numbers=None):\n    \"\"\"Configure GPUs based on specified GPU numbers\"\"\"\n    gpus = tf.config.list_physical_devices('GPU')\n    \n    if not gpus:\n        print(\"⚠️ No GPUs found! Using CPU.\")\n        return tf.distribute.get_strategy(), 0\n    \n    print(f\"🔍 Total GPUs available: {len(gpus)}\")\n    for i, gpu in enumerate(gpus):\n        print(f\"   GPU {i}: {gpu.name}\")\n    \n    if gpu_numbers is not None:\n        invalid_gpus = [g for g in gpu_numbers if g >= len(gpus)]\n        if invalid_gpus:\n            raise ValueError(f\"Invalid GPU numbers: {invalid_gpus}\")\n        \n        selected_gpus = [gpus[i] for i in gpu_numbers]\n        print(f\"\\n✅ Using selected GPUs: {gpu_numbers}\")\n    else:\n        selected_gpus = gpus\n        print(f\"\\n✅ Using all available GPUs\")\n    \n    try:\n        tf.config.set_visible_devices(selected_gpus, 'GPU')\n        \n        for gpu in selected_gpus:\n            tf.config.experimental.set_memory_growth(gpu, True)\n        \n        num_gpus = len(selected_gpus)\n        \n        if num_gpus > 1:\n            strategy = tf.distribute.MirroredStrategy()\n            print(f\"🚀 Multi-GPU Training: {num_gpus} GPUs\")\n            print(f\"   Effective batch size: {config.BATCH_SIZE * num_gpus}\")\n        else:\n            strategy = tf.distribute.get_strategy()\n            print(f\"📍 Single GPU Training\")\n        \n        return strategy, num_gpus\n        \n    except RuntimeError as e:\n        print(f\"⚠️ GPU setup error: {e}\")\n        return tf.distribute.get_strategy(), 0\n\nstrategy, num_gpus = setup_gpus(config.GPU_NUMBERS)\n\n# ==============================================================================\n# UTILS & AUGMENTATION\n# ==============================================================================\n\ndef set_seed(seed=42, deterministic=False):\n    \"\"\"Set random seeds for reproducibility\"\"\"\n    import random\n    np.random.seed(seed)\n    random.seed(seed)\n    tf.random.set_seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n\ndef get_image_mask_pairs(images_dir, masks_dir):\n    \"\"\"Find and pair images with their corresponding masks\"\"\"\n    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.bmp']\n    image_files = []\n    for ext in image_extensions:\n        image_files.extend(glob(os.path.join(images_dir, ext)))\n    image_files = sorted(image_files)\n    \n    if len(image_files) == 0:\n        print(f\"⚠️ No images found in {images_dir}\")\n        return []\n\n    pairs = []\n    for img_path in image_files:\n        img_name = Path(img_path).stem\n        possible_names = [\n            f\"{img_name}.png\", f\"{img_name}.jpg\", f\"{img_name}.tif\", f\"{img_name}.bmp\", \n            f\"{img_name}_mask.png\", f\"{img_name}_mask.jpg\"\n        ]\n        for mask_name in possible_names:\n            cand = os.path.join(masks_dir, mask_name)\n            if os.path.exists(cand):\n                pairs.append((img_path, cand))\n                break\n    return pairs\n\ndef load_dataset_split(split_dir):\n    \"\"\"Load image-mask pairs from a dataset split\"\"\"\n    images_dir = os.path.join(split_dir, \"images\")\n    masks_dir  = os.path.join(split_dir, \"masks\")\n    return get_image_mask_pairs(images_dir, masks_dir)\n\ndef get_training_augmentation(cfg):\n    \"\"\"Aggressive training augmentation pipeline\"\"\"\n    return A.Compose([\n        A.HorizontalFlip(p=0.5),\n        A.VerticalFlip(p=0.5),\n        A.ShiftScaleRotate(\n            shift_limit=0.0625, \n            scale_limit=0.2, \n            rotate_limit=180, \n            border_mode=cv2.BORDER_CONSTANT,\n            p=1.0\n        ),\n        A.ColorJitter(brightness=0.4, contrast=0.2, saturation=0.1, hue=0.01, p=1.0),\n        A.Resize(height=cfg.INPUT_SIZE, width=cfg.INPUT_SIZE),\n    ], p=1.0)\n\ndef get_validation_augmentation(cfg):\n    \"\"\"Validation augmentation (resize only)\"\"\"\n    return A.Compose([A.Resize(cfg.INPUT_SIZE, cfg.INPUT_SIZE)])\n\n# ==============================================================================\n# DATA GENERATOR\n# ==============================================================================\n\nclass ExpandedDataGenerator(tf.keras.utils.Sequence):\n    \"\"\"Data generator with epoch expansion for longer virtual epochs\"\"\"\n    \n    def __init__(self, pairs, cfg, augmentation=None, shuffle=True, expansion_factor=1):\n        self.pairs = pairs\n        self.cfg = cfg\n        self.augmentation = augmentation\n        self.shuffle = shuffle\n        self.expansion_factor = expansion_factor\n        self.indices = np.arange(len(self.pairs))\n        \n        self.real_batches = len(self.pairs) // self.cfg.BATCH_SIZE\n        self.virtual_batches = self.real_batches * self.expansion_factor\n        \n        if self.shuffle:\n            np.random.shuffle(self.indices)\n            \n    def __len__(self):\n        return self.virtual_batches\n\n    def __getitem__(self, index):\n        real_index_ptr = index % self.real_batches\n        \n        batch_start = real_index_ptr * self.cfg.BATCH_SIZE\n        batch_end = batch_start + self.cfg.BATCH_SIZE\n        batch_indices = self.indices[batch_start:batch_end]\n        \n        images, masks = [], []\n        \n        for idx in batch_indices:\n            img_path, mask_path = self.pairs[idx]\n            \n            image = cv2.imread(img_path)\n            if image is None: \n                continue\n            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n            \n            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)\n            mask = (mask > 127).astype(np.float32)\n            \n            if self.augmentation:\n                augmented = self.augmentation(image=image, mask=mask)\n                image = augmented[\"image\"]\n                mask = augmented[\"mask\"]\n            \n            image = image.astype(np.float32) / 255.0\n            if len(mask.shape) == 2:\n                mask = np.expand_dims(mask, axis=-1)\n            \n            images.append(image)\n            masks.append(mask)\n            \n        return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)\n\n    def on_epoch_end(self):\n        if self.shuffle:\n            np.random.shuffle(self.indices)\n\n# ==============================================================================\n# COMMON BUILDING BLOCKS\n# ==============================================================================\n\ndef conv_block(x, filters, kernel_size=3, activation='relu', use_bn=True, dropout=0.0, name='conv'):\n    \"\"\"Standard convolutional block\"\"\"\n    x = Conv2D(filters, kernel_size, padding='same', \n               kernel_initializer='he_normal', name=f'{name}_conv')(x)\n    if use_bn:\n        x = BatchNormalization(name=f'{name}_bn')(x)\n    x = Activation(activation, name=f'{name}_act')(x)\n    if dropout > 0:\n        x = Dropout(dropout, name=f'{name}_dropout')(x)\n    return x\n\ndef attention_gate(g, x, filters, name='att'):\n    \"\"\"Attention gate for skip connections\"\"\"\n    g_conv = Conv2D(filters, 1, padding='same', kernel_initializer='he_normal', name=f'{name}_g')(g)\n    x_conv = Conv2D(filters, 1, padding='same', kernel_initializer='he_normal', name=f'{name}_x')(x)\n    \n    psi = Activation('relu', name=f'{name}_relu')(Add(name=f'{name}_add')([g_conv, x_conv]))\n    psi = Conv2D(1, 1, padding='same', kernel_initializer='he_normal', name=f'{name}_psi')(psi)\n    psi = Activation('sigmoid', name=f'{name}_sigmoid')(psi)\n    \n    return Multiply(name=f'{name}_mult')([x, psi])\n\n# ==============================================================================\n# MODEL 1: U-Net\n# ==============================================================================\n\ndef build_unet(cfg):\n    \"\"\"Classic U-Net architecture\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING U-NET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    # Encoder\n    c1 = conv_block(inputs, 64, name='enc1_1')\n    c1 = conv_block(c1, 64, name='enc1_2')\n    p1 = MaxPooling2D((2, 2), name='pool1')(c1)\n    \n    c2 = conv_block(p1, 128, name='enc2_1')\n    c2 = conv_block(c2, 128, name='enc2_2')\n    p2 = MaxPooling2D((2, 2), name='pool2')(c2)\n    \n    c3 = conv_block(p2, 256, name='enc3_1')\n    c3 = conv_block(c3, 256, name='enc3_2')\n    p3 = MaxPooling2D((2, 2), name='pool3')(c3)\n    \n    c4 = conv_block(p3, 512, name='enc4_1')\n    c4 = conv_block(c4, 512, name='enc4_2')\n    p4 = MaxPooling2D((2, 2), name='pool4')(c4)\n    \n    # Bottleneck\n    c5 = conv_block(p4, 1024, name='bottleneck_1')\n    c5 = conv_block(c5, 1024, name='bottleneck_2')\n    \n    # Decoder\n    u6 = Conv2DTranspose(512, 2, strides=2, padding='same', name='up6')(c5)\n    u6 = Concatenate(name='concat6')([u6, c4])\n    c6 = conv_block(u6, 512, name='dec6_1')\n    c6 = conv_block(c6, 512, name='dec6_2')\n    \n    u7 = Conv2DTranspose(256, 2, strides=2, padding='same', name='up7')(c6)\n    u7 = Concatenate(name='concat7')([u7, c3])\n    c7 = conv_block(u7, 256, name='dec7_1')\n    c7 = conv_block(c7, 256, name='dec7_2')\n    \n    u8 = Conv2DTranspose(128, 2, strides=2, padding='same', name='up8')(c7)\n    u8 = Concatenate(name='concat8')([u8, c2])\n    c8 = conv_block(u8, 128, name='dec8_1')\n    c8 = conv_block(c8, 128, name='dec8_2')\n    \n    u9 = Conv2DTranspose(64, 2, strides=2, padding='same', name='up9')(c8)\n    u9 = Concatenate(name='concat9')([u9, c1])\n    c9 = conv_block(u9, 64, name='dec9_1')\n    c9 = conv_block(c9, 64, name='dec9_2')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(c9)\n    \n    model = Model(inputs, outputs, name='UNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 2: UNet++\n# ==============================================================================\n\ndef build_unetplusplus(cfg):\n    \"\"\"UNet++ with nested skip pathways\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING UNET++\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    filters = [32, 64, 128, 256, 512]\n    \n    # Encoder\n    x00 = conv_block(inputs, filters[0], name='x00_1')\n    x00 = conv_block(x00, filters[0], name='x00_2')\n    p0 = MaxPooling2D((2, 2), name='pool0')(x00)\n    \n    x10 = conv_block(p0, filters[1], name='x10_1')\n    x10 = conv_block(x10, filters[1], name='x10_2')\n    p1 = MaxPooling2D((2, 2), name='pool1')(x10)\n    \n    x01 = Conv2DTranspose(filters[0], 2, strides=2, padding='same', name='up01')(x10)\n    x01 = Concatenate(name='concat01')([x00, x01])\n    x01 = conv_block(x01, filters[0], name='x01_1')\n    x01 = conv_block(x01, filters[0], name='x01_2')\n    \n    x20 = conv_block(p1, filters[2], name='x20_1')\n    x20 = conv_block(x20, filters[2], name='x20_2')\n    p2 = MaxPooling2D((2, 2), name='pool2')(x20)\n    \n    x11 = Conv2DTranspose(filters[1], 2, strides=2, padding='same', name='up11')(x20)\n    x11 = Concatenate(name='concat11')([x10, x11])\n    x11 = conv_block(x11, filters[1], name='x11_1')\n    x11 = conv_block(x11, filters[1], name='x11_2')\n    \n    x02 = Conv2DTranspose(filters[0], 2, strides=2, padding='same', name='up02')(x11)\n    x02 = Concatenate(name='concat02')([x00, x01, x02])\n    x02 = conv_block(x02, filters[0], name='x02_1')\n    x02 = conv_block(x02, filters[0], name='x02_2')\n    \n    x30 = conv_block(p2, filters[3], name='x30_1')\n    x30 = conv_block(x30, filters[3], name='x30_2')\n    p3 = MaxPooling2D((2, 2), name='pool3')(x30)\n    \n    x21 = Conv2DTranspose(filters[2], 2, strides=2, padding='same', name='up21')(x30)\n    x21 = Concatenate(name='concat21')([x20, x21])\n    x21 = conv_block(x21, filters[2], name='x21_1')\n    x21 = conv_block(x21, filters[2], name='x21_2')\n    \n    x12 = Conv2DTranspose(filters[1], 2, strides=2, padding='same', name='up12')(x21)\n    x12 = Concatenate(name='concat12')([x10, x11, x12])\n    x12 = conv_block(x12, filters[1], name='x12_1')\n    x12 = conv_block(x12, filters[1], name='x12_2')\n    \n    x03 = Conv2DTranspose(filters[0], 2, strides=2, padding='same', name='up03')(x12)\n    x03 = Concatenate(name='concat03')([x00, x01, x02, x03])\n    x03 = conv_block(x03, filters[0], name='x03_1')\n    x03 = conv_block(x03, filters[0], name='x03_2')\n    \n    # Bottleneck\n    x40 = conv_block(p3, filters[4], name='x40_1')\n    x40 = conv_block(x40, filters[4], name='x40_2')\n    \n    # Decoder\n    x31 = Conv2DTranspose(filters[3], 2, strides=2, padding='same', name='up31')(x40)\n    x31 = Concatenate(name='concat31')([x30, x31])\n    x31 = conv_block(x31, filters[3], name='x31_1')\n    x31 = conv_block(x31, filters[3], name='x31_2')\n    \n    x22 = Conv2DTranspose(filters[2], 2, strides=2, padding='same', name='up22')(x31)\n    x22 = Concatenate(name='concat22')([x20, x21, x22])\n    x22 = conv_block(x22, filters[2], name='x22_1')\n    x22 = conv_block(x22, filters[2], name='x22_2')\n    \n    x13 = Conv2DTranspose(filters[1], 2, strides=2, padding='same', name='up13')(x22)\n    x13 = Concatenate(name='concat13')([x10, x11, x12, x13])\n    x13 = conv_block(x13, filters[1], name='x13_1')\n    x13 = conv_block(x13, filters[1], name='x13_2')\n    \n    x04 = Conv2DTranspose(filters[0], 2, strides=2, padding='same', name='up04')(x13)\n    x04 = Concatenate(name='concat04')([x00, x01, x02, x03, x04])\n    x04 = conv_block(x04, filters[0], name='x04_1')\n    x04 = conv_block(x04, filters[0], name='x04_2')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(x04)\n    \n    model = Model(inputs, outputs, name='UNetPlusPlus')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 3: PraNet (Parallel Reverse Attention Network)\n# ==============================================================================\n\ndef rfb_module(x, filters, name='rfb'):\n    \"\"\"Receptive Field Block\"\"\"\n    branch1 = Conv2D(filters//4, 1, padding='same', name=f'{name}_b1_1')(x)\n    branch1 = Conv2D(filters//4, 3, padding='same', dilation_rate=1, name=f'{name}_b1_2')(branch1)\n    \n    branch2 = Conv2D(filters//4, 1, padding='same', name=f'{name}_b2_1')(x)\n    branch2 = Conv2D(filters//4, 3, padding='same', dilation_rate=3, name=f'{name}_b2_2')(branch2)\n    \n    branch3 = Conv2D(filters//4, 1, padding='same', name=f'{name}_b3_1')(x)\n    branch3 = Conv2D(filters//4, 3, padding='same', dilation_rate=5, name=f'{name}_b3_2')(branch3)\n    \n    branch4 = Conv2D(filters//4, 1, padding='same', name=f'{name}_b4')(x)\n    \n    concat = Concatenate(name=f'{name}_concat')([branch1, branch2, branch3, branch4])\n    output = Conv2D(filters, 1, padding='same', name=f'{name}_fuse')(concat)\n    output = BatchNormalization(name=f'{name}_bn')(output)\n    output = Activation('relu', name=f'{name}_act')(output)\n    \n    return output\n\ndef reverse_attention(x, filters, name='ra'):\n    \"\"\"Reverse Attention Module\"\"\"\n    attention = GlobalAveragePooling2D(name=f'{name}_gap')(x)\n    attention = Dense(filters//16, activation='relu', name=f'{name}_fc1')(attention)\n    attention = Dense(filters, activation='sigmoid', name=f'{name}_fc2')(attention)\n    attention = Reshape((1, 1, filters), name=f'{name}_reshape')(attention)\n    \n    x_att = Multiply(name=f'{name}_mult')([x, attention])\n    x_reverse = Lambda(lambda x: 1 - x, name=f'{name}_reverse')(attention)\n    x_rev_att = Multiply(name=f'{name}_rev_mult')([x, x_reverse])\n    \n    return Add(name=f'{name}_add')([x_att, x_rev_att])\n\ndef build_pranet(cfg):\n    \"\"\"PraNet: Parallel Reverse Attention Network\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING PRANET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    # Encoder with RFB modules\n    x = conv_block(inputs, 64, name='enc1')\n    x = MaxPooling2D((2, 2), name='pool1')(x)\n    e1 = rfb_module(x, 64, name='rfb1')\n    \n    x = conv_block(e1, 128, name='enc2')\n    x = MaxPooling2D((2, 2), name='pool2')(x)\n    e2 = rfb_module(x, 128, name='rfb2')\n    \n    x = conv_block(e2, 256, name='enc3')\n    x = MaxPooling2D((2, 2), name='pool3')(x)\n    e3 = rfb_module(x, 256, name='rfb3')\n    \n    x = conv_block(e3, 512, name='enc4')\n    x = MaxPooling2D((2, 2), name='pool4')(x)\n    e4 = rfb_module(x, 512, name='rfb4')\n    \n    # Decoder with reverse attention\n    d4 = reverse_attention(e4, 512, name='ra4')\n    d4 = UpSampling2D((2, 2), name='up4')(d4)\n    d4 = Concatenate(name='concat4')([d4, e3])\n    d4 = conv_block(d4, 256, name='dec4')\n    \n    d3 = reverse_attention(d4, 256, name='ra3')\n    d3 = UpSampling2D((2, 2), name='up3')(d3)\n    d3 = Concatenate(name='concat3')([d3, e2])\n    d3 = conv_block(d3, 128, name='dec3')\n    \n    d2 = reverse_attention(d3, 128, name='ra2')\n    d2 = UpSampling2D((2, 2), name='up2')(d2)\n    d2 = Concatenate(name='concat2')([d2, e1])\n    d2 = conv_block(d2, 64, name='dec2')\n    \n    d1 = reverse_attention(d2, 64, name='ra1')\n    d1 = UpSampling2D((2, 2), name='up1')(d1)\n    d1 = conv_block(d1, 32, name='dec1')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(d1)\n    \n    model = Model(inputs, outputs, name='PraNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 4: RAPUNet (Residual Attention Pyramid U-Net)\n# ==============================================================================\n\ndef residual_block(x, filters, name='res'):\n    \"\"\"Residual block with attention\"\"\"\n    shortcut = x\n    \n    x = conv_block(x, filters, name=f'{name}_1')\n    x = conv_block(x, filters, activation='linear', name=f'{name}_2')\n    \n    if shortcut.shape[-1] != filters:\n        shortcut = Conv2D(filters, 1, padding='same', name=f'{name}_shortcut')(shortcut)\n    \n    x = Add(name=f'{name}_add')([x, shortcut])\n    x = Activation('relu', name=f'{name}_act')(x)\n    \n    return x\n\ndef pyramid_pooling(x, filters, name='ppm'):\n    \"\"\"Pyramid Pooling Module\"\"\"\n    input_shape = x.shape\n    h, w = input_shape[1], input_shape[2]\n    \n    # Pool 1: Global Average Pooling\n    pool1 = GlobalAveragePooling2D(keepdims=True, name=f'{name}_gap')(x)\n    pool1 = Conv2D(filters//4, 1, padding='same', name=f'{name}_conv1')(pool1)\n    pool1 = Lambda(\n        lambda inp: tf.image.resize(inp, [h, w], method='bilinear'),\n        output_shape=(h, w, filters//4),\n        name=f'{name}_up1'\n    )(pool1)\n    \n    # Pool 2: 2x2 pooling\n    pool2 = MaxPooling2D((2, 2), name=f'{name}_pool2')(x)\n    pool2 = Conv2D(filters//4, 1, padding='same', name=f'{name}_conv2')(pool2)\n    pool2 = Lambda(\n        lambda inp: tf.image.resize(inp, [h, w], method='bilinear'),\n        output_shape=(h, w, filters//4),\n        name=f'{name}_up2'\n    )(pool2)\n    \n    # Pool 3: 4x4 pooling  \n    pool3 = MaxPooling2D((4, 4), name=f'{name}_pool3')(x)\n    pool3 = Conv2D(filters//4, 1, padding='same', name=f'{name}_conv3')(pool3)\n    pool3 = Lambda(\n        lambda inp: tf.image.resize(inp, [h, w], method='bilinear'),\n        output_shape=(h, w, filters//4),\n        name=f'{name}_up3'\n    )(pool3)\n    \n    # Pool 4: 8x8 pooling\n    pool4 = MaxPooling2D((8, 8), name=f'{name}_pool4')(x)\n    pool4 = Conv2D(filters//4, 1, padding='same', name=f'{name}_conv4')(pool4)\n    pool4 = Lambda(\n        lambda inp: tf.image.resize(inp, [h, w], method='bilinear'),\n        output_shape=(h, w, filters//4),\n        name=f'{name}_up4'\n    )(pool4)\n    \n    # Concatenate all pyramid levels\n    concat = Concatenate(name=f'{name}_concat')([pool1, pool2, pool3, pool4])\n    output = Conv2D(filters, 1, padding='same', name=f'{name}_fuse')(concat)\n    \n    return output\n\ndef build_rapunet(cfg):\n    \"\"\"RAPUNet: Residual Attention Pyramid U-Net\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING RAPUNET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    # Encoder\n    c1 = residual_block(inputs, 64, name='enc1')\n    p1 = MaxPooling2D((2, 2), name='pool1')(c1)\n    \n    c2 = residual_block(p1, 128, name='enc2')\n    p2 = MaxPooling2D((2, 2), name='pool2')(c2)\n    \n    c3 = residual_block(p2, 256, name='enc3')\n    p3 = MaxPooling2D((2, 2), name='pool3')(c3)\n    \n    c4 = residual_block(p3, 512, name='enc4')\n    p4 = MaxPooling2D((2, 2), name='pool4')(c4)\n    \n    # Bottleneck with Pyramid Pooling\n    c5 = residual_block(p4, 1024, name='bottleneck')\n    c5 = pyramid_pooling(c5, 1024, name='ppm')\n    \n    # Decoder with attention gates\n    u6 = Conv2DTranspose(512, 2, strides=2, padding='same', name='up6')(c5)\n    a6 = attention_gate(u6, c4, 512, name='att6')\n    u6 = Concatenate(name='concat6')([u6, a6])\n    c6 = residual_block(u6, 512, name='dec6')\n    \n    u7 = Conv2DTranspose(256, 2, strides=2, padding='same', name='up7')(c6)\n    a7 = attention_gate(u7, c3, 256, name='att7')\n    u7 = Concatenate(name='concat7')([u7, a7])\n    c7 = residual_block(u7, 256, name='dec7')\n    \n    u8 = Conv2DTranspose(128, 2, strides=2, padding='same', name='up8')(c7)\n    a8 = attention_gate(u8, c2, 128, name='att8')\n    u8 = Concatenate(name='concat8')([u8, a8])\n    c8 = residual_block(u8, 128, name='dec8')\n    \n    u9 = Conv2DTranspose(64, 2, strides=2, padding='same', name='up9')(c8)\n    a9 = attention_gate(u9, c1, 64, name='att9')\n    u9 = Concatenate(name='concat9')([u9, a9])\n    c9 = residual_block(u9, 64, name='dec9')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(c9)\n    \n    model = Model(inputs, outputs, name='RAPUNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 5: Swin-UNet (Swin Transformer U-Net)\n# ==============================================================================\n\nclass PatchEmbedding(Layer):\n    \"\"\"Patch embedding layer\"\"\"\n    def __init__(self, patch_size=4, embed_dim=96, **kwargs):\n        super().__init__(**kwargs)\n        self.patch_size = patch_size\n        self.embed_dim = embed_dim\n        \n    def build(self, input_shape):\n        self.projection = Conv2D(\n            self.embed_dim, \n            kernel_size=self.patch_size, \n            strides=self.patch_size,\n            padding='valid',\n            name='projection'\n        )\n        super().build(input_shape)\n        \n    def call(self, x):\n        x = self.projection(x)\n        return x\n\nclass SwinTransformerBlock(Layer):\n    \"\"\"Simplified Swin Transformer Block\"\"\"\n    def __init__(self, dim, num_heads, window_size=7, **kwargs):\n        super().__init__(**kwargs)\n        self.dim = dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        \n    def build(self, input_shape):\n        self.norm1 = LayerNormalization(epsilon=1e-5, name='norm1')\n        self.attn = MultiHeadAttention(\n            num_heads=self.num_heads,\n            key_dim=self.dim // self.num_heads,\n            name='attn'\n        )\n        self.norm2 = LayerNormalization(epsilon=1e-5, name='norm2')\n        self.mlp = tf.keras.Sequential([\n            Dense(self.dim * 4, activation='gelu', name='fc1'),\n            Dense(self.dim, name='fc2')\n        ], name='mlp')\n        super().build(input_shape)\n        \n    def call(self, x):\n        h, w = x.shape[1], x.shape[2]\n        shortcut = x\n        \n        x = self.norm1(x)\n        x_2d = Reshape((h * w, self.dim))(x)\n        attn_out = self.attn(x_2d, x_2d)\n        attn_out = Reshape((h, w, self.dim))(attn_out)\n        x = Add()([shortcut, attn_out])\n        \n        shortcut = x\n        x = self.norm2(x)\n        x_2d = Reshape((h * w, self.dim))(x)\n        mlp_out = self.mlp(x_2d)\n        mlp_out = Reshape((h, w, self.dim))(mlp_out)\n        x = Add()([shortcut, mlp_out])\n        \n        return x\n\ndef build_swinunet(cfg):\n    \"\"\"Swin-UNet: Swin Transformer for segmentation\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING SWIN-UNET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    # Patch embedding\n    x = PatchEmbedding(patch_size=4, embed_dim=96, name='patch_embed')(inputs)\n    \n    # Encoder\n    e1 = SwinTransformerBlock(96, 3, name='swin1')(x)\n    e1 = SwinTransformerBlock(96, 3, name='swin1_2')(e1)\n    p1 = Conv2D(192, 2, strides=2, padding='same', name='down1')(e1)\n    \n    e2 = SwinTransformerBlock(192, 6, name='swin2')(p1)\n    e2 = SwinTransformerBlock(192, 6, name='swin2_2')(e2)\n    p2 = Conv2D(384, 2, strides=2, padding='same', name='down2')(e2)\n    \n    e3 = SwinTransformerBlock(384, 12, name='swin3')(p2)\n    e3 = SwinTransformerBlock(384, 12, name='swin3_2')(e3)\n    p3 = Conv2D(768, 2, strides=2, padding='same', name='down3')(e3)\n    \n    # Bottleneck\n    bottleneck = SwinTransformerBlock(768, 24, name='swin_bottleneck')(p3)\n    \n    # Decoder\n    u1 = UpSampling2D((2, 2), name='up1')(bottleneck)\n    u1 = Conv2D(384, 1, padding='same', name='reduce1')(u1)\n    u1 = Concatenate(name='concat1')([u1, e3])\n    d1 = conv_block(u1, 384, name='dec1')\n    \n    u2 = UpSampling2D((2, 2), name='up2')(d1)\n    u2 = Conv2D(192, 1, padding='same', name='reduce2')(u2)\n    u2 = Concatenate(name='concat2')([u2, e2])\n    d2 = conv_block(u2, 192, name='dec2')\n    \n    u3 = UpSampling2D((2, 2), name='up3')(d2)\n    u3 = Conv2D(96, 1, padding='same', name='reduce3')(u3)\n    u3 = Concatenate(name='concat3')([u3, e1])\n    d3 = conv_block(u3, 96, name='dec3')\n    \n    # Final upsampling to original resolution\n    final = UpSampling2D((4, 4), name='final_up')(d3)\n    final = conv_block(final, 64, name='final_conv1')\n    final = conv_block(final, 32, name='final_conv2')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(final)\n    \n    model = Model(inputs, outputs, name='SwinUNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 6: TransUNet\n# ==============================================================================\n\ndef transformer_encoder(x, num_heads, mlp_dim, name='transformer'):\n    \"\"\"Transformer encoder block\"\"\"\n    # Multi-head attention\n    shortcut = x\n    x = LayerNormalization(epsilon=1e-6, name=f'{name}_norm1')(x)\n    x = MultiHeadAttention(\n        num_heads=num_heads, \n        key_dim=x.shape[-1] // num_heads,\n        name=f'{name}_mha'\n    )(x, x)\n    x = Add(name=f'{name}_add1')([x, shortcut])\n    \n    # MLP\n    shortcut = x\n    x = LayerNormalization(epsilon=1e-6, name=f'{name}_norm2')(x)\n    x = Dense(mlp_dim, activation='gelu', name=f'{name}_mlp1')(x)\n    x = Dense(shortcut.shape[-1], name=f'{name}_mlp2')(x)\n    x = Add(name=f'{name}_add2')([x, shortcut])\n    \n    return x\n\ndef build_transunet(cfg):\n    \"\"\"TransUNet: Transformer + U-Net\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING TRANSUNET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    # CNN Encoder\n    c1 = conv_block(inputs, 64, name='enc1_1')\n    c1 = conv_block(c1, 64, name='enc1_2')\n    p1 = MaxPooling2D((2, 2), name='pool1')(c1)\n    \n    c2 = conv_block(p1, 128, name='enc2_1')\n    c2 = conv_block(c2, 128, name='enc2_2')\n    p2 = MaxPooling2D((2, 2), name='pool2')(c2)\n    \n    c3 = conv_block(p2, 256, name='enc3_1')\n    c3 = conv_block(c3, 256, name='enc3_2')\n    p3 = MaxPooling2D((2, 2), name='pool3')(c3)\n    \n    # Transformer bottleneck\n    h, w, c = p3.shape[1], p3.shape[2], p3.shape[3]\n    x = Reshape((h * w, c), name='reshape')(p3)\n    \n    for i in range(cfg.TRANSFORMER_LAYERS):\n        x = transformer_encoder(x, cfg.NUM_HEADS, 512, name=f'transformer_{i}')\n    \n    x = Reshape((h, w, c), name='reshape_back')(x)\n    \n    # CNN Decoder\n    u1 = Conv2DTranspose(256, 2, strides=2, padding='same', name='up1')(x)\n    u1 = Concatenate(name='concat1')([u1, c3])\n    d1 = conv_block(u1, 256, name='dec1_1')\n    d1 = conv_block(d1, 256, name='dec1_2')\n    \n    u2 = Conv2DTranspose(128, 2, strides=2, padding='same', name='up2')(d1)\n    u2 = Concatenate(name='concat2')([u2, c2])\n    d2 = conv_block(u2, 128, name='dec2_1')\n    d2 = conv_block(d2, 128, name='dec2_2')\n    \n    u3 = Conv2DTranspose(64, 2, strides=2, padding='same', name='up3')(d2)\n    u3 = Concatenate(name='concat3')([u3, c1])\n    d3 = conv_block(u3, 64, name='dec3_1')\n    d3 = conv_block(d3, 64, name='dec3_2')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(d3)\n    \n    model = Model(inputs, outputs, name='TransUNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 7: UMamba (State Space Model)\n# ==============================================================================\n\nclass SSMLayer(Layer):\n    \"\"\"Simplified State Space Model Layer\"\"\"\n    def __init__(self, d_model, d_state=16, **kwargs):\n        super().__init__(**kwargs)\n        self.d_model = d_model\n        self.d_state = d_state\n        \n    def build(self, input_shape):\n        self.A = self.add_weight(\n            name='A',\n            shape=(self.d_model, self.d_state),\n            initializer='glorot_uniform',\n            trainable=True\n        )\n        self.B = Dense(self.d_state, name='B')\n        self.C = Dense(self.d_model, name='C')\n        self.D = Dense(self.d_model, name='D')\n        \n        self.norm = LayerNormalization(epsilon=1e-5, name='norm')\n        super().build(input_shape)\n        \n    def call(self, x):\n        B, H, W, C = x.shape\n        shortcut = x\n        \n        # Flatten spatial dimensions\n        x_flat = Reshape((H * W, C))(x)\n        \n        # Apply SSM\n        x_ssm = self.C(x_flat) + self.D(x_flat)\n        x_ssm = self.norm(x_ssm)\n        \n        # Reshape back\n        x_out = Reshape((H, W, C))(x_ssm)\n        \n        return Add()([shortcut, x_out])\n\ndef build_umamba(cfg):\n    \"\"\"UMamba: Mamba-based U-Net\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING UMAMBA\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    # Encoder with SSM blocks\n    c1 = conv_block(inputs, 64, name='enc1')\n    c1 = SSMLayer(64, name='ssm1')(c1)\n    p1 = MaxPooling2D((2, 2), name='pool1')(c1)\n    \n    c2 = conv_block(p1, 128, name='enc2')\n    c2 = SSMLayer(128, name='ssm2')(c2)\n    p2 = MaxPooling2D((2, 2), name='pool2')(c2)\n    \n    c3 = conv_block(p2, 256, name='enc3')\n    c3 = SSMLayer(256, name='ssm3')(c3)\n    p3 = MaxPooling2D((2, 2), name='pool3')(c3)\n    \n    c4 = conv_block(p3, 512, name='enc4')\n    c4 = SSMLayer(512, name='ssm4')(c4)\n    p4 = MaxPooling2D((2, 2), name='pool4')(c4)\n    \n    # Bottleneck\n    c5 = conv_block(p4, 1024, name='bottleneck_1')\n    c5 = SSMLayer(1024, name='ssm_bottleneck')(c5)\n    c5 = conv_block(c5, 1024, name='bottleneck_2')\n    \n    # Decoder\n    u6 = Conv2DTranspose(512, 2, strides=2, padding='same', name='up6')(c5)\n    u6 = Concatenate(name='concat6')([u6, c4])\n    c6 = conv_block(u6, 512, name='dec6')\n    c6 = SSMLayer(512, name='ssm_dec6')(c6)\n    \n    u7 = Conv2DTranspose(256, 2, strides=2, padding='same', name='up7')(c6)\n    u7 = Concatenate(name='concat7')([u7, c3])\n    c7 = conv_block(u7, 256, name='dec7')\n    c7 = SSMLayer(256, name='ssm_dec7')(c7)\n    \n    u8 = Conv2DTranspose(128, 2, strides=2, padding='same', name='up8')(c7)\n    u8 = Concatenate(name='concat8')([u8, c2])\n    c8 = conv_block(u8, 128, name='dec8')\n    c8 = SSMLayer(128, name='ssm_dec8')(c8)\n    \n    u9 = Conv2DTranspose(64, 2, strides=2, padding='same', name='up9')(c8)\n    u9 = Concatenate(name='concat9')([u9, c1])\n    c9 = conv_block(u9, 64, name='dec9')\n    c9 = SSMLayer(64, name='ssm_dec9')(c9)\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(c9)\n    \n    model = Model(inputs, outputs, name='UMamba')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 8: DUCK-Net\n# ==============================================================================\n\ndef duck_block(x, filters, dilation_rates=[1, 3, 5], name='duck'):\n    \"\"\"DUCK Block: Dilated and Upsampled Concatenation Kernel block\"\"\"\n    # Dimension reduction\n    x_reduced = Conv2D(filters // 2, 1, padding='same', \n                      kernel_initializer='he_normal',\n                      name=f'{name}_reduce')(x)\n    x_reduced = BatchNormalization(name=f'{name}_reduce_bn')(x_reduced)\n    x_reduced = Activation('relu', name=f'{name}_reduce_act')(x_reduced)\n    \n    # Parallel dilated convolutions\n    branches = []\n    for i, rate in enumerate(dilation_rates):\n        branch = Conv2D(filters // len(dilation_rates), 3, \n                       padding='same', \n                       dilation_rate=rate,\n                       kernel_initializer='he_normal',\n                       name=f'{name}_dil{rate}')(x_reduced)\n        branch = BatchNormalization(name=f'{name}_dil{rate}_bn')(branch)\n        branch = Activation('relu', name=f'{name}_dil{rate}_act')(branch)\n        branches.append(branch)\n    \n    # Concatenate all branches\n    if len(branches) > 1:\n        x_concat = Concatenate(name=f'{name}_concat')(branches)\n    else:\n        x_concat = branches[0]\n    \n    # Fusion with 1x1 convolution\n    x_fused = Conv2D(filters, 1, padding='same',\n                    kernel_initializer='he_normal',\n                    name=f'{name}_fuse')(x_concat)\n    x_fused = BatchNormalization(name=f'{name}_fuse_bn')(x_fused)\n    \n    # Residual connection\n    if x.shape[-1] != filters:\n        shortcut = Conv2D(filters, 1, padding='same',\n                         kernel_initializer='he_normal',\n                         name=f'{name}_shortcut')(x)\n        shortcut = BatchNormalization(name=f'{name}_shortcut_bn')(shortcut)\n    else:\n        shortcut = x\n    \n    x_out = Add(name=f'{name}_add')([x_fused, shortcut])\n    x_out = Activation('relu', name=f'{name}_out_act')(x_out)\n    \n    return x_out\n\ndef separated_block(x, filters, name='sep'):\n    \"\"\"Separated convolution block\"\"\"\n    x = DepthwiseConv2D(3, padding='same',\n                       depthwise_initializer='he_normal',\n                       name=f'{name}_dw')(x)\n    x = BatchNormalization(name=f'{name}_dw_bn')(x)\n    x = Activation('relu', name=f'{name}_dw_act')(x)\n    \n    x = Conv2D(filters, 1, padding='same',\n              kernel_initializer='he_normal',\n              name=f'{name}_pw')(x)\n    x = BatchNormalization(name=f'{name}_pw_bn')(x)\n    x = Activation('relu', name=f'{name}_pw_act')(x)\n    \n    return x\n\ndef build_ducknet(cfg):\n    \"\"\"DUCK-Net: Dilated and Upsampled Concatenation Kernel Network\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING DUCK-NET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    filters = cfg.DUCKNET_FILTERS\n    dilation_rates = cfg.DUCKNET_DILATION_RATES\n    \n    print(f\"   Filter configuration: {filters}\")\n    print(f\"   Dilation rates: {dilation_rates}\")\n    \n    # Encoder\n    e1 = Conv2D(filters[0], 3, padding='same', \n               kernel_initializer='he_normal', name='enc1_conv')(inputs)\n    e1 = BatchNormalization(name='enc1_bn')(e1)\n    e1 = Activation('relu', name='enc1_act')(e1)\n    e1 = duck_block(e1, filters[0], dilation_rates, name='duck1')\n    p1 = MaxPooling2D((2, 2), name='pool1')(e1)\n    \n    e2 = duck_block(p1, filters[1], dilation_rates, name='duck2_1')\n    e2 = duck_block(e2, filters[1], dilation_rates, name='duck2_2')\n    p2 = MaxPooling2D((2, 2), name='pool2')(e2)\n    \n    e3 = duck_block(p2, filters[2], dilation_rates, name='duck3_1')\n    e3 = duck_block(e3, filters[2], dilation_rates, name='duck3_2')\n    p3 = MaxPooling2D((2, 2), name='pool3')(e3)\n    \n    e4 = duck_block(p3, filters[3], dilation_rates, name='duck4_1')\n    e4 = duck_block(e4, filters[3], dilation_rates, name='duck4_2')\n    p4 = MaxPooling2D((2, 2), name='pool4')(e4)\n    \n    # Bottleneck\n    e5 = duck_block(p4, filters[4], dilation_rates, name='duck5_1')\n    e5 = duck_block(e5, filters[4], dilation_rates, name='duck5_2')\n    \n    # Decoder\n    u4 = Conv2DTranspose(filters[3], 2, strides=2, padding='same', name='up4')(e5)\n    u4 = Concatenate(name='concat4')([u4, e4])\n    d4 = duck_block(u4, filters[3], dilation_rates, name='duck_dec4_1')\n    d4 = separated_block(d4, filters[3], name='sep4')\n    \n    u3 = Conv2DTranspose(filters[2], 2, strides=2, padding='same', name='up3')(d4)\n    u3 = Concatenate(name='concat3')([u3, e3])\n    d3 = duck_block(u3, filters[2], dilation_rates, name='duck_dec3_1')\n    d3 = separated_block(d3, filters[2], name='sep3')\n    \n    u2 = Conv2DTranspose(filters[1], 2, strides=2, padding='same', name='up2')(d3)\n    u2 = Concatenate(name='concat2')([u2, e2])\n    d2 = duck_block(u2, filters[1], dilation_rates, name='duck_dec2_1')\n    d2 = separated_block(d2, filters[1], name='sep2')\n    \n    u1 = Conv2DTranspose(filters[0], 2, strides=2, padding='same', name='up1')(d2)\n    u1 = Concatenate(name='concat1')([u1, e1])\n    d1 = duck_block(u1, filters[0], dilation_rates, name='duck_dec1_1')\n    d1 = separated_block(d1, filters[0], name='sep1')\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(d1)\n    \n    model = Model(inputs, outputs, name='DUCKNET')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 9: MedSegNet-SSF\n# ==============================================================================\n\nclass SpectralSelectiveTokenMixer(Layer):\n    \"\"\"Spectral-Selective Token Mixer\"\"\"\n    def __init__(self, channels, num_frequencies=32, ssm_state_dim=16, \n                 use_spectral=True, use_ssm=True, dropout=0.0, **kwargs):\n        super().__init__(**kwargs)\n        self.channels = channels\n        self.num_frequencies = num_frequencies\n        self.ssm_state_dim = ssm_state_dim\n        self.use_spectral = use_spectral\n        self.use_ssm = use_ssm\n        self.dropout_rate = dropout\n        \n    def build(self, input_shape):\n        input_h, input_w = input_shape[1], input_shape[2]\n        self.actual_frequencies = min(self.num_frequencies, input_h, input_w) if input_h else self.num_frequencies\n        \n        if self.use_spectral:\n            self.freq_weights_real = self.add_weight(\n                name='freq_weights_real',\n                shape=(self.actual_frequencies, self.actual_frequencies, self.channels),\n                initializer=self._get_initializer(),\n                trainable=True\n            )\n            self.freq_weights_imag = self.add_weight(\n                name='freq_weights_imag',\n                shape=(self.actual_frequencies, self.actual_frequencies, self.channels),\n                initializer='zeros',\n                trainable=True\n            )\n            self.spectral_norm = LayerNormalization(epsilon=1e-6, name='spectral_norm')\n        \n        if self.use_ssm:\n            self.ssm_delta = Dense(self.channels, name='ssm_delta')\n            self.ssm_A = self.add_weight(\n                name='ssm_A',\n                shape=(self.channels, self.ssm_state_dim),\n                initializer='glorot_uniform',\n                trainable=True\n            )\n            self.ssm_B = Dense(self.ssm_state_dim, name='ssm_B')\n            self.ssm_C = Dense(self.channels, name='ssm_C')\n            self.selection_gate = Dense(self.channels, activation='sigmoid', name='selection')\n            self.ssm_norm = LayerNormalization(epsilon=1e-6, name='ssm_norm')\n        \n        if self.use_spectral and self.use_ssm:\n            self.fusion = Dense(self.channels, name='fusion')\n            self.fusion_norm = LayerNormalization(epsilon=1e-6, name='fusion_norm')\n        \n        self.norm = LayerNormalization(epsilon=1e-6, name='norm')\n        super().build(input_shape)\n    \n    def _get_initializer(self):\n        def init_fn(shape, dtype=None):\n            H, W, C = shape\n            freq_h = np.fft.fftfreq(H)[:, np.newaxis]\n            freq_w = np.fft.fftfreq(W)[np.newaxis, :]\n            freq_magnitude = np.sqrt(freq_h**2 + freq_w**2)\n            gaussian = np.exp(-((freq_magnitude - 0.25)**2) / (2 * 0.15**2))\n            gaussian = np.repeat(gaussian[:, :, np.newaxis], C, axis=2)\n            return gaussian.astype(np.float32) * 0.5\n        return init_fn\n    \n    def spectral_path(self, x):\n        H, W = tf.shape(x)[1], tf.shape(x)[2]\n        freq_size = tf.minimum(tf.minimum(H, W), self.actual_frequencies)\n        x_complex = tf.cast(x, tf.complex64)\n        x_freq = tf.signal.fft2d(x_complex)\n        x_freq_real = tf.math.real(x_freq)\n        x_freq_imag = tf.math.imag(x_freq)\n        x_freq_real_resized = tf.image.resize(x_freq_real, [freq_size, freq_size], method='bilinear')\n        x_freq_imag_resized = tf.image.resize(x_freq_imag, [freq_size, freq_size], method='bilinear')\n        x_freq_resized = tf.complex(x_freq_real_resized, x_freq_imag_resized)\n        freq_filter = tf.cast(self.freq_weights_real[:freq_size, :freq_size, :], tf.complex64)\n        x_freq_filtered = x_freq_resized * freq_filter\n        x_freq_filt_real = tf.math.real(x_freq_filtered)\n        x_freq_filt_imag = tf.math.imag(x_freq_filtered)\n        x_freq_back_real = tf.image.resize(x_freq_filt_real, [H, W], method='bilinear')\n        x_freq_back_imag = tf.image.resize(x_freq_filt_imag, [H, W], method='bilinear')\n        x_freq_back = tf.complex(x_freq_back_real, x_freq_back_imag)\n        x_spatial = tf.signal.ifft2d(x_freq_back)\n        return self.spectral_norm(tf.math.real(x_spatial))\n    \n    def ssm_path(self, x):\n        B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]\n        x_flat = tf.reshape(x, [B, H * W, self.channels])\n        selection = self.selection_gate(x_flat)\n        x_selected = x_flat * selection\n        x_ssm = self.ssm_C(x_selected)\n        return self.ssm_norm(tf.reshape(x_ssm, [B, H, W, self.channels]))\n    \n    def call(self, x, training=None):\n        outputs = []\n        if self.use_spectral:\n            outputs.append(self.spectral_path(x))\n        if self.use_ssm:\n            outputs.append(self.ssm_path(x))\n        \n        if len(outputs) == 2:\n            fused = self.fusion_norm(self.fusion(tf.concat(outputs, axis=-1)))\n        elif len(outputs) == 1:\n            fused = outputs[0]\n        else:\n            fused = x\n        \n        fused = self.norm(fused)\n        if training and self.dropout_rate > 0:\n            fused = tf.nn.dropout(fused, rate=self.dropout_rate)\n        return x + fused\n\ndef MRF_SE_BLOCK(x, filters, activation='elu', dropout=0.0, expand_ratio=6, \n                 regularizer=0.0, kernels=[3, 5, 7], se_reduction=16, name='mrf_se'):\n    \"\"\"Multi-Receptive Field with Squeeze-Excitation Block\"\"\"\n    F_expanded = filters * expand_ratio\n    \n    conv = Conv2D(F_expanded, (1, 1), padding='same', \n                  kernel_initializer='he_uniform',\n                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None,\n                  name=name+'_expand')(x) if expand_ratio > 1 else x\n    conv = Activation(activation, name=name+'_expand_act')(\n        BatchNormalization(name=name+'_expand_bn')(conv))\n    \n    features = []\n    for k in kernels:\n        dw = DepthwiseConv2D((k, k), padding='same',\n                            depthwise_initializer='he_uniform',\n                            depthwise_regularizer=l2(regularizer) if regularizer > 0 else None,\n                            name=f\"{name}_dw{k}x{k}\")(conv)\n        features.append(Activation(activation, name=f\"{name}_dw{k}x{k}_act\")(\n            BatchNormalization(name=f\"{name}_dw{k}x{k}_bn\")(dw)))\n    \n    combined = Concatenate(name=name+'_concat')(features) if len(features) > 1 else features[0]\n    if len(features) > 1:\n        combined = Activation(activation, name=name+'_fuse_act')(\n            BatchNormalization(name=name+'_fuse_bn')(\n                Conv2D(F_expanded, (1, 1), padding='same',\n                      kernel_initializer='he_uniform',\n                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,\n                      name=name+'_fuse')(combined)))\n    \n    gap = Reshape((1, 1, F_expanded), name=name+'_reshape')(\n        GlobalAveragePooling2D(name=name+'_gap')(combined))\n    se = Conv2D(F_expanded, (1, 1), activation='sigmoid',\n               kernel_initializer='he_uniform',\n               name=name+'_se_expand')(\n        Conv2D(max(F_expanded//se_reduction, 8), (1, 1),\n              activation=activation,\n              kernel_initializer='he_uniform',\n              name=name+'_se_reduce')(gap))\n    \n    projected = Conv2D(filters, (1, 1), padding='same',\n                      kernel_initializer='he_uniform',\n                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,\n                      name=name+'_project')(Multiply(name=name+'_se_mult')([combined, se]))\n    projected = BatchNormalization(name=name+'_project_bn')(projected)\n    \n    if dropout > 0:\n        projected = Dropout(dropout, name=name+'_dropout')(projected)\n    \n    out = Add(name=name+'_add')([projected, x])\n    return out\n\ndef boundary_detection_module(features, filters, name='boundary'):\n    \"\"\"Boundary detection module\"\"\"\n    boundary_conv = Conv2D(filters // 2, (3, 3), padding='same',\n                          activation='relu', name=name + '_conv')(features)\n    boundary_map = Conv2D(1, (1, 1), padding='same',\n                         activation='sigmoid', name=name + '_map')(boundary_conv)\n    return Multiply(name=name + '_mult')([features, boundary_map]), boundary_map\n\ndef BFP_decoder_stage(decoder_input, skip_features, filters, stage_name='bfp'):\n    \"\"\"Boundary-Focused Progressively decoder stage\"\"\"\n    region = Concatenate(name=stage_name+'_concat')([\n        UpSampling2D((2, 2), name=stage_name+'_up')(decoder_input),\n        skip_features\n    ])\n    \n    region = Activation('relu')(BatchNormalization()(\n        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv1')(region)))\n    region = Activation('relu')(BatchNormalization()(\n        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv2')(region)))\n    \n    boundary_features, boundary_map = boundary_detection_module(region, filters, stage_name+'_boundary')\n    \n    boundary_refined = Activation('relu')(BatchNormalization()(\n        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_boundary_refine')(boundary_features)))\n    \n    output = Activation('relu')(BatchNormalization()(\n        Conv2D(filters, (1, 1), padding='same', name=stage_name+'_fusion')(\n            region * (1 - boundary_map) + boundary_refined * boundary_map)))\n    \n    return output, boundary_map\n\ndef build_medsegnet_ssf(cfg):\n    \"\"\"Build MedSegNet-SSF model\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING MED-SEGNET-SSF\")\n    print(\"=\"*80)\n    \n    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    x = Conv2D(16, (3, 3), padding='same',\n              kernel_initializer='he_uniform', name='stem_conv')(inp)\n    x = BatchNormalization(name='stem_bn')(x)\n    x = Activation('elu', name='stem_act')(x)\n    \n    encoder_outputs = []\n    filters = [cfg.F1, cfg.F2, cfg.F3, cfg.F4, cfg.F5]\n    \n    for i, f in enumerate(filters):\n        x = Conv2D(f, (3, 3), strides=2, padding='same',\n                  kernel_initializer='he_uniform')(x)\n        x = BatchNormalization()(x)\n        x = Activation('elu')(x)\n        \n        if cfg.USE_MRF_SE:\n            x = MRF_SE_BLOCK(x, f, activation='elu',\n                           dropout=cfg.DROPOUT,\n                           expand_ratio=cfg.EXPAND_RATIO,\n                           regularizer=cfg.L2_REG,\n                           kernels=cfg.MRF_KERNELS,\n                           se_reduction=cfg.SE_REDUCTION,\n                           name=f'mrfse_stage{i+1}')\n        \n        if cfg.USE_SSTM:\n            x = SpectralSelectiveTokenMixer(\n                channels=f,\n                num_frequencies=cfg.SSTM_NUM_FREQUENCIES,\n                ssm_state_dim=cfg.SSTM_SSM_STATE_DIM,\n                use_spectral=cfg.SSTM_USE_SPECTRAL[i],\n                use_ssm=cfg.SSTM_USE_SSM[i],\n                dropout=cfg.SSTM_DROPOUT,\n                name=f'sstm_stage{i+1}'\n            )(x)\n        \n        encoder_outputs.append(x)\n        print(f\"  Encoder Stage {i+1}: filters={f}\")\n    \n    skip_connections = encoder_outputs[::-1]\n    decoder = skip_connections[0]\n    decoder_filters = filters[::-1][1:] + [16]\n    \n    for i, (skip, f) in enumerate(zip(skip_connections[1:], decoder_filters)):\n        if cfg.USE_BFP:\n            decoder, _ = BFP_decoder_stage(decoder, skip, f, stage_name=f'bfp_stage{i+1}')\n            print(f\"  Decoder Stage {i+1}: filters={f}\")\n    \n    decoder = UpSampling2D((2, 2))(decoder)\n    decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)\n    decoder = Conv2D(16, (3, 3), padding='same', activation='relu')(decoder)\n    out = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name='output')(decoder)\n    \n    model = Model(inputs=inp, outputs=out, name=\"MedSegNet_SSF\")\n    \n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    \n    return model\n\n# ==============================================================================\n# MODEL 10: nnU-Net (2020 Baseline)\n# ==============================================================================\n\ndef instance_norm_block(x, filters, kernel_size=3, name='inorm'):\n    \"\"\"nnU-Net style block with Instance Normalization and Leaky ReLU\"\"\"\n    x = Conv2D(filters, kernel_size, padding='same', \n               kernel_initializer='he_normal', name=f'{name}_conv')(x)\n    x = LayerNormalization(epsilon=1e-5, name=f'{name}_norm')(x)\n    x = Activation('relu', name=f'{name}_act')(x)\n    return x\n\ndef nnunet_residual_block(x, filters, name='nnunet_res'):\n    \"\"\"nnU-Net residual block\"\"\"\n    shortcut = x\n    \n    x = instance_norm_block(x, filters, kernel_size=3, name=f'{name}_1')\n    x = instance_norm_block(x, filters, kernel_size=3, name=f'{name}_2')\n    \n    if shortcut.shape[-1] != filters:\n        shortcut = Conv2D(filters, 1, padding='same', \n                         kernel_initializer='he_normal',\n                         name=f'{name}_shortcut')(shortcut)\n        shortcut = LayerNormalization(epsilon=1e-5, name=f'{name}_shortcut_norm')(shortcut)\n    \n    x = Add(name=f'{name}_add')([x, shortcut])\n    return x\n\ndef build_nnunet(cfg):\n    \"\"\"nnU-Net: Self-configuring method for medical image segmentation\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING nnU-NET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    x = Conv2D(32, 3, padding='same', kernel_initializer='he_normal', name='stem')(inputs)\n    x = LayerNormalization(epsilon=1e-5, name='stem_norm')(x)\n    x = Activation('relu', name='stem_act')(x)\n    \n    filters = [32, 64, 128, 256, 512]\n    encoder_outputs = []\n    \n    for i, f in enumerate(filters):\n        x = nnunet_residual_block(x, f, name=f'enc_res{i+1}')\n        encoder_outputs.append(x)\n        \n        if i < len(filters) - 1:\n            x = Conv2D(f*2, 3, strides=2, padding='same',\n                      kernel_initializer='he_normal',\n                      name=f'down{i+1}')(x)\n            x = LayerNormalization(epsilon=1e-5, name=f'down{i+1}_norm')(x)\n            x = Activation('relu', name=f'down{i+1}_act')(x)\n        \n        print(f\"  Encoder Stage {i+1}: filters={f}\")\n    \n    for i in range(len(filters)-2, -1, -1):\n        f = filters[i]\n        \n        x = Conv2DTranspose(f, 2, strides=2, padding='same',\n                           kernel_initializer='he_normal',\n                           name=f'up{i+1}')(x)\n        x = LayerNormalization(epsilon=1e-5, name=f'up{i+1}_norm')(x)\n        x = Activation('relu', name=f'up{i+1}_act')(x)\n        \n        x = Concatenate(name=f'concat{i+1}')([x, encoder_outputs[i]])\n        x = nnunet_residual_block(x, f, name=f'dec_res{i+1}')\n        \n        print(f\"  Decoder Stage {i+1}: filters={f}\")\n    \n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(x)\n    \n    model = Model(inputs, outputs, name='nnUNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 11: VM-UNet (Vision Mamba - 2024)\n# ==============================================================================\n\nclass VSSBlock(Layer):\n    \"\"\"Visual State Space Block for VM-UNet\"\"\"\n    def __init__(self, dim, d_state=16, expand_ratio=2, **kwargs):\n        super().__init__(**kwargs)\n        self.dim = dim\n        self.d_state = d_state\n        self.expand_ratio = expand_ratio\n        self.hidden_dim = int(dim * expand_ratio)\n        \n    def build(self, input_shape):\n        self.proj_in = Dense(self.hidden_dim * 2, name='proj_in')\n        self.proj_out = Dense(self.dim, name='proj_out')\n        \n        self.x_proj = Dense(self.d_state * 2, name='x_proj')\n        self.dt_proj = Dense(self.dim, name='dt_proj')\n        \n        self.A_log = self.add_weight(\n            name='A_log',\n            shape=(self.dim, self.d_state),\n            initializer='glorot_uniform',\n            trainable=True\n        )\n        \n        self.D = self.add_weight(\n            name='D',\n            shape=(self.dim,),\n            initializer='ones',\n            trainable=True\n        )\n        \n        self.norm = LayerNormalization(epsilon=1e-5, name='norm')\n        self.dropout = Dropout(0.1)\n        \n        super().build(input_shape)\n    \n    def call(self, x, training=None):\n        B, H, W, C = x.shape\n        shortcut = x\n        \n        x = Reshape((H * W, C))(x)\n        x = self.norm(x)\n        \n        x_proj = self.proj_in(x)\n        x, gate = tf.split(x_proj, 2, axis=-1)\n        \n        x = x * tf.nn.silu(gate)\n        \n        x_ssm = self.x_proj(x)\n        B_expanded, C_expanded = tf.split(x_ssm, 2, axis=-1)\n        \n        y = self.dt_proj(x)\n        y = self.proj_out(y)\n        \n        if training:\n            y = self.dropout(y)\n        \n        y = Reshape((H, W, C))(y)\n        \n        return Add()([shortcut, y])\n\ndef patch_merging(x, dim, name='patch_merge'):\n    \"\"\"Patch merging layer for downsampling\"\"\"\n    x0 = x[:, 0::2, 0::2, :]\n    x1 = x[:, 1::2, 0::2, :]\n    x2 = x[:, 0::2, 1::2, :]\n    x3 = x[:, 1::2, 1::2, :]\n    \n    x = Concatenate(axis=-1, name=f'{name}_concat')([x0, x1, x2, x3])\n    x = LayerNormalization(epsilon=1e-5, name=f'{name}_norm')(x)\n    x = Dense(dim * 2, name=f'{name}_dense')(x)\n    \n    return x\n\ndef patch_expanding(x, dim, name='patch_expand'):\n    \"\"\"Patch expanding layer for upsampling\"\"\"\n    x = Dense(dim * 4, name=f'{name}_dense')(x)\n    x = LayerNormalization(epsilon=1e-5, name=f'{name}_norm')(x)\n    \n    B, H, W, C = x.shape\n    x = Reshape((H, W, 2, 2, dim), name=f'{name}_reshape1')(x)\n    x = Lambda(lambda t: tf.transpose(t, [0, 1, 3, 2, 4, 5]), \n               name=f'{name}_transpose')(x)\n    x = Reshape((H * 2, W * 2, dim), name=f'{name}_reshape2')(x)\n    \n    return x\n\ndef build_vmunet(cfg):\n    \"\"\"VM-UNet: Vision Mamba U-Net\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING VM-UNET\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    x = Conv2D(48, 4, strides=4, padding='same',\n              kernel_initializer='he_normal', name='patch_embed')(inputs)\n    x = LayerNormalization(epsilon=1e-5, name='patch_embed_norm')(x)\n    \n    dims = [48, 96, 192, 384]\n    encoder_outputs = []\n    depths = [2, 2, 4, 2]\n    \n    for stage, (dim, depth) in enumerate(zip(dims, depths)):\n        for i in range(depth):\n            x = VSSBlock(dim, d_state=16, expand_ratio=2,\n                        name=f'vss_enc{stage+1}_{i+1}')(x)\n        \n        encoder_outputs.append(x)\n        print(f\"  Encoder Stage {stage+1}: dim={dim}, depth={depth}\")\n        \n        if stage < len(dims) - 1:\n            x = patch_merging(x, dim, name=f'merge{stage+1}')\n    \n    for stage in range(len(dims)-2, -1, -1):\n        dim = dims[stage]\n        depth = depths[stage]\n        \n        x = patch_expanding(x, dim, name=f'expand{stage+1}')\n        \n        x = Concatenate(name=f'concat{stage+1}')([x, encoder_outputs[stage]])\n        x = Dense(dim, name=f'fuse{stage+1}')(x)\n        x = LayerNormalization(epsilon=1e-5, name=f'fuse{stage+1}_norm')(x)\n        \n        for i in range(depth):\n            x = VSSBlock(dim, d_state=16, expand_ratio=2,\n                        name=f'vss_dec{stage+1}_{i+1}')(x)\n        \n        print(f\"  Decoder Stage {stage+1}: dim={dim}, depth={depth}\")\n    \n    x = patch_expanding(x, 24, name='final_expand')\n    x = patch_expanding(x, 12, name='final_expand2')\n    \n    H, W = cfg.INPUT_SIZE, cfg.INPUT_SIZE\n    x = Reshape((H, W, 12), name='final_reshape')(x)\n    x = Conv2D(32, 3, padding='same', activation='relu', name='head_conv1')(x)\n    x = Conv2D(16, 3, padding='same', activation='relu', name='head_conv2')(x)\n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(x)\n    \n    model = Model(inputs, outputs, name='VMUNet')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL 12: MedNeXt\n# ==============================================================================\n\ndef mednext_block(x, filters, kernel_size=7, expand_ratio=4, name='mednext'):\n    \"\"\"MedNeXt Block: ConvNeXt-inspired block for medical imaging\"\"\"\n    shortcut = x\n    \n    x = DepthwiseConv2D(kernel_size, padding='same',\n                       depthwise_initializer='he_normal',\n                       name=f'{name}_dw')(x)\n    x = LayerNormalization(epsilon=1e-6, name=f'{name}_norm1')(x)\n    \n    hidden_dim = filters * expand_ratio\n    x = Dense(hidden_dim, name=f'{name}_expand')(x)\n    x = Activation('gelu', name=f'{name}_gelu')(x)\n    x = Dense(filters, name=f'{name}_compress')(x)\n    \n    if shortcut.shape[-1] != filters:\n        shortcut = Conv2D(filters, 1, padding='same',\n                         kernel_initializer='he_normal',\n                         name=f'{name}_shortcut')(shortcut)\n    \n    x = Add(name=f'{name}_add')([shortcut, x])\n    \n    return x\n\ndef mednext_stage(x, filters, num_blocks=3, kernel_size=7, \n                  downsample=False, name='stage'):\n    \"\"\"MedNeXt stage with multiple blocks\"\"\"\n    \n    if downsample:\n        x = LayerNormalization(epsilon=1e-6, name=f'{name}_down_norm')(x)\n        x = Conv2D(filters, 2, strides=2, padding='same',\n                  kernel_initializer='he_normal',\n                  name=f'{name}_down')(x)\n    \n    for i in range(num_blocks):\n        x = mednext_block(x, filters, kernel_size=kernel_size,\n                         expand_ratio=4, name=f'{name}_block{i+1}')\n    \n    return x\n\ndef build_mednext(cfg):\n    \"\"\"MedNeXt: Modern ConvNeXt-inspired architecture for medical imaging\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"🔥 BUILDING MEDNEXT\")\n    print(\"=\"*80)\n    \n    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name=\"input\")\n    \n    x = Conv2D(32, 4, strides=4, padding='same',\n              kernel_initializer='he_normal', name='stem')(inputs)\n    x = LayerNormalization(epsilon=1e-6, name='stem_norm')(x)\n    \n    filters = [32, 64, 128, 256, 512]\n    depths = [3, 3, 9, 3]\n    kernel_sizes = [7, 7, 7, 7]\n    encoder_outputs = []\n    \n    x = mednext_stage(x, filters[0], num_blocks=depths[0],\n                     kernel_size=kernel_sizes[0],\n                     downsample=False, name='enc1')\n    encoder_outputs.append(x)\n    print(f\"  Encoder Stage 1: filters={filters[0]}, blocks={depths[0]}, kernel={kernel_sizes[0]}\")\n    \n    for i in range(1, len(depths)):\n        x = mednext_stage(x, filters[i], num_blocks=depths[i],\n                         kernel_size=kernel_sizes[i],\n                         downsample=True, name=f'enc{i+1}')\n        encoder_outputs.append(x)\n        print(f\"  Encoder Stage {i+1}: filters={filters[i]}, blocks={depths[i]}, kernel={kernel_sizes[i]}\")\n    \n    x = mednext_stage(x, filters[4], num_blocks=3,\n                     kernel_size=7, downsample=True, name='bottleneck')\n    print(f\"  Bottleneck: filters={filters[4]}, blocks=3\")\n    \n    for i in range(len(depths)-1, -1, -1):\n        f = filters[i]\n        \n        x = LayerNormalization(epsilon=1e-6, name=f'dec{i+1}_up_norm')(x)\n        x = Conv2DTranspose(f, 2, strides=2, padding='same',\n                           kernel_initializer='he_normal',\n                           name=f'dec{i+1}_up')(x)\n        \n        skip = encoder_outputs[i]\n        x = Concatenate(name=f'dec{i+1}_concat')([x, skip])\n        x = Conv2D(f, 1, padding='same',\n                  kernel_initializer='he_normal',\n                  name=f'dec{i+1}_fuse')(x)\n        x = LayerNormalization(epsilon=1e-6, name=f'dec{i+1}_fuse_norm')(x)\n        \n        x = mednext_stage(x, f, num_blocks=depths[i],\n                         kernel_size=kernel_sizes[i],\n                         downsample=False, name=f'dec{i+1}')\n        \n        print(f\"  Decoder Stage {i+1}: filters={f}, blocks={depths[i]}\")\n    \n    x = LayerNormalization(epsilon=1e-6, name='final_norm')(x)\n    x = Conv2DTranspose(16, 4, strides=4, padding='same',\n                       kernel_initializer='he_normal',\n                       name='final_up')(x)\n    \n    x = Conv2D(32, 3, padding='same', activation='gelu', name='head_conv1')(x)\n    x = Conv2D(16, 3, padding='same', activation='gelu', name='head_conv2')(x)\n    outputs = Conv2D(1, 1, activation='sigmoid', name='output')(x)\n    \n    model = Model(inputs, outputs, name='MedNeXt')\n    print(f\"\\nTotal parameters: {model.count_params():,}\")\n    print(\"=\"*80 + \"\\n\")\n    return model\n\n# ==============================================================================\n# MODEL FACTORY\n# ==============================================================================\n\ndef build_model(cfg):\n    \"\"\"Factory function to build the selected model\"\"\"\n    model_builders = {\n        \"UNet\": build_unet,\n        \"UNetPlusPlus\": build_unetplusplus,\n        \"PraNet\": build_pranet,\n        \"RAPUNet\": build_rapunet,\n        \"SwinUNet\": build_swinunet,\n        \"TransUNet\": build_transunet,\n        \"UMamba\": build_umamba,\n        \"DUCKNET\": build_ducknet,\n        \"MedSegNet-SSF\": build_medsegnet_ssf,\n        \"nnUNet\": build_nnunet,\n        \"VM-UNet\": build_vmunet,\n        \"MedNeXt\": build_mednext\n    }\n    \n    if cfg.MODEL_NAME not in model_builders:\n        raise ValueError(f\"Unknown model: {cfg.MODEL_NAME}. Available: {list(model_builders.keys())}\")\n    \n    return model_builders[cfg.MODEL_NAME](cfg)\n\n# ==============================================================================\n# MASL LOSS FUNCTION\n# ==============================================================================\n\nclass ClipConstraint(tf.keras.constraints.Constraint):\n    \"\"\"Custom constraint to clip weight values\"\"\"\n    def __init__(self, min_value=0.1, max_value=10.0):\n        self.min_value = min_value\n        self.max_value = max_value\n    \n    def __call__(self, w):\n        return tf.clip_by_value(w, self.min_value, self.max_value)\n    \n    def get_config(self):\n        return {'min_value': self.min_value, 'max_value': self.max_value}\n\nclass MorphologyAwareAdaptiveLoss(Layer):\n    \"\"\"MASL: Morphology-Aware Adaptive Segmentation Loss\"\"\"\n    \n    def __init__(self, name='masl', **kwargs):\n        super().__init__(name=name, **kwargs)\n        self.epsilon = 1e-6\n        \n    def build(self, input_shape):\n        clip_constraint = ClipConstraint(min_value=0.1, max_value=10.0)\n        \n        self.w_region = self.add_weight(\n            name='w_region', shape=(), initializer=tf.constant_initializer(1.0),\n            trainable=True, constraint=clip_constraint\n        )\n        self.w_boundary = self.add_weight(\n            name='w_boundary', shape=(), initializer=tf.constant_initializer(1.0),\n            trainable=True, constraint=clip_constraint\n        )\n        self.w_structure = self.add_weight(\n            name='w_structure', shape=(), initializer=tf.constant_initializer(1.0),\n            trainable=True, constraint=clip_constraint\n        )\n        self.w_scale = self.add_weight(\n            name='w_scale', shape=(), initializer=tf.constant_initializer(0.5),\n            trainable=True, constraint=clip_constraint\n        )\n        self.w_texture = self.add_weight(\n            name='w_texture', shape=(), initializer=tf.constant_initializer(0.5),\n            trainable=True, constraint=clip_constraint\n        )\n        super().build(input_shape)\n    \n    def morphological_dilation(self, x, kernel_size=5):\n        return tf.nn.max_pool2d(x, kernel_size, strides=1, padding='SAME')\n    \n    def morphological_erosion(self, x, kernel_size=5):\n        return -tf.nn.max_pool2d(-x, kernel_size, strides=1, padding='SAME')\n    \n    def detect_boundary(self, mask, kernel_size=5):\n        dilated = self.morphological_dilation(mask, kernel_size)\n        eroded = self.morphological_erosion(mask, kernel_size)\n        boundary = dilated - eroded\n        return tf.clip_by_value(boundary, 0.0, 1.0)\n    \n    def analyze_structure_characteristics(self, y_true):\n        area = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon\n        total_pixels = tf.cast(tf.shape(y_true)[1] * tf.shape(y_true)[2], tf.float32)\n        \n        dy = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]\n        dx = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]\n        dy_padded = tf.pad(dy, [[0, 0], [0, 1], [0, 0], [0, 0]])\n        dx_padded = tf.pad(dx, [[0, 0], [0, 0], [0, 1], [0, 0]])\n        gradient_mag = tf.sqrt(dy_padded**2 + dx_padded**2 + self.epsilon)\n        perimeter = tf.reduce_sum(gradient_mag, axis=[1, 2, 3]) + self.epsilon\n        \n        skeleton_approx = self.morphological_erosion(y_true, kernel_size=3)\n        skeleton_area = tf.reduce_sum(skeleton_approx, axis=[1, 2, 3]) + self.epsilon\n        \n        tubularity = tf.reduce_mean(skeleton_area / (area + self.epsilon))\n        compactness = tf.reduce_mean((4 * 3.14159 * area) / (perimeter**2 + self.epsilon))\n        compactness = tf.clip_by_value(compactness, 0.0, 1.0)\n        \n        boundary = self.detect_boundary(y_true, kernel_size=5)\n        ddy = boundary[:, 2:, :, :] - 2*boundary[:, 1:-1, :, :] + boundary[:, :-2, :, :]\n        ddx = boundary[:, :, 2:, :] - 2*boundary[:, :, 1:-1, :] + boundary[:, :, :-2, :]\n        irregularity = tf.reduce_mean(tf.abs(ddy)) + tf.reduce_mean(tf.abs(ddx))\n        \n        object_size = tf.reduce_mean(area / total_pixels)\n        \n        return {\n            'tubularity': tf.clip_by_value(tubularity, 0.0, 1.0),\n            'compactness': compactness,\n            'irregularity': tf.clip_by_value(irregularity, 0.0, 1.0),\n            'object_size': tf.clip_by_value(object_size, 0.0, 1.0)\n        }\n    \n    def core_loss(self, y_true, y_pred):\n        intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])\n        dice = (2. * intersection + self.epsilon) / (\n            tf.reduce_sum(y_true, axis=[1, 2, 3]) + \n            tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon\n        )\n        dice_loss = 1.0 - tf.reduce_mean(dice)\n        \n        union = (tf.reduce_sum(y_true, axis=[1, 2, 3]) + \n                tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection)\n        iou = (intersection + self.epsilon) / (union + self.epsilon)\n        iou_loss = 1.0 - tf.reduce_mean(iou)\n        \n        boundary = self.detect_boundary(y_true, kernel_size=5)\n        weights = 1.0 + 5.0 * boundary\n        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + \n               (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))\n        weighted_bce = tf.reduce_mean(weights * bce)\n        \n        return 0.4 * dice_loss + 0.3 * iou_loss + 0.3 * weighted_bce\n    \n    def boundary_loss(self, y_true, y_pred):\n        total_loss = 0.0\n        weights = [0.5, 0.3, 0.2]\n        \n        for scale, w in zip([1, 2, 4], weights):\n            dy_true = y_true[:, scale:, :, :] - y_true[:, :-scale, :, :]\n            dy_pred = y_pred[:, scale:, :, :] - y_pred[:, :-scale, :, :]\n            \n            dx_true = y_true[:, :, scale:, :] - y_true[:, :, :-scale, :]\n            dx_pred = y_pred[:, :, scale:, :] - y_pred[:, :, :-scale, :]\n            \n            total_loss += w * (tf.reduce_mean(tf.abs(dy_true - dy_pred)) + \n                             tf.reduce_mean(tf.abs(dx_true - dx_pred)))\n        \n        return total_loss\n    \n    def structure_aware_loss(self, y_true, y_pred, characteristics):\n        area_true = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon\n        dy_true = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]\n        dx_true = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]\n        dy_true_padded = tf.pad(dy_true, [[0, 0], [0, 1], [0, 0], [0, 0]])\n        dx_true_padded = tf.pad(dx_true, [[0, 0], [0, 0], [0, 1], [0, 0]])\n        perimeter_true = tf.reduce_sum(tf.sqrt(dy_true_padded**2 + dx_true_padded**2 + self.epsilon), \n                                      axis=[1, 2, 3]) + self.epsilon\n        \n        area_pred = tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon\n        dy_pred = y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]\n        dx_pred = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]\n        dy_pred_padded = tf.pad(dy_pred, [[0, 0], [0, 1], [0, 0], [0, 0]])\n        dx_pred_padded = tf.pad(dx_pred, [[0, 0], [0, 0], [0, 1], [0, 0]])\n        perimeter_pred = tf.reduce_sum(tf.sqrt(dy_pred_padded**2 + dx_pred_padded**2 + self.epsilon), \n                                      axis=[1, 2, 3]) + self.epsilon\n        \n        compact_true = area_true / (perimeter_true**2 + self.epsilon)\n        compact_pred = area_pred / (perimeter_pred**2 + self.epsilon)\n        \n        return characteristics['compactness'] * tf.reduce_mean(tf.abs(compact_true - compact_pred))\n    \n    def scale_aware_focal_loss(self, y_true, y_pred, characteristics):\n        size = characteristics['object_size']\n        gamma = tf.cond(\n            size < 0.05,\n            lambda: 3.0,\n            lambda: tf.cond(size < 0.2, lambda: 2.0, lambda: 1.5)\n        )\n        \n        p = y_true * y_pred + (1 - y_true) * (1 - y_pred)\n        focal_weight = tf.pow(1 - p, gamma)\n        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + \n               (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))\n        \n        return tf.reduce_mean(focal_weight * bce)\n    \n    def texture_aware_loss(self, y_true, y_pred):\n        ddy_true = y_true[:, 2:, :, :] - 2*y_true[:, 1:-1, :, :] + y_true[:, :-2, :, :]\n        ddy_pred = y_pred[:, 2:, :, :] - 2*y_pred[:, 1:-1, :, :] + y_pred[:, :-2, :, :]\n        \n        ddx_true = y_true[:, :, 2:, :] - 2*y_true[:, :, 1:-1, :] + y_true[:, :, :-2, :]\n        ddx_pred = y_pred[:, :, 2:, :] - 2*y_pred[:, :, 1:-1, :] + y_pred[:, :, :-2, :]\n        \n        return tf.reduce_mean(tf.abs(ddy_true - ddy_pred)) + tf.reduce_mean(tf.abs(ddx_true - ddx_pred))\n    \n    def call(self, y_true, y_pred):\n        y_true = tf.cast(y_true, tf.float32)\n        y_pred = tf.cast(y_pred, tf.float32)\n        \n        characteristics = self.analyze_structure_characteristics(y_true)\n        \n        alpha_region = 1.0 + 0.5 * characteristics['compactness']\n        alpha_boundary = 1.0 + 1.5 * characteristics['tubularity'] + characteristics['compactness']\n        alpha_structure = 1.0 + characteristics['tubularity']\n        alpha_scale = 1.0 + 1.5 * characteristics['irregularity']\n        alpha_texture = 1.0 + characteristics['irregularity']\n        \n        l_core = self.core_loss(y_true, y_pred)\n        l_boundary = self.boundary_loss(y_true, y_pred)\n        l_structure = self.structure_aware_loss(y_true, y_pred, characteristics)\n        l_scale = self.scale_aware_focal_loss(y_true, y_pred, characteristics)\n        l_texture = self.texture_aware_loss(y_true, y_pred)\n        \n        weighted_core = self.w_region * alpha_region * l_core\n        weighted_boundary = self.w_boundary * alpha_boundary * l_boundary\n        weighted_structure = self.w_structure * alpha_structure * l_structure\n        weighted_scale = self.w_scale * alpha_scale * l_scale\n        weighted_texture = self.w_texture * alpha_texture * l_texture\n        \n        total_weight = (self.w_region * alpha_region + \n                       self.w_boundary * alpha_boundary + \n                       self.w_structure * alpha_structure + \n                       self.w_scale * alpha_scale + \n                       self.w_texture * alpha_texture)\n        \n        masl_loss = (weighted_core + weighted_boundary + weighted_structure + \n                    weighted_scale + weighted_texture) / (total_weight + self.epsilon)\n        \n        return masl_loss\n    \n    def get_config(self):\n        return super().get_config()\n\n_masl_instance = MorphologyAwareAdaptiveLoss()\n\ndef masl_loss_fn(y_true, y_pred):\n    \"\"\"MASL loss function wrapper\"\"\"\n    return _masl_instance(y_true, y_pred)\n\n# ==============================================================================\n# METRICS\n# ==============================================================================\n\ndef dice_coefficient(y_true, y_pred, smooth=1e-6):\n    \"\"\"Dice coefficient metric\"\"\"\n    y_true_f = K.flatten(y_true)\n    y_pred_f = K.flatten(y_pred)\n    intersection = K.sum(y_true_f * y_pred_f)\n    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)\n\ndef iou_score(y_true, y_pred, smooth=1e-6):\n    \"\"\"IoU (Jaccard) metric\"\"\"\n    y_true_f = K.flatten(y_true)\n    y_pred_f = K.flatten(y_pred)\n    intersection = K.sum(y_true_f * y_pred_f)\n    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection\n    return (intersection + smooth) / (union + smooth)\n\ndef precision_metric(y_true, y_pred):\n    \"\"\"Precision metric\"\"\"\n    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)\n    true_positives = K.sum(y_true * y_pred_bin)\n    predicted_positives = K.sum(y_pred_bin)\n    return true_positives / (predicted_positives + K.epsilon())\n\ndef recall_metric(y_true, y_pred):\n    \"\"\"Recall metric\"\"\"\n    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)\n    true_positives = K.sum(y_true * y_pred_bin)\n    actual_positives = K.sum(y_true)\n    return true_positives / (actual_positives + K.epsilon())\n\n# ==============================================================================\n# PREDICTION MASK SAVING\n# ==============================================================================\n\ndef save_prediction_masks(model, test_gen, save_dir, \n                         save_probabilities=True, \n                         save_binary=True, \n                         save_overlay=True):\n    \"\"\"Save prediction masks for all test samples\"\"\"\n    print(\"\\n\" + \"=\"*80)\n    print(\"💾 SAVING PREDICTION MASKS\")\n    print(\"=\"*80)\n    \n    if save_probabilities:\n        prob_dir = os.path.join(save_dir, \"probability_masks\")\n        os.makedirs(prob_dir, exist_ok=True)\n    \n    if save_binary:\n        binary_dir = os.path.join(save_dir, \"binary_masks\")\n        os.makedirs(binary_dir, exist_ok=True)\n    \n    if save_overlay:\n        overlay_dir = os.path.join(save_dir, \"overlays\")\n        os.makedirs(overlay_dir, exist_ok=True)\n    \n    gt_dir = os.path.join(save_dir, \"ground_truth\")\n    os.makedirs(gt_dir, exist_ok=True)\n    \n    img_dir = os.path.join(save_dir, \"original_images\")\n    os.makedirs(img_dir, exist_ok=True)\n    \n    metrics_file = os.path.join(save_dir, \"prediction_metrics.csv\")\n    \n    sample_count = 0\n    all_metrics = []\n    \n    for batch_idx in range(len(test_gen)):\n        images, masks = test_gen[batch_idx]\n        predictions = model.predict(images, verbose=0)\n        \n        for i in range(len(images)):\n            sample_count += 1\n            \n            img = images[i]\n            gt_mask = masks[i, :, :, 0]\n            pred = predictions[i, :, :, 0]\n            pred_binary = (pred > 0.5).astype(np.float32)\n            \n            # Calculate metrics\n            intersection = np.sum(gt_mask * pred_binary)\n            union = np.sum(gt_mask) + np.sum(pred_binary) - intersection\n            dice = (2. * intersection) / (np.sum(gt_mask) + np.sum(pred_binary) + 1e-6)\n            iou = intersection / (union + 1e-6)\n            \n            tp = np.sum(gt_mask * pred_binary)\n            fp = np.sum((1 - gt_mask) * pred_binary)\n            fn = np.sum(gt_mask * (1 - pred_binary))\n            precision = tp / (tp + fp + 1e-6)\n            recall = tp / (tp + fn + 1e-6)\n            \n            all_metrics.append({\n                'sample_id': f'sample_{sample_count:04d}',\n                'dice': dice,\n                'iou': iou,\n                'precision': precision,\n                'recall': recall\n            })\n            \n            filename = f'sample_{sample_count:04d}'\n            \n            # Save original image\n            img_uint8 = (img * 255).astype(np.uint8)\n            cv2.imwrite(\n                os.path.join(img_dir, f'{filename}.png'),\n                cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR)\n            )\n            \n            # Save ground truth\n            gt_uint8 = (gt_mask * 255).astype(np.uint8)\n            cv2.imwrite(os.path.join(gt_dir, f'{filename}_gt.png'), gt_uint8)\n            \n            # Save probability mask\n            if save_probabilities:\n                prob_uint8 = (pred * 255).astype(np.uint8)\n                cv2.imwrite(os.path.join(prob_dir, f'{filename}_prob.png'), prob_uint8)\n                prob_colorized = cv2.applyColorMap(prob_uint8, cv2.COLORMAP_JET)\n                cv2.imwrite(os.path.join(prob_dir, f'{filename}_prob_color.png'), prob_colorized)\n            \n            # Save binary mask\n            if save_binary:\n                binary_uint8 = (pred_binary * 255).astype(np.uint8)\n                cv2.imwrite(os.path.join(binary_dir, f'{filename}_binary.png'), binary_uint8)\n            \n            # Save overlay\n            if save_overlay:\n                overlay = img_uint8.copy()\n                pred_uint8 = (pred_binary * 255).astype(np.uint8)\n                contours, _ = cv2.findContours(pred_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n                cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)\n                \n                gt_uint8 = (gt_mask * 255).astype(np.uint8)\n                gt_contours, _ = cv2.findContours(gt_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n                cv2.drawContours(overlay, gt_contours, -1, (255, 0, 0), 2)\n                \n                cv2.putText(overlay, f'Dice: {dice:.4f}', (10, 30), \n                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)\n                cv2.putText(overlay, 'Green=Pred, Red=GT', (10, 60),\n                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)\n                \n                cv2.imwrite(\n                    os.path.join(overlay_dir, f'{filename}_overlay.png'),\n                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)\n                )\n            \n            if sample_count % 10 == 0:\n                print(f\"   Processed {sample_count} samples... (Dice: {dice:.4f})\")\n    \n    # Save metrics CSV\n    df_metrics = pd.DataFrame(all_metrics)\n    df_metrics.to_csv(metrics_file, index=False)\n    \n    print(\"\\n\" + \"=\"*80)\n    print(f\"✅ SAVED {sample_count} PREDICTION MASKS\")\n    print(\"=\"*80)\n    print(f\"\\n📊 Overall Metrics:\")\n    print(f\"   Mean Dice:      {df_metrics['dice'].mean():.4f} ± {df_metrics['dice'].std():.4f}\")\n    print(f\"   Mean IoU:       {df_metrics['iou'].mean():.4f} ± {df_metrics['iou'].std():.4f}\")\n    print(f\"   Mean Precision: {df_metrics['precision'].mean():.4f} ± {df_metrics['precision'].std():.4f}\")\n    print(f\"   Mean Recall:    {df_metrics['recall'].mean():.4f} ± {df_metrics['recall'].std():.4f}\")\n    print(\"=\"*80)\n    \n    return df_metrics\n\n# ==============================================================================\n# MAIN TRAINING LOOP\n# ==============================================================================\n\ndef train_model(cfg, strategy, num_gpus):\n    \"\"\"Main training function\"\"\"\n    set_seed(cfg.SEED, cfg.DETERMINISTIC)\n    \n    # 1. LOAD DATA\n    train_pairs = load_dataset_split(cfg.TRAIN_DIR)\n    val_pairs = load_dataset_split(cfg.VAL_DIR)\n    test_pairs = load_dataset_split(cfg.TEST_DIR)\n    \n    if not train_pairs:\n        print(\"❌ No training data found!\")\n        return\n    \n    print(f\"\\n📊 Dataset Statistics:\")\n    print(f\"   Training samples:   {len(train_pairs)}\")\n    print(f\"   Validation samples: {len(val_pairs)}\")\n    print(f\"   Test samples:       {len(test_pairs)}\")\n    \n    # 2. GENERATORS\n    train_aug = get_training_augmentation(cfg)\n    val_aug = get_validation_augmentation(cfg)\n    \n    train_gen = ExpandedDataGenerator(\n        train_pairs, cfg, augmentation=train_aug, shuffle=True,\n        expansion_factor=cfg.EPOCH_EXPANSION_FACTOR\n    )\n    \n    val_gen = ExpandedDataGenerator(\n        val_pairs, cfg, augmentation=val_aug, shuffle=False, expansion_factor=1\n    )\n    \n    test_gen = ExpandedDataGenerator(\n        test_pairs, cfg, augmentation=val_aug, shuffle=False, expansion_factor=1\n    )\n    \n    print(f\"\\n📊 Training Configuration:\")\n    print(f\"   Steps per Epoch: {len(train_gen)}\")\n    print(f\"   Virtual images/epoch: {len(train_pairs)*cfg.EPOCH_EXPANSION_FACTOR}\")\n    if num_gpus > 1:\n        print(f\"   Effective batch size: {cfg.BATCH_SIZE * num_gpus}\")\n    \n    # 3. BUILD MODEL\n    with strategy.scope():\n        model = build_model(cfg)\n        \n        # 4. COMPILE\n        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)\n        \n        loss_fn = masl_loss_fn if cfg.USE_MASL else \"binary_crossentropy\"\n        loss_name = \"MASL\" if cfg.USE_MASL else \"Binary Crossentropy\"\n        \n        model.compile(\n            optimizer=optimizer,\n            loss=loss_fn,\n            metrics=[dice_coefficient, iou_score, precision_metric, recall_metric]\n        )\n        \n        print(f\"\\n✅ Model compiled with {loss_name} loss\")\n    \n    # 5. CALLBACKS\n    callbacks = [\n        ModelCheckpoint(\n            os.path.join(cfg.SAVE_DIR, f\"best_model_{cfg.MODEL_NAME}.h5\"),\n            monitor=cfg.CHECKPOINT_MONITOR,\n            mode=cfg.CHECKPOINT_MODE,\n            save_best_only=True,\n            verbose=1\n        ),\n        EarlyStopping(\n            monitor=cfg.CHECKPOINT_MONITOR,\n            mode=cfg.CHECKPOINT_MODE,\n            patience=cfg.EARLY_STOPPING_PATIENCE,\n            verbose=1,\n            restore_best_weights=True\n        ),\n        ReduceLROnPlateau(\n            monitor=cfg.CHECKPOINT_MONITOR,\n            mode=cfg.CHECKPOINT_MODE,\n            factor=0.5,\n            patience=10,\n            min_lr=1e-7,\n            verbose=1\n        ),\n        CSVLogger(os.path.join(cfg.SAVE_DIR, f\"training_log_{cfg.MODEL_NAME}.csv\"))\n    ]\n    \n    # 6. TRAIN\n    gpu_info = f\" on {num_gpus} GPU(s)\" if num_gpus > 0 else \" on CPU\"\n    print(f\"\\n🚀 STARTING TRAINING ({cfg.EPOCHS} EPOCHS){gpu_info}\")\n    print(\"=\"*80)\n    start_time = time.time()\n    \n    history = model.fit(\n        train_gen,\n        validation_data=val_gen,\n        epochs=cfg.EPOCHS,\n        callbacks=callbacks,\n        verbose=1\n    )\n    \n    training_time = time.time() - start_time\n    print(f\"\\n✅ Training finished in {training_time/60:.1f} minutes\")\n    \n    # 7. EVALUATE ON TEST SET\n    print(\"\\n\" + \"=\"*80)\n    print(\"📊 EVALUATING ON TEST SET\")\n    print(\"=\"*80)\n    test_results = model.evaluate(test_gen, verbose=1)\n    \n    results = {\n        \"model\": cfg.MODEL_NAME,\n        \"loss_function\": loss_name,\n        \"gpu_config\": {\n            \"num_gpus\": num_gpus,\n            \"gpu_numbers\": cfg.GPU_NUMBERS,\n            \"effective_batch_size\": cfg.BATCH_SIZE * max(num_gpus, 1)\n        },\n        \"training_time_minutes\": training_time / 60,\n        \"test_results\": {\n            name: float(value) \n            for name, value in zip(model.metrics_names, test_results)\n        }\n    }\n    \n    # Save results\n    with open(os.path.join(cfg.SAVE_DIR, f\"results_{cfg.MODEL_NAME}.json\"), \"w\") as f:\n        json.dump(results, f, indent=2)\n    \n    print(f\"\\n✅ Results saved to {cfg.SAVE_DIR}/results_{cfg.MODEL_NAME}.json\")\n    \n    # 8. SAVE PREDICTION MASKS\n    print(\"\\n\" + \"=\"*80)\n    print(\"💾 GENERATING PREDICTION MASKS\")\n    print(\"=\"*80)\n    \n    masks_dir = os.path.join(cfg.SAVE_DIR, \"predictions\")\n    \n    try:\n        df_metrics = save_prediction_masks(\n            model, test_gen,\n            save_dir=masks_dir,\n            save_probabilities=True,\n            save_binary=True,\n            save_overlay=True\n        )\n    except Exception as e:\n        print(f\"⚠️ Prediction mask saving failed: {e}\")\n    \n    # 9. FINAL SUMMARY\n    print(\"\\n\" + \"=\"*80)\n    print(f\"✅ {cfg.MODEL_NAME} TRAINING COMPLETE!\")\n    print(\"=\"*80)\n    print(f\"\\n📁 Outputs saved to: {cfg.SAVE_DIR}/\")\n    print(f\"   ├─ best_model_{cfg.MODEL_NAME}.h5\")\n    print(f\"   ├─ training_log_{cfg.MODEL_NAME}.csv\")\n    print(f\"   ├─ results_{cfg.MODEL_NAME}.json\")\n    print(f\"   └─ predictions/\")\n    print(f\"       ├─ probability_masks/\")\n    print(f\"       ├─ binary_masks/\")\n    print(f\"       ├─ overlays/\")\n    print(f\"       ├─ ground_truth/\")\n    print(f\"       ├─ original_images/\")\n    print(f\"       └─ prediction_metrics.csv\")\n    \n    print(f\"\\n📊 Final Test Results:\")\n    print(f\"   Loss:      {test_results[0]:.4f}\")\n    print(f\"   Dice:      {test_results[1]:.4f}\")\n    print(f\"   IoU:       {test_results[2]:.4f}\")\n    print(f\"   Precision: {test_results[3]:.4f}\")\n    print(f\"   Recall:    {test_results[4]:.4f}\")\n    print(\"=\"*80 + \"\\n\")\n    \n    return model, history\n\n# ==============================================================================\n# MAIN EXECUTION\n# ==============================================================================\n\nif __name__ == \"__main__\":\n    model, history = train_model(config, strategy, num_gpus)","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null}]}