#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
================================================================================
🔥 MULTI-MODEL MEDICAL IMAGE SEGMENTATION TRAINING SCRIPT
================================================================================

FEATURES:
- ✅ Multiple Model Architectures (UNet, UNet++, PraNet, RAPUNet, etc.)
- ✅ Patch-based inference for evaluation (NO MORE RESIZING!)
- ✅ Correct TTA implementation
- ✅ Save predicted images for each model
- ✅ Side-by-side visualizations
- ✅ MASL loss function
- ✅ Model selection capability

MODELS INCLUDED:
1. UNet
2. UNet++
3. PraNet
4. RAPUNet
5. SwinUNet
6. TransUNet
7. UMamba
8. MedSegNet-SSF
9. DuckNet
10. nnUNet

Author: Sanaullah
Date: 2025
================================================================================
"""

import os
import sys
import time
import json
from glob import glob
from pathlib import Path

import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, GlobalAveragePooling2D,
    Reshape, Multiply, BatchNormalization, Activation, Dropout,
    Concatenate, Add, UpSampling2D, LayerNormalization, Dense, Layer,
    MaxPooling2D, Conv2DTranspose, MultiHeadAttention, Embedding
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, CSVLogger, Callback, LearningRateScheduler
)
from tensorflow.keras import backend as K
import albumentations as A
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle, Patch
from scipy import ndimage
from scipy.stats import pearsonr
import pandas as pd

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# ==============================================================================
# 🔥 CONFIGURATION
# ==============================================================================

class Config:
    # ==================== GPU CONFIGURATION ====================
    GPU_NUMBERS = [0]
    
    # DATA PATHS
    DATA_ROOT = "/kaggle/input/chase-db/chase_data"  # ⚡ CHANGE THIS
    TRAIN_DIR = os.path.join(DATA_ROOT, "train")
    VAL_DIR   = os.path.join(DATA_ROOT, "val")
    TEST_DIR  = os.path.join(DATA_ROOT, "test")
    SAVE_DIR  = "/kaggle/working/MULTI_MODEL_OUTPUT"
    
    # ==================== MODEL SELECTION ====================
    # Choose which models to train (set to True/False)
    MODELS_TO_TRAIN = {
        "UNet": True,
        "UNetPlusPlus": True,
        "PraNet": True,
        "RAPUNet": True,
        "SwinUNet": True,  # Slower, requires more memory
        "TransUNet": True,  # Slower, requires more memory
        "UMamba": True,     # Slower, requires more memory
        # "MedSegNet-SSF": False,
        "DuckNet": True,
        "nnUNet": True
    }
    
    # ==================== PATCH TRAINING SETTINGS ====================
    USE_PATCH_TRAINING = True
    PATCH_SIZE = 256
    STRIDE = 32
    PATCHES_PER_EPOCH = 4000
    MIN_VESSEL_RATIO = 0.005
    
    # ==================== MODEL ARCHITECTURE ====================
    INPUT_SIZE = 256
    F1, F2, F3, F4, F5 = 24, 32, 64, 80, 128
    
    USE_MRF_SE = True
    USE_SSTM   = True
    USE_BFP    = True
    
    MRF_KERNELS = [3, 5, 7]
    SE_REDUCTION = 16
    EXPAND_RATIO = 6
    DROPOUT = 0.1
    L2_REG = 1e-4
    
    SSTM_NUM_FREQUENCIES = 32
    SSTM_SSM_STATE_DIM = 16
    SSTM_USE_SPECTRAL = [True, True, True, True, True]
    SSTM_USE_SSM = [False, False, True, True, True]
    SSTM_DROPOUT = 0.1

    # ==================== TRAINING SETTINGS ====================
    BATCH_SIZE = 16
    EPOCHS = 100
    LEARNING_RATE = 5e-4
    
    EARLY_STOPPING_PATIENCE = 50
    CHECKPOINT_MONITOR = "val_dice_coefficient"
    CHECKPOINT_MODE = "max"

    # ==================== PREPROCESSING ====================
    USE_CLAHE = True
    USE_GREEN_CHANNEL = True
    CLAHE_CLIP_LIMIT = 2.0
    CLAHE_TILE_GRID_SIZE = (8, 8)
    
    # ==================== TEST-TIME AUGMENTATION ====================
    USE_TTA = True
    TTA_AUGMENTATIONS = 8
    
    # ==================== FOV MASKING ====================
    USE_FOV_MASK = True
    FOV_MARGIN = 20

    # ==================== VISUALIZATION SETTINGS ====================
    SAVE_PREDICTIONS = True
    SAVE_OVERLAYS = True

    SEED = 42
    DETERMINISTIC = False

    def __init__(self):
        os.makedirs(self.SAVE_DIR, exist_ok=True)
        print(f"🔥 MULTI-MODEL TRAINING CONFIGURATION")
        print(f"   Models to train: {sum(self.MODELS_TO_TRAIN.values())}/{len(self.MODELS_TO_TRAIN)}")
        for model_name, enabled in self.MODELS_TO_TRAIN.items():
            if enabled:
                print(f"      ✅ {model_name}")
        print(f"   Patch Size: {self.PATCH_SIZE}x{self.PATCH_SIZE}")
        print(f"   CLAHE Preprocessing: {'ENABLED' if self.USE_CLAHE else 'DISABLED'}")
        print(f"   Test-Time Augmentation: {'ENABLED' if self.USE_TTA else 'DISABLED'}")

config = Config()

# ==============================================================================
# GPU SETUP
# ==============================================================================

def setup_gpus(gpu_numbers=None):
    gpus = tf.config.list_physical_devices('GPU')
    if not gpus:
        print("⚠️ No GPUs found! Using CPU.")
        return tf.distribute.get_strategy(), 0
    
    print(f"🔍 Total GPUs available: {len(gpus)}")
    
    if gpu_numbers is not None:
        selected_gpus = [gpus[i] for i in gpu_numbers if i < len(gpus)]
    else:
        selected_gpus = gpus
    
    try:
        tf.config.set_visible_devices(selected_gpus, 'GPU')
        for gpu in selected_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        
        num_gpus = len(selected_gpus)
        strategy = tf.distribute.MirroredStrategy() if num_gpus > 1 else tf.distribute.get_strategy()
        print(f"✅ Using {num_gpus} GPU(s)")
        return strategy, num_gpus
    except RuntimeError as e:
        print(f"⚠️ GPU setup error: {e}")
        return tf.distribute.get_strategy(), 0

strategy, num_gpus = setup_gpus(config.GPU_NUMBERS)

# ==============================================================================
# PREPROCESSING FUNCTIONS
# ==============================================================================

def preprocess_image_clahe(image, cfg):
    """Apply CLAHE preprocessing optimized for retinal vessel segmentation."""
    if cfg.USE_GREEN_CHANNEL:
        if len(image.shape) == 3:
            green_channel = image[:, :, 1]
        else:
            green_channel = image
    else:
        green_channel = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if len(image.shape) == 3 else image
    
    if cfg.USE_CLAHE:
        clahe = cv2.createCLAHE(
            clipLimit=cfg.CLAHE_CLIP_LIMIT,
            tileGridSize=cfg.CLAHE_TILE_GRID_SIZE
        )
        enhanced = clahe.apply(green_channel.astype(np.uint8))
    else:
        enhanced = green_channel
    
    enhanced = enhanced.astype(np.float32) / 255.0
    enhanced_rgb = np.stack([enhanced, enhanced, enhanced], axis=-1)
    
    return enhanced_rgb

def apply_fov_mask(image, mask, cfg):
    """Apply circular FOV mask for retinal images."""
    if not cfg.USE_FOV_MASK:
        return image, mask
    
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    radius = min(h, w) // 2 - cfg.FOV_MARGIN
    
    y, x = np.ogrid[:h, :w]
    mask_region = (x - center[0])**2 + (y - center[1])**2 <= radius**2
    
    if len(image.shape) == 3:
        image = image * mask_region[:, :, np.newaxis]
    else:
        image = image * mask_region
    
    mask = mask * mask_region
    
    return image, mask

# ==============================================================================
# UTILS
# ==============================================================================

def set_seed(seed=42, deterministic=False):
    import random
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_image_mask_pairs(images_dir, masks_dir):
    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.bmp']
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob(os.path.join(images_dir, ext)))
    image_files = sorted(image_files)
    
    if len(image_files) == 0:
        print(f"⚠️ No images found in {images_dir}")
        return []

    pairs = []
    for img_path in image_files:
        img_name = Path(img_path).stem
        possible_names = [
            f"{img_name}.png", f"{img_name}.jpg", f"{img_name}.tif", 
            f"{img_name}_mask.png", f"{img_name}_mask.jpg"
        ]
        for mask_name in possible_names:
            cand = os.path.join(masks_dir, mask_name)
            if os.path.exists(cand):
                pairs.append((img_path, cand))
                break
    return pairs

def load_dataset_split(split_dir):
    images_dir = os.path.join(split_dir, "images")
    masks_dir  = os.path.join(split_dir, "masks")
    return get_image_mask_pairs(images_dir, masks_dir)

# ==============================================================================
# PATCH EXTRACTION
# ==============================================================================

def extract_patches_from_image(image, mask, cfg):
    """Extract overlapping patches from a single image."""
    h, w = image.shape[:2]
    patches = []
    
    for y in range(0, h - cfg.PATCH_SIZE + 1, cfg.STRIDE):
        for x in range(0, w - cfg.PATCH_SIZE + 1, cfg.STRIDE):
            img_patch = image[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE]
            mask_patch = mask[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE]
            
            vessel_ratio = np.sum(mask_patch) / (cfg.PATCH_SIZE * cfg.PATCH_SIZE)
            
            if vessel_ratio >= cfg.MIN_VESSEL_RATIO:
                patches.append((img_patch, mask_patch))
    
    return patches

def extract_all_patches(pairs, cfg):
    """Pre-extract all patches from all images."""
    print(f"\n🔍 Extracting patches from {len(pairs)} images...")
    all_patches = []
    
    for img_path, mask_path in pairs:
        image = cv2.imread(img_path)
        if image is None:
            continue
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)
        
        image = preprocess_image_clahe(image, cfg)
        image, mask = apply_fov_mask(image, mask, cfg)
        
        patches = extract_patches_from_image(image, mask, cfg)
        all_patches.extend(patches)
    
    print(f"✅ Extracted {len(all_patches)} patches")
    return all_patches

# ==============================================================================
# AUGMENTATION
# ==============================================================================

def get_patch_augmentation(cfg):
    """Optimized augmentation for patch-based training."""
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=180, border_mode=cv2.BORDER_REFLECT_101, p=0.9),
        A.ElasticTransform(alpha=25, sigma=4, alpha_affine=4, border_mode=cv2.BORDER_REFLECT_101, p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.1, border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.6),
        A.GaussNoise(var_limit=(5.0, 15.0), p=0.2),
        A.GaussianBlur(blur_limit=(3, 5), p=0.1),
    ], p=1.0)

# ==============================================================================
# DATA GENERATORS
# ==============================================================================

class PatchBasedGenerator(tf.keras.utils.Sequence):
    """Generator for patch-based training."""
    def __init__(self, all_patches, cfg, augmentation=None, shuffle=True):
        self.all_patches = all_patches
        self.cfg = cfg
        self.augmentation = augmentation
        self.shuffle = shuffle
        self.steps_per_epoch = cfg.PATCHES_PER_EPOCH // cfg.BATCH_SIZE
            
    def __len__(self):
        return self.steps_per_epoch

    def __getitem__(self, index):
        indices = np.random.choice(len(self.all_patches), self.cfg.BATCH_SIZE, replace=False)
        
        images, masks = [], []
        for idx in indices:
            img_patch, mask_patch = self.all_patches[idx]
            
            if self.augmentation:
                augmented = self.augmentation(image=img_patch, mask=mask_patch)
                img_patch = augmented["image"]
                mask_patch = augmented["mask"]
            
            if len(mask_patch.shape) == 2:
                mask_patch = np.expand_dims(mask_patch, axis=-1)
            
            images.append(img_patch)
            masks.append(mask_patch)
            
        return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)

    def on_epoch_end(self):
        pass

# ==============================================================================
# 🔥 MODEL ARCHITECTURES
# ==============================================================================

# ======================== 1. UNET ========================
def build_unet(cfg):
    """Classic U-Net architecture."""
    print("\n" + "="*80)
    print("🔥 BUILDING UNET")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder
    c1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(32, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    
    c2 = Conv2D(64, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(64, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
    
    c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(128, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    
    c4 = Conv2D(256, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(256, (3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)
    
    # Bottleneck
    c5 = Conv2D(512, (3, 3), activation='relu', padding='same')(p4)
    c5 = Conv2D(512, (3, 3), activation='relu', padding='same')(c5)
    
    # Decoder
    u6 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = Concatenate()([u6, c4])
    c6 = Conv2D(256, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(256, (3, 3), activation='relu', padding='same')(c6)
    
    u7 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = Concatenate()([u7, c3])
    c7 = Conv2D(128, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(128, (3, 3), activation='relu', padding='same')(c7)
    
    u8 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = Concatenate()([u8, c2])
    c8 = Conv2D(64, (3, 3), activation='relu', padding='same')(u8)
    c8 = Conv2D(64, (3, 3), activation='relu', padding='same')(c8)
    
    u9 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = Concatenate()([u9, c1])
    c9 = Conv2D(32, (3, 3), activation='relu', padding='same')(u9)
    c9 = Conv2D(32, (3, 3), activation='relu', padding='same')(c9)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='UNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 2. UNET++ ========================
def build_unet_plusplus(cfg):
    """U-Net++ with nested skip connections."""
    print("\n" + "="*80)
    print("🔥 BUILDING UNET++")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder
    x00 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x00 = Conv2D(32, (3, 3), activation='relu', padding='same')(x00)
    p0 = MaxPooling2D((2, 2))(x00)
    
    x10 = Conv2D(64, (3, 3), activation='relu', padding='same')(p0)
    x10 = Conv2D(64, (3, 3), activation='relu', padding='same')(x10)
    p1 = MaxPooling2D((2, 2))(x10)
    
    # Nested connections
    u01 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(x10)
    x01 = Concatenate()([u01, x00])
    x01 = Conv2D(32, (3, 3), activation='relu', padding='same')(x01)
    x01 = Conv2D(32, (3, 3), activation='relu', padding='same')(x01)
    
    x20 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    x20 = Conv2D(128, (3, 3), activation='relu', padding='same')(x20)
    p2 = MaxPooling2D((2, 2))(x20)
    
    u11 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(x20)
    x11 = Concatenate()([u11, x10])
    x11 = Conv2D(64, (3, 3), activation='relu', padding='same')(x11)
    x11 = Conv2D(64, (3, 3), activation='relu', padding='same')(x11)
    
    u02 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(x11)
    x02 = Concatenate()([u02, x00, x01])
    x02 = Conv2D(32, (3, 3), activation='relu', padding='same')(x02)
    x02 = Conv2D(32, (3, 3), activation='relu', padding='same')(x02)
    
    x30 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    x30 = Conv2D(256, (3, 3), activation='relu', padding='same')(x30)
    
    u21 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(x30)
    x21 = Concatenate()([u21, x20])
    x21 = Conv2D(128, (3, 3), activation='relu', padding='same')(x21)
    
    u12 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(x21)
    x12 = Concatenate()([u12, x10, x11])
    x12 = Conv2D(64, (3, 3), activation='relu', padding='same')(x12)
    
    u03 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(x12)
    x03 = Concatenate()([u03, x00, x01, x02])
    x03 = Conv2D(32, (3, 3), activation='relu', padding='same')(x03)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(x03)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='UNetPlusPlus')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 3. PraNet ========================
def rfb_block(x, filters, name='rfb'):
    """Receptive Field Block."""
    # Branch 1
    b1 = Conv2D(filters//4, (1, 1), padding='same', name=f'{name}_b1')(x)
    
    # Branch 2
    b2 = Conv2D(filters//4, (1, 1), padding='same', name=f'{name}_b2_1')(x)
    b2 = Conv2D(filters//4, (3, 3), padding='same', dilation_rate=1, name=f'{name}_b2_2')(b2)
    
    # Branch 3
    b3 = Conv2D(filters//4, (1, 1), padding='same', name=f'{name}_b3_1')(x)
    b3 = Conv2D(filters//4, (3, 3), padding='same', dilation_rate=3, name=f'{name}_b3_2')(b3)
    
    # Branch 4
    b4 = Conv2D(filters//4, (1, 1), padding='same', name=f'{name}_b4_1')(x)
    b4 = Conv2D(filters//4, (3, 3), padding='same', dilation_rate=5, name=f'{name}_b4_2')(b4)
    
    out = Concatenate(name=f'{name}_concat')([b1, b2, b3, b4])
    out = Conv2D(filters, (1, 1), padding='same', name=f'{name}_conv')(out)
    out = BatchNormalization(name=f'{name}_bn')(out)
    out = Activation('relu', name=f'{name}_act')(out)
    
    return out

def build_pranet(cfg):
    """PraNet for polyp segmentation."""
    print("\n" + "="*80)
    print("🔥 BUILDING PRANET")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x1 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x2 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x3 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x4 = x
    
    # RFB blocks
    r1 = rfb_block(x1, 64, 'rfb1')
    r2 = rfb_block(x2, 128, 'rfb2')
    r3 = rfb_block(x3, 256, 'rfb3')
    r4 = rfb_block(x4, 512, 'rfb4')
    
    # Decoder
    u1 = UpSampling2D((2, 2))(r4)
    u1 = Concatenate()([u1, r3])
    u1 = Conv2D(256, (3, 3), activation='relu', padding='same')(u1)
    
    u2 = UpSampling2D((2, 2))(u1)
    u2 = Concatenate()([u2, r2])
    u2 = Conv2D(128, (3, 3), activation='relu', padding='same')(u2)
    
    u3 = UpSampling2D((2, 2))(u2)
    u3 = Concatenate()([u3, r1])
    u3 = Conv2D(64, (3, 3), activation='relu', padding='same')(u3)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(u3)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='PraNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 4. RAPUNet ========================
def residual_attention_block(x, filters, name='ra'):
    """Residual Attention Block."""
    shortcut = x
    
    # Residual path
    res = Conv2D(filters, (3, 3), padding='same', name=f'{name}_conv1')(x)
    res = BatchNormalization(name=f'{name}_bn1')(res)
    res = Activation('relu', name=f'{name}_act1')(res)
    res = Conv2D(filters, (3, 3), padding='same', name=f'{name}_conv2')(res)
    res = BatchNormalization(name=f'{name}_bn2')(res)
    
    # Attention
    att = GlobalAveragePooling2D(name=f'{name}_gap')(res)
    att = Reshape((1, 1, filters), name=f'{name}_reshape')(att)
    att = Conv2D(filters//8, (1, 1), activation='relu', name=f'{name}_att1')(att)
    att = Conv2D(filters, (1, 1), activation='sigmoid', name=f'{name}_att2')(att)
    
    res = Multiply(name=f'{name}_mult')([res, att])
    
    # Projection if needed
    if shortcut.shape[-1] != filters:
        shortcut = Conv2D(filters, (1, 1), padding='same', name=f'{name}_proj')(shortcut)
    
    out = Add(name=f'{name}_add')([shortcut, res])
    out = Activation('relu', name=f'{name}_act')(out)
    
    return out

def build_rapunet(cfg):
    """Residual Attention-based Pyramid UNet."""
    print("\n" + "="*80)
    print("🔥 BUILDING RAPUNet")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder
    x = Conv2D(32, (3, 3), padding='same')(inputs)
    e1 = residual_attention_block(x, 32, 'ra_e1')
    p1 = MaxPooling2D((2, 2))(e1)
    
    e2 = residual_attention_block(p1, 64, 'ra_e2')
    p2 = MaxPooling2D((2, 2))(e2)
    
    e3 = residual_attention_block(p2, 128, 'ra_e3')
    p3 = MaxPooling2D((2, 2))(e3)
    
    e4 = residual_attention_block(p3, 256, 'ra_e4')
    p4 = MaxPooling2D((2, 2))(e4)
    
    # Bottleneck
    b = residual_attention_block(p4, 512, 'ra_b')
    
    # Decoder with pyramid features
    u4 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(b)
    u4 = Concatenate()([u4, e4])
    d4 = residual_attention_block(u4, 256, 'ra_d4')
    
    u3 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(d4)
    u3 = Concatenate()([u3, e3])
    d3 = residual_attention_block(u3, 128, 'ra_d3')
    
    u2 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(d3)
    u2 = Concatenate()([u2, e2])
    d2 = residual_attention_block(u2, 64, 'ra_d2')
    
    u1 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(d2)
    u1 = Concatenate()([u1, e1])
    d1 = residual_attention_block(u1, 32, 'ra_d1')
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(d1)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='RAPUNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 5. SwinUNet ========================
def window_partition(x, window_size):
    """Partition into non-overlapping windows."""
    B, H, W, C = x.shape
    x = tf.reshape(x, [B, H // window_size, window_size, W // window_size, window_size, C])
    windows = tf.reshape(tf.transpose(x, [0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, C])
    return windows

def build_swinunet(cfg):
    """Simplified Swin Transformer UNet."""
    print("\n" + "="*80)
    print("🔥 BUILDING SWINUNET")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Patch embedding
    x = Conv2D(96, (4, 4), strides=4, padding='same')(inputs)
    x = LayerNormalization()(x)
    
    # Encoder stages
    e1 = Conv2D(96, (3, 3), padding='same', activation='relu')(x)
    e1 = LayerNormalization()(e1)
    p1 = Conv2D(192, (3, 3), strides=2, padding='same')(e1)
    
    e2 = Conv2D(192, (3, 3), padding='same', activation='relu')(p1)
    e2 = LayerNormalization()(e2)
    p2 = Conv2D(384, (3, 3), strides=2, padding='same')(e2)
    
    e3 = Conv2D(384, (3, 3), padding='same', activation='relu')(p2)
    e3 = LayerNormalization()(e3)
    p3 = Conv2D(768, (3, 3), strides=2, padding='same')(e3)
    
    # Bottleneck
    b = Conv2D(768, (3, 3), padding='same', activation='relu')(p3)
    b = LayerNormalization()(b)
    
    # Decoder
    u3 = Conv2DTranspose(384, (2, 2), strides=2, padding='same')(b)
    u3 = Concatenate()([u3, e3])
    d3 = Conv2D(384, (3, 3), padding='same', activation='relu')(u3)
    d3 = LayerNormalization()(d3)
    
    u2 = Conv2DTranspose(192, (2, 2), strides=2, padding='same')(d3)
    u2 = Concatenate()([u2, e2])
    d2 = Conv2D(192, (3, 3), padding='same', activation='relu')(u2)
    d2 = LayerNormalization()(d2)
    
    u1 = Conv2DTranspose(96, (2, 2), strides=2, padding='same')(d2)
    u1 = Concatenate()([u1, e1])
    d1 = Conv2D(96, (3, 3), padding='same', activation='relu')(u1)
    
    # Final upsampling
    final = Conv2DTranspose(48, (4, 4), strides=4, padding='same')(d1)
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(final)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='SwinUNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 6. TransUNet ========================
def build_transunet(cfg):
    """Simplified Transformer-based UNet."""
    print("\n" + "="*80)
    print("🔥 BUILDING TRANSUNET")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # CNN Encoder
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(inputs)
    x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
    e1 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    e2 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    e3 = x
    x = MaxPooling2D((2, 2))(x)
    
    # Transformer bottleneck (simplified)
    B, H, W, C = x.shape.as_list()
    x_flat = tf.reshape(x, [-1, H*W, C])
    x_flat = LayerNormalization()(x_flat)
    
    # Self-attention
    attn_out = MultiHeadAttention(num_heads=8, key_dim=C//8)(x_flat, x_flat)
    x_flat = Add()([x_flat, attn_out])
    x_flat = LayerNormalization()(x_flat)
    
    # MLP
    mlp = Dense(C*4, activation='relu')(x_flat)
    mlp = Dense(C)(mlp)
    x_flat = Add()([x_flat, mlp])
    
    x = tf.reshape(x_flat, [-1, H, W, C])
    
    # Decoder
    u3 = Conv2DTranspose(256, (2, 2), strides=2, padding='same')(x)
    u3 = Concatenate()([u3, e3])
    d3 = Conv2D(256, (3, 3), padding='same', activation='relu')(u3)
    d3 = Conv2D(256, (3, 3), padding='same', activation='relu')(d3)
    
    u2 = Conv2DTranspose(128, (2, 2), strides=2, padding='same')(d3)
    u2 = Concatenate()([u2, e2])
    d2 = Conv2D(128, (3, 3), padding='same', activation='relu')(u2)
    d2 = Conv2D(128, (3, 3), padding='same', activation='relu')(d2)
    
    u1 = Conv2DTranspose(64, (2, 2), strides=2, padding='same')(d2)
    u1 = Concatenate()([u1, e1])
    d1 = Conv2D(64, (3, 3), padding='same', activation='relu')(u1)
    d1 = Conv2D(64, (3, 3), padding='same', activation='relu')(d1)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(d1)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='TransUNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 7. UMamba ========================
def build_umamba(cfg):
    """U-Mamba (Simplified Mamba-based UNet)."""
    print("\n" + "="*80)
    print("🔥 BUILDING UMAMBA")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder with depthwise separable convolutions (approximating selective scan)
    x = Conv2D(32, (3, 3), padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    e1 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(64, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    e2 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(128, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    e3 = x
    x = MaxPooling2D((2, 2))(x)
    
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(256, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    e4 = x
    x = MaxPooling2D((2, 2))(x)
    
    # Bottleneck
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(512, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # Decoder
    x = Conv2DTranspose(256, (2, 2), strides=2, padding='same')(x)
    x = Concatenate()([x, e4])
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(256, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2DTranspose(128, (2, 2), strides=2, padding='same')(x)
    x = Concatenate()([x, e3])
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(128, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2DTranspose(64, (2, 2), strides=2, padding='same')(x)
    x = Concatenate()([x, e2])
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(64, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2DTranspose(32, (2, 2), strides=2, padding='same')(x)
    x = Concatenate()([x, e1])
    x = DepthwiseConv2D((3, 3), padding='same')(x)
    x = Conv2D(32, (1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(x)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='UMamba')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 8. MedSegNet-SSF ========================
class SpectralSelectiveTokenMixer(Layer):
    def __init__(self, channels, num_frequencies=32, ssm_state_dim=16, 
                 use_spectral=True, use_ssm=True, dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_frequencies = num_frequencies
        self.ssm_state_dim = ssm_state_dim
        self.use_spectral = use_spectral
        self.use_ssm = use_ssm
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        input_h, input_w = input_shape[1], input_shape[2]
        self.actual_frequencies = min(self.num_frequencies, input_h, input_w) if input_h else self.num_frequencies
        
        if self.use_spectral:
            self.freq_weights_real = self.add_weight(
                name='freq_weights_real',
                shape=(self.actual_frequencies, self.actual_frequencies, self.channels),
                initializer=self._get_initializer(),
                trainable=True
            )
            self.spectral_norm = LayerNormalization(epsilon=1e-6, name='spectral_norm')
        
        if self.use_ssm:
            self.ssm_C = Dense(self.channels, name='ssm_C')
            self.selection_gate = Dense(self.channels, activation='sigmoid', name='selection')
            self.ssm_norm = LayerNormalization(epsilon=1e-6, name='ssm_norm')
        
        if self.use_spectral and self.use_ssm:
            self.fusion = Dense(self.channels, name='fusion')
            self.fusion_norm = LayerNormalization(epsilon=1e-6, name='fusion_norm')
        
        self.norm = LayerNormalization(epsilon=1e-6, name='norm')
        super().build(input_shape)
    
    def _get_initializer(self):
        def init_fn(shape, dtype=None):
            H, W, C = shape
            freq_h = np.fft.fftfreq(H)[:, np.newaxis]
            freq_w = np.fft.fftfreq(W)[np.newaxis, :]
            freq_magnitude = np.sqrt(freq_h**2 + freq_w**2)
            gaussian = np.exp(-((freq_magnitude - 0.25)**2) / (2 * 0.15**2))
            gaussian = np.repeat(gaussian[:, :, np.newaxis], C, axis=2)
            return gaussian.astype(np.float32) * 0.5
        return init_fn
    
    def spectral_path(self, x):
        H, W = tf.shape(x)[1], tf.shape(x)[2]
        freq_size = tf.minimum(tf.minimum(H, W), self.actual_frequencies)
        x_complex = tf.cast(x, tf.complex64)
        x_freq = tf.signal.fft2d(x_complex)
        x_freq_real = tf.math.real(x_freq)
        x_freq_imag = tf.math.imag(x_freq)
        x_freq_real_resized = tf.image.resize(x_freq_real, [freq_size, freq_size], method='bilinear')
        x_freq_imag_resized = tf.image.resize(x_freq_imag, [freq_size, freq_size], method='bilinear')
        x_freq_resized = tf.complex(x_freq_real_resized, x_freq_imag_resized)
        freq_filter = tf.cast(self.freq_weights_real[:freq_size, :freq_size, :], tf.complex64)
        x_freq_filtered = x_freq_resized * freq_filter
        x_freq_filt_real = tf.math.real(x_freq_filtered)
        x_freq_filt_imag = tf.math.imag(x_freq_filtered)
        x_freq_back_real = tf.image.resize(x_freq_filt_real, [H, W], method='bilinear')
        x_freq_back_imag = tf.image.resize(x_freq_filt_imag, [H, W], method='bilinear')
        x_freq_back = tf.complex(x_freq_back_real, x_freq_back_imag)
        x_spatial = tf.signal.ifft2d(x_freq_back)
        return self.spectral_norm(tf.math.real(x_spatial))
    
    def ssm_path(self, x):
        B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x_flat = tf.reshape(x, [B, H * W, self.channels])
        selection = self.selection_gate(x_flat)
        x_selected = x_flat * selection
        x_ssm = self.ssm_C(x_selected)
        return self.ssm_norm(tf.reshape(x_ssm, [B, H, W, self.channels]))
    
    def call(self, x, training=None):
        outputs = []
        if self.use_spectral:
            outputs.append(self.spectral_path(x))
        if self.use_ssm:
            outputs.append(self.ssm_path(x))
        
        if len(outputs) == 2:
            fused = self.fusion_norm(self.fusion(tf.concat(outputs, axis=-1)))
        elif len(outputs) == 1:
            fused = outputs[0]
        else:
            fused = x
        
        fused = self.norm(fused)
        if training and self.dropout_rate > 0:
            fused = tf.nn.dropout(fused, rate=self.dropout_rate)
        return x + fused

def MRF_SE_BLOCK(x, filters, activation='elu', dropout=0.0, expand_ratio=6, 
                 regularizer=0.0, kernels=[3, 5, 7], se_reduction=16, name='mrf_se'):
    F_expanded = filters * expand_ratio
    
    conv = Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                  name=name+'_expand')(x) if expand_ratio > 1 else x
    conv = Activation(activation, name=name+'_expand_act')(BatchNormalization(name=name+'_expand_bn')(conv))
    
    features = []
    for k in kernels:
        dw = DepthwiseConv2D((k, k), padding='same', depthwise_initializer='he_uniform',
                            depthwise_regularizer=l2(regularizer) if regularizer > 0 else None,
                            name=f"{name}_dw{k}x{k}")(conv)
        features.append(Activation(activation, name=f"{name}_dw{k}x{k}_act")(
            BatchNormalization(name=f"{name}_dw{k}x{k}_bn")(dw)))
    
    combined = Concatenate(name=name+'_concat')(features) if len(features) > 1 else features[0]
    if len(features) > 1:
        combined = Activation(activation, name=name+'_fuse_act')(
            BatchNormalization(name=name+'_fuse_bn')(
                Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_fuse')(combined)))
    
    gap = Reshape((1, 1, F_expanded), name=name+'_reshape')(GlobalAveragePooling2D(name=name+'_gap')(combined))
    se = Conv2D(F_expanded, (1, 1), activation='sigmoid', kernel_initializer='he_uniform', name=name+'_se_expand')(
        Conv2D(max(F_expanded//se_reduction, 8), (1, 1), activation=activation, kernel_initializer='he_uniform',
              name=name+'_se_reduce')(gap))
    
    projected = Conv2D(filters, (1, 1), padding='same', kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_project')(Multiply(name=name+'_se_mult')([combined, se]))
    projected = BatchNormalization(name=name+'_project_bn')(projected)
    
    if dropout > 0:
        projected = Dropout(dropout, name=name+'_dropout')(projected)
    
    out = Add(name=name+'_add')([projected, x])
    return out

def boundary_detection_module(features, filters, name='boundary'):
    boundary_conv = Conv2D(filters // 2, (3, 3), padding='same', activation='relu', name=name + '_conv')(features)
    boundary_map = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name=name + '_map')(boundary_conv)
    return Multiply(name=name + '_mult')([features, boundary_map]), boundary_map

def BFP_decoder_stage(decoder_input, skip_features, filters, stage_name='bfp'):
    region = Concatenate(name=stage_name+'_concat')([
        UpSampling2D((2, 2), name=stage_name+'_up')(decoder_input),
        skip_features
    ])
    
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv1')(region)))
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv2')(region)))
    
    boundary_features, boundary_map = boundary_detection_module(region, filters, stage_name+'_boundary')
    
    boundary_refined = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_boundary_refine')(boundary_features)))
    
    output = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (1, 1), padding='same', name=stage_name+'_fusion')(
            region * (1 - boundary_map) + boundary_refined * boundary_map)))
    
    return output, boundary_map

def build_medsegnet_ssf(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING MED-SEGNET-SSF")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    x = Conv2D(16, (3, 3), padding='same', kernel_initializer='he_uniform', name='stem_conv')(inp)
    x = BatchNormalization(name='stem_bn')(x)
    x = Activation('elu', name='stem_act')(x)
    
    encoder_outputs = []
    filters = [cfg.F1, cfg.F2, cfg.F3, cfg.F4, cfg.F5]
    
    for i, f in enumerate(filters):
        x = Conv2D(f, (3, 3), strides=2, padding='same', kernel_initializer='he_uniform')(x)
        x = BatchNormalization()(x)
        x = Activation('elu')(x)
        
        if cfg.USE_MRF_SE:
            x = MRF_SE_BLOCK(x, f, activation='elu', dropout=cfg.DROPOUT,
                           expand_ratio=cfg.EXPAND_RATIO, regularizer=cfg.L2_REG,
                           kernels=cfg.MRF_KERNELS, se_reduction=cfg.SE_REDUCTION,
                           name=f'mrfse_stage{i+1}')
        
        if cfg.USE_SSTM:
            x = SpectralSelectiveTokenMixer(
                channels=f, num_frequencies=cfg.SSTM_NUM_FREQUENCIES,
                ssm_state_dim=cfg.SSTM_SSM_STATE_DIM,
                use_spectral=cfg.SSTM_USE_SPECTRAL[i],
                use_ssm=cfg.SSTM_USE_SSM[i],
                dropout=cfg.SSTM_DROPOUT,
                name=f'sstm_stage{i+1}'
            )(x)
        
        encoder_outputs.append(x)
    
    skip_connections = encoder_outputs[::-1]
    decoder = skip_connections[0]
    decoder_filters = filters[::-1][1:] + [16]
    
    for i, (skip, f) in enumerate(zip(skip_connections[1:], decoder_filters)):
        if cfg.USE_BFP:
            decoder, _ = BFP_decoder_stage(decoder, skip, f, stage_name=f'bfp_stage{i+1}')
    
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)
    decoder = Conv2D(16, (3, 3), padding='same', activation='relu')(decoder)
    out = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name='output')(decoder)
    
    model = Model(inputs=inp, outputs=out, name="MedSegNet-SSF")
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ======================== 9. DuckNet ========================
def duck_block(x, filters, name='duck'):
    """DuckNet block with grouped convolutions."""
    # Widening path
    w1 = Conv2D(filters, (3, 3), padding='same', name=f'{name}_w1')(x)
    w1 = BatchNormalization(name=f'{name}_w1_bn')(w1)
    w1 = Activation('relu', name=f'{name}_w1_act')(w1)
    
    w2 = Conv2D(filters, (3, 3), padding='same', name=f'{name}_w2')(w1)
    w2 = BatchNormalization(name=f'{name}_w2_bn')(w2)
    w2 = Activation('relu', name=f'{name}_w2_act')(w2)
    
    # Residual
    residual = Conv2D(filters, (1, 1), padding='same', name=f'{name}_res')(x)
    
    out = Add(name=f'{name}_add')([w2, residual])
    return out

def build_ducknet(cfg):
    """DuckNet - lightweight architecture."""
    print("\n" + "="*80)
    print("🔥 BUILDING DUCKNET")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder
    x = Conv2D(32, (3, 3), padding='same')(inputs)
    e1 = duck_block(x, 32, 'duck_e1')
    p1 = MaxPooling2D((2, 2))(e1)
    
    e2 = duck_block(p1, 64, 'duck_e2')
    p2 = MaxPooling2D((2, 2))(e2)
    
    e3 = duck_block(p2, 128, 'duck_e3')
    p3 = MaxPooling2D((2, 2))(e3)
    
    e4 = duck_block(p3, 256, 'duck_e4')
    p4 = MaxPooling2D((2, 2))(e4)
    
    # Bottleneck
    b = duck_block(p4, 512, 'duck_b')
    
    # Decoder
    u4 = Conv2DTranspose(256, (2, 2), strides=2, padding='same')(b)
    u4 = Concatenate()([u4, e4])
    d4 = duck_block(u4, 256, 'duck_d4')
    
    u3 = Conv2DTranspose(128, (2, 2), strides=2, padding='same')(d4)
    u3 = Concatenate()([u3, e3])
    d3 = duck_block(u3, 128, 'duck_d3')
    
    u2 = Conv2DTranspose(64, (2, 2), strides=2, padding='same')(d3)
    u2 = Concatenate()([u2, e2])
    d2 = duck_block(u2, 64, 'duck_d2')
    
    u1 = Conv2DTranspose(32, (2, 2), strides=2, padding='same')(d2)
    u1 = Concatenate()([u1, e1])
    d1 = duck_block(u1, 32, 'duck_d1')
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(d1)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='DuckNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ======================== 10. nnUNet ========================
def build_nnunet(cfg):
    """nnUNet-style architecture."""
    print("\n" + "="*80)
    print("🔥 BUILDING nnUNet")
    print("="*80)
    
    inputs = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3))
    
    # Encoder with instance normalization and leaky ReLU
    def conv_block(x, filters, name):
        x = Conv2D(filters, (3, 3), padding='same', name=f'{name}_conv1')(x)
        x = BatchNormalization(name=f'{name}_bn1')(x)  # Using BN instead of IN for simplicity
        x = Activation('relu', name=f'{name}_act1')(x)
        x = Conv2D(filters, (3, 3), padding='same', name=f'{name}_conv2')(x)
        x = BatchNormalization(name=f'{name}_bn2')(x)
        x = Activation('relu', name=f'{name}_act2')(x)
        return x
    
    # Encoder
    e1 = conv_block(inputs, 32, 'enc1')
    p1 = Conv2D(32, (3, 3), strides=2, padding='same')(e1)
    
    e2 = conv_block(p1, 64, 'enc2')
    p2 = Conv2D(64, (3, 3), strides=2, padding='same')(e2)
    
    e3 = conv_block(p2, 128, 'enc3')
    p3 = Conv2D(128, (3, 3), strides=2, padding='same')(e3)
    
    e4 = conv_block(p3, 256, 'enc4')
    p4 = Conv2D(256, (3, 3), strides=2, padding='same')(e4)
    
    e5 = conv_block(p4, 512, 'enc5')
    p5 = Conv2D(512, (3, 3), strides=2, padding='same')(e5)
    
    # Bottleneck
    b = conv_block(p5, 512, 'bottleneck')
    
    # Decoder
    u5 = Conv2DTranspose(512, (2, 2), strides=2, padding='same')(b)
    u5 = Concatenate()([u5, e5])
    d5 = conv_block(u5, 512, 'dec5')
    
    u4 = Conv2DTranspose(256, (2, 2), strides=2, padding='same')(d5)
    u4 = Concatenate()([u4, e4])
    d4 = conv_block(u4, 256, 'dec4')
    
    u3 = Conv2DTranspose(128, (2, 2), strides=2, padding='same')(d4)
    u3 = Concatenate()([u3, e3])
    d3 = conv_block(u3, 128, 'dec3')
    
    u2 = Conv2DTranspose(64, (2, 2), strides=2, padding='same')(d3)
    u2 = Concatenate()([u2, e2])
    d2 = conv_block(u2, 64, 'dec2')
    
    u1 = Conv2DTranspose(32, (2, 2), strides=2, padding='same')(d2)
    u1 = Concatenate()([u1, e1])
    d1 = conv_block(u1, 32, 'dec1')
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(d1)
    
    model = Model(inputs=[inputs], outputs=[outputs], name='nnUNet')
    print(f"Total parameters: {model.count_params():,}")
    print("="*80 + "\n")
    return model

# ==============================================================================
# MODEL FACTORY
# ==============================================================================

def get_model(model_name, cfg):
    """Factory function to get model by name."""
    models = {
        "UNet": build_unet,
        "UNetPlusPlus": build_unet_plusplus,
        "PraNet": build_pranet,
        "RAPUNet": build_rapunet,
        "SwinUNet": build_swinunet,
        "TransUNet": build_transunet,
        "UMamba": build_umamba,
        "MedSegNet-SSF": build_medsegnet_ssf,
        "DuckNet": build_ducknet,
        "nnUNet": build_nnunet
    }
    
    if model_name not in models:
        raise ValueError(f"Unknown model: {model_name}")
    
    return models[model_name](cfg)

# ==============================================================================
# MASL LOSS FUNCTION
# ==============================================================================

class ClipConstraint(tf.keras.constraints.Constraint):
    def __init__(self, min_value=0.1, max_value=10.0):
        self.min_value = min_value
        self.max_value = max_value
    
    def __call__(self, w):
        return tf.clip_by_value(w, self.min_value, self.max_value)
    
    def get_config(self):
        return {'min_value': self.min_value, 'max_value': self.max_value}

class MorphologyAwareAdaptiveLoss(Layer):
    def __init__(self, name='masl', **kwargs):
        super().__init__(name=name, **kwargs)
        self.epsilon = 1e-6
        
    def build(self, input_shape):
        clip_constraint = ClipConstraint(min_value=0.1, max_value=10.0)
        
        self.w_region = self.add_weight(name='w_region', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint)
        self.w_boundary = self.add_weight(name='w_boundary', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint)
        self.w_structure = self.add_weight(name='w_structure', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint)
        self.w_scale = self.add_weight(name='w_scale', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint)
        self.w_texture = self.add_weight(name='w_texture', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint)
        super().build(input_shape)
    
    def morphological_dilation(self, x, kernel_size=5):
        return tf.nn.max_pool2d(x, kernel_size, strides=1, padding='SAME')
    
    def morphological_erosion(self, x, kernel_size=5):
        return -tf.nn.max_pool2d(-x, kernel_size, strides=1, padding='SAME')
    
    def detect_boundary(self, mask, kernel_size=5):
        dilated = self.morphological_dilation(mask, kernel_size)
        eroded = self.morphological_erosion(mask, kernel_size)
        boundary = dilated - eroded
        return tf.clip_by_value(boundary, 0.0, 1.0)
    
    def analyze_structure_characteristics(self, y_true):
        area = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        total_pixels = tf.cast(tf.shape(y_true)[1] * tf.shape(y_true)[2], tf.float32)
        
        dy = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_padded = tf.pad(dy, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_padded = tf.pad(dx, [[0, 0], [0, 0], [0, 1], [0, 0]])
        gradient_mag = tf.sqrt(dy_padded**2 + dx_padded**2 + self.epsilon)
        perimeter = tf.reduce_sum(gradient_mag, axis=[1, 2, 3]) + self.epsilon
        
        skeleton_approx = self.morphological_erosion(y_true, kernel_size=3)
        skeleton_area = tf.reduce_sum(skeleton_approx, axis=[1, 2, 3]) + self.epsilon
        
        tubularity = tf.reduce_mean(skeleton_area / (area + self.epsilon))
        compactness = tf.reduce_mean((4 * 3.14159 * area) / (perimeter**2 + self.epsilon))
        compactness = tf.clip_by_value(compactness, 0.0, 1.0)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        ddy = boundary[:, 2:, :, :] - 2*boundary[:, 1:-1, :, :] + boundary[:, :-2, :, :]
        ddx = boundary[:, :, 2:, :] - 2*boundary[:, :, 1:-1, :] + boundary[:, :, :-2, :]
        irregularity = tf.reduce_mean(tf.abs(ddy)) + tf.reduce_mean(tf.abs(ddx))
        
        object_size = tf.reduce_mean(area / total_pixels)
        
        return {
            'tubularity': tf.clip_by_value(tubularity, 0.0, 1.0),
            'compactness': compactness,
            'irregularity': tf.clip_by_value(irregularity, 0.0, 1.0),
            'object_size': tf.clip_by_value(object_size, 0.0, 1.0)
        }
    
    def core_loss(self, y_true, y_pred):
        intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
        dice = (2. * intersection + self.epsilon) / (
            tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon)
        dice_loss = 1.0 - tf.reduce_mean(dice)
        
        union = (tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection)
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        iou_loss = 1.0 - tf.reduce_mean(iou)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        weights = 1.0 + 5.0 * boundary
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        weighted_bce = tf.reduce_mean(weights * bce)
        
        return 0.4 * dice_loss + 0.3 * iou_loss + 0.3 * weighted_bce
    
    def boundary_loss(self, y_true, y_pred):
        total_loss = 0.0
        weights = [0.5, 0.3, 0.2]
        
        for scale, w in zip([1, 2, 4], weights):
            dy_true = y_true[:, scale:, :, :] - y_true[:, :-scale, :, :]
            dy_pred = y_pred[:, scale:, :, :] - y_pred[:, :-scale, :, :]
            dx_true = y_true[:, :, scale:, :] - y_true[:, :, :-scale, :]
            dx_pred = y_pred[:, :, scale:, :] - y_pred[:, :, :-scale, :]
            total_loss += w * (tf.reduce_mean(tf.abs(dy_true - dy_pred)) + tf.reduce_mean(tf.abs(dx_true - dx_pred)))
        
        return total_loss
    
    def structure_aware_loss(self, y_true, y_pred, characteristics):
        area_true = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        dy_true = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx_true = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_true_padded = tf.pad(dy_true, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_true_padded = tf.pad(dx_true, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_true = tf.reduce_sum(tf.sqrt(dy_true_padded**2 + dx_true_padded**2 + self.epsilon), axis=[1, 2, 3]) + self.epsilon
        
        area_pred = tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon
        dy_pred = y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]
        dx_pred = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
        dy_pred_padded = tf.pad(dy_pred, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_pred_padded = tf.pad(dx_pred, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_pred = tf.reduce_sum(tf.sqrt(dy_pred_padded**2 + dx_pred_padded**2 + self.epsilon), axis=[1, 2, 3]) + self.epsilon
        
        compact_true = area_true / (perimeter_true**2 + self.epsilon)
        compact_pred = area_pred / (perimeter_pred**2 + self.epsilon)
        
        return characteristics['compactness'] * tf.reduce_mean(tf.abs(compact_true - compact_pred))
    
    def scale_aware_focal_loss(self, y_true, y_pred, characteristics):
        size = characteristics['object_size']
        gamma = tf.cond(size < 0.05, lambda: 3.0, lambda: tf.cond(size < 0.2, lambda: 2.0, lambda: 1.5))
        
        p = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_weight = tf.pow(1 - p, gamma)
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        
        return tf.reduce_mean(focal_weight * bce)
    
    def texture_aware_loss(self, y_true, y_pred):
        ddy_true = y_true[:, 2:, :, :] - 2*y_true[:, 1:-1, :, :] + y_true[:, :-2, :, :]
        ddy_pred = y_pred[:, 2:, :, :] - 2*y_pred[:, 1:-1, :, :] + y_pred[:, :-2, :, :]
        ddx_true = y_true[:, :, 2:, :] - 2*y_true[:, :, 1:-1, :] + y_true[:, :, :-2, :]
        ddx_pred = y_pred[:, :, 2:, :] - 2*y_pred[:, :, 1:-1, :] + y_pred[:, :, :-2, :]
        
        return tf.reduce_mean(tf.abs(ddy_true - ddy_pred)) + tf.reduce_mean(tf.abs(ddx_true - ddx_pred))
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        characteristics = self.analyze_structure_characteristics(y_true)
        
        alpha_region = 1.0 + 0.5 * characteristics['compactness']
        alpha_boundary = 1.0 + 1.5 * characteristics['tubularity'] + characteristics['compactness']
        alpha_structure = 1.0 + characteristics['tubularity']
        alpha_scale = 1.0 + 1.5 * characteristics['irregularity']
        alpha_texture = 1.0 + characteristics['irregularity']
        
        l_core = self.core_loss(y_true, y_pred)
        l_boundary = self.boundary_loss(y_true, y_pred)
        l_structure = self.structure_aware_loss(y_true, y_pred, characteristics)
        l_scale = self.scale_aware_focal_loss(y_true, y_pred, characteristics)
        l_texture = self.texture_aware_loss(y_true, y_pred)
        
        weighted_core = self.w_region * alpha_region * l_core
        weighted_boundary = self.w_boundary * alpha_boundary * l_boundary
        weighted_structure = self.w_structure * alpha_structure * l_structure
        weighted_scale = self.w_scale * alpha_scale * l_scale
        weighted_texture = self.w_texture * alpha_texture * l_texture
        
        total_weight = (self.w_region * alpha_region + self.w_boundary * alpha_boundary + 
                       self.w_structure * alpha_structure + self.w_scale * alpha_scale + self.w_texture * alpha_texture)
        
        masl_loss = (weighted_core + weighted_boundary + weighted_structure + weighted_scale + weighted_texture) / (total_weight + self.epsilon)
        
        return masl_loss
    
    def get_config(self):
        return super().get_config()

_masl_instance = MorphologyAwareAdaptiveLoss()

def masl_loss_fn(y_true, y_pred):
    return _masl_instance(y_true, y_pred)

# ==============================================================================
# METRICS
# ==============================================================================

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def precision_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    predicted_positives = K.sum(y_pred_bin)
    return true_positives / (predicted_positives + K.epsilon())

def recall_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    actual_positives = K.sum(y_true)
    return true_positives / (actual_positives + K.epsilon())

# ==============================================================================
# PATCH-BASED PREDICTION
# ==============================================================================

def _predict_patches_single(model, image, cfg):
    """Helper: Predict full image using overlapping patches (no TTA)."""
    h, w = image.shape[:2]
    
    prediction = np.zeros((h, w), dtype=np.float32)
    counts = np.zeros((h, w), dtype=np.float32)
    
    for y in range(0, h - cfg.PATCH_SIZE + 1, cfg.STRIDE):
        for x in range(0, w - cfg.PATCH_SIZE + 1, cfg.STRIDE):
            patch = image[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE]
            pred_patch = model.predict(np.expand_dims(patch, 0), verbose=0)[0]
            pred_patch = pred_patch[:, :, 0]
            
            prediction[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE] += pred_patch
            counts[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE] += 1
    
    prediction = prediction / (counts + 1e-6)
    return prediction

def predict_full_image_with_patches(model, image, cfg, use_tta=False):
    """Predict on full-resolution image using overlapping patches."""
    if not use_tta:
        return _predict_patches_single(model, image, cfg)
    else:
        predictions = []
        
        # 1. Original
        predictions.append(_predict_patches_single(model, image, cfg))
        
        # 2-8. TTA augmentations
        img_h = np.flip(image, axis=1)
        pred_h = _predict_patches_single(model, img_h, cfg)
        predictions.append(np.flip(pred_h, axis=1))
        
        img_v = np.flip(image, axis=0)
        pred_v = _predict_patches_single(model, img_v, cfg)
        predictions.append(np.flip(pred_v, axis=0))
        
        img_hv = np.flip(np.flip(image, 0), 1)
        pred_hv = _predict_patches_single(model, img_hv, cfg)
        predictions.append(np.flip(np.flip(pred_hv, 0), 1))
        
        for k in [1, 2, 3]:
            img_rot = np.rot90(image, k, axes=(0, 1))
            pred_rot = _predict_patches_single(model, img_rot, cfg)
            predictions.append(np.rot90(pred_rot, -k, axes=(0, 1)))
        
        img_diag = np.transpose(image, (1, 0, 2))
        pred_diag = _predict_patches_single(model, img_diag, cfg)
        predictions.append(np.transpose(pred_diag, (1, 0)))
        
        return np.mean(predictions, axis=0)

# ==============================================================================
# VISUALIZATION FUNCTIONS
# ==============================================================================

def save_prediction_image(original_image, ground_truth, prediction, save_path, image_name):
    """Save side-by-side comparison."""
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(original_image)
    axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(ground_truth, cmap='gray')
    axes[1].set_title('Ground Truth', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    axes[2].imshow(prediction, cmap='gray')
    axes[2].set_title('Prediction', fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    overlay = original_image.copy()
    pred_binary = (prediction > 0.5).astype(np.uint8)
    overlay[pred_binary == 1] = [0, 255, 0]
    axes[3].imshow(overlay)
    axes[3].set_title('Overlay (Green=Prediction)', fontsize=14, fontweight='bold')
    axes[3].axis('off')
    
    plt.suptitle(f'{image_name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def save_individual_predictions(prediction, save_dir, image_name):
    """Save individual prediction as grayscale image."""
    pred_uint8 = (prediction * 255).astype(np.uint8)
    save_path = os.path.join(save_dir, f"{image_name}_prediction.png")
    cv2.imwrite(save_path, pred_uint8)
    
    pred_binary = ((prediction > 0.5) * 255).astype(np.uint8)
    save_path_binary = os.path.join(save_dir, f"{image_name}_prediction_binary.png")
    cv2.imwrite(save_path_binary, pred_binary)

# ==============================================================================
# EVALUATION FUNCTIONS
# ==============================================================================

def evaluate_model(model, test_pairs, cfg, model_name, use_tta=False):
    """Evaluate model on full images using patch-based inference."""
    tta_suffix = "with_tta" if use_tta else "no_tta"
    print(f"\n{'='*80}")
    print(f"📊 EVALUATING {model_name} {'WITH TTA' if use_tta else 'WITHOUT TTA'}")
    print(f"{'='*80}")
    
    dice_scores = []
    iou_scores = []
    precision_scores = []
    recall_scores = []
    
    # Create visualization directory
    viz_dir = os.path.join(cfg.SAVE_DIR, model_name, tta_suffix)
    os.makedirs(viz_dir, exist_ok=True)
    
    for img_path, mask_path in test_pairs:
        # Load full-resolution image
        image_original = cv2.imread(img_path)
        if image_original is None:
            continue
        image_original_rgb = cv2.cvtColor(image_original, cv2.COLOR_BGR2RGB)
        
        # Load mask
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)
        
        # Preprocessing
        image_preprocessed = preprocess_image_clahe(image_original_rgb, cfg)
        image_preprocessed, mask_masked = apply_fov_mask(image_preprocessed, mask, cfg)
        
        # Predict using patches
        pred = predict_full_image_with_patches(model, image_preprocessed, cfg, use_tta=use_tta)
        pred_binary = (pred > 0.5).astype(np.float32)
        
        # Calculate metrics
        intersection = np.sum(mask_masked * pred_binary)
        union = np.sum(mask_masked) + np.sum(pred_binary) - intersection
        
        dice = (2.0 * intersection) / (np.sum(mask_masked) + np.sum(pred_binary) + 1e-6)
        iou = intersection / (union + 1e-6)
        precision = intersection / (np.sum(pred_binary) + 1e-6)
        recall = intersection / (np.sum(mask_masked) + 1e-6)
        
        dice_scores.append(dice)
        iou_scores.append(iou)
        precision_scores.append(precision)
        recall_scores.append(recall)
        
        image_name = Path(img_path).stem
        print(f"  {image_name}: Dice={dice:.4f}, IoU={iou:.4f}")
        
        # Save visualizations
        if cfg.SAVE_PREDICTIONS:
            save_path = os.path.join(viz_dir, f"{image_name}_comparison.png")
            save_prediction_image(image_original_rgb, mask_masked, pred, save_path, image_name)
            save_individual_predictions(pred, viz_dir, image_name)
    
    results = {
        'dice': {'mean': float(np.mean(dice_scores)), 'std': float(np.std(dice_scores)), 'all': dice_scores},
        'iou': {'mean': float(np.mean(iou_scores)), 'std': float(np.std(iou_scores)), 'all': iou_scores},
        'precision': {'mean': float(np.mean(precision_scores)), 'std': float(np.std(precision_scores)), 'all': precision_scores},
        'recall': {'mean': float(np.mean(recall_scores)), 'std': float(np.std(recall_scores)), 'all': recall_scores}
    }
    
    print(f"\n{'='*80}")
    print(f"RESULTS FOR {model_name} ({'WITH TTA' if use_tta else 'WITHOUT TTA'})")
    print(f"{'='*80}")
    print(f"Dice:      {results['dice']['mean']:.4f} ± {results['dice']['std']:.4f}")
    print(f"IoU:       {results['iou']['mean']:.4f} ± {results['iou']['std']:.4f}")
    print(f"Precision: {results['precision']['mean']:.4f} ± {results['precision']['std']:.4f}")
    print(f"Recall:    {results['recall']['mean']:.4f} ± {results['recall']['std']:.4f}")
    print(f"{'='*80}")
    
    return results

# ==============================================================================
# LEARNING RATE SCHEDULER
# ==============================================================================

def cosine_annealing_with_warmup(epoch, lr, total_epochs=100, warmup_epochs=10, min_lr=1e-6):
    """Cosine annealing with warmup"""
    if epoch < warmup_epochs:
        return lr * (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return min_lr + (lr - min_lr) * (1 + np.cos(np.pi * progress)) / 2

# ==============================================================================
# MAIN TRAINING LOOP FOR SINGLE MODEL
# ==============================================================================

def train_single_model(model_name, cfg, strategy, train_gen, val_gen, test_gen, test_pairs):
    """Train a single model."""
    print(f"\n{'='*80}")
    print(f"🚀 TRAINING MODEL: {model_name}")
    print(f"{'='*80}")
    
    model_save_dir = os.path.join(cfg.SAVE_DIR, model_name)
    os.makedirs(model_save_dir, exist_ok=True)
    
    # Build model
    with strategy.scope():
        model = get_model(model_name, cfg)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)
        
        model.compile(
            optimizer=optimizer,
            loss="binary_crossentropy",
            metrics=[dice_coefficient, iou_score, precision_metric, recall_metric]
        )
    
    # Callbacks
    lr_scheduler = LearningRateScheduler(
        lambda epoch: cosine_annealing_with_warmup(epoch, cfg.LEARNING_RATE, cfg.EPOCHS, warmup_epochs=10),
        verbose=0
    )
    
    callbacks = [
        ModelCheckpoint(
            os.path.join(model_save_dir, f"best_{model_name}.h5"),
            monitor=cfg.CHECKPOINT_MONITOR,
            mode=cfg.CHECKPOINT_MODE,
            save_best_only=True,
            verbose=1
        ),
        EarlyStopping(
            monitor=cfg.CHECKPOINT_MONITOR,
            mode=cfg.CHECKPOINT_MODE,
            patience=cfg.EARLY_STOPPING_PATIENCE,
            verbose=1,
            restore_best_weights=True
        ),
        CSVLogger(os.path.join(model_save_dir, f"training_log_{model_name}.csv")),
        lr_scheduler
    ]
    
    # Train
    start_time = time.time()
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=cfg.EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    
    training_time = time.time() - start_time
    print(f"\n✅ {model_name} training finished in {training_time/60:.1f} minutes")
    
    # Evaluate
    print(f"\n{'='*80}")
    print(f"📊 EVALUATING {model_name}")
    print(f"{'='*80}")
    
    # Evaluate without TTA
    no_tta_results = evaluate_model(model, test_pairs, cfg, model_name, use_tta=False)
    
    # Evaluate with TTA
    with_tta_results = evaluate_model(model, test_pairs, cfg, model_name, use_tta=True)
    
    # Save results
    results = {
        "model": model_name,
        "training_time_minutes": training_time / 60,
        "no_tta": {
            "dice": no_tta_results['dice']['mean'],
            "dice_std": no_tta_results['dice']['std'],
            "iou": no_tta_results['iou']['mean'],
            "precision": no_tta_results['precision']['mean'],
            "recall": no_tta_results['recall']['mean']
        },
        "with_tta": {
            "dice": with_tta_results['dice']['mean'],
            "dice_std": with_tta_results['dice']['std'],
            "iou": with_tta_results['iou']['mean'],
            "precision": with_tta_results['precision']['mean'],
            "recall": with_tta_results['recall']['mean']
        },
        "tta_boost": {
            "dice": with_tta_results['dice']['mean'] - no_tta_results['dice']['mean'],
            "dice_percent": ((with_tta_results['dice']['mean'] - no_tta_results['dice']['mean']) / no_tta_results['dice']['mean']) * 100
        }
    }
    
    with open(os.path.join(model_save_dir, f"results_{model_name}.json"), "w") as f:
        json.dump(results, f, indent=2)
    
    # Save per-image results
    results_df = pd.DataFrame({
        'image': [Path(img_path).stem for img_path, _ in test_pairs],
        'dice_no_tta': no_tta_results['dice']['all'],
        'dice_with_tta': with_tta_results['dice']['all'],
        'iou_no_tta': no_tta_results['iou']['all'],
        'iou_with_tta': with_tta_results['iou']['all']
    })
    results_df.to_csv(os.path.join(model_save_dir, f"per_image_results_{model_name}.csv"), index=False)
    
    # Clear session to free memory
    K.clear_session()
    
    return results

# ==============================================================================
# MAIN MULTI-MODEL TRAINING ORCHESTRATOR
# ==============================================================================

def train_all_models(cfg, strategy):
    """Train all selected models."""
    set_seed(cfg.SEED, cfg.DETERMINISTIC)
    
    # Load data
    train_pairs = load_dataset_split(cfg.TRAIN_DIR)
    val_pairs = load_dataset_split(cfg.VAL_DIR)
    test_pairs = load_dataset_split(cfg.TEST_DIR)
    
    if not train_pairs:
        print("❌ No training data found!")
        return
    
    print(f"\n📊 Dataset Statistics:")
    print(f"   Training images:   {len(train_pairs)}")
    print(f"   Validation images: {len(val_pairs)}")
    print(f"   Test images:       {len(test_pairs)}")
    
    # Extract patches
    print("\n🔥 EXTRACTING PATCHES...")
    train_patches = extract_all_patches(train_pairs, cfg)
    val_patches = extract_all_patches(val_pairs, cfg)
    test_patches = extract_all_patches(test_pairs, cfg)
    
    # Create generators
    train_aug = get_patch_augmentation(cfg)
    train_gen = PatchBasedGenerator(train_patches, cfg, augmentation=train_aug, shuffle=True)
    val_gen = PatchBasedGenerator(val_patches, cfg, augmentation=None, shuffle=False)
    test_gen = PatchBasedGenerator(test_patches, cfg, augmentation=None, shuffle=False)
    
    # Get list of models to train
    models_to_train = [name for name, enabled in cfg.MODELS_TO_TRAIN.items() if enabled]
    
    print(f"\n{'='*80}")
    print(f"🔥 STARTING MULTI-MODEL TRAINING")
    print(f"{'='*80}")
    print(f"Models to train: {models_to_train}")
    print(f"Total: {len(models_to_train)} models")
    print(f"{'='*80}\n")
    
    # Train each model
    all_results = {}
    for i, model_name in enumerate(models_to_train, 1):
        print(f"\n{'#'*80}")
        print(f"# MODEL {i}/{len(models_to_train)}: {model_name}")
        print(f"{'#'*80}")
        
        try:
            results = train_single_model(model_name, cfg, strategy, train_gen, val_gen, test_gen, test_pairs)
            all_results[model_name] = results
            print(f"\n✅ {model_name} completed successfully!")
        except Exception as e:
            print(f"\n❌ Error training {model_name}: {e}")
            all_results[model_name] = {"error": str(e)}
            continue
    
    # Save comprehensive results
    print(f"\n{'='*80}")
    print("📊 GENERATING COMPREHENSIVE COMPARISON")
    print(f"{'='*80}")
    
    # Create comparison table
    comparison_data = []
    for model_name, results in all_results.items():
        if "error" not in results:
            comparison_data.append({
                "Model": model_name,
                "Dice (No TTA)": f"{results['no_tta']['dice']:.4f} ± {results['no_tta']['dice_std']:.4f}",
                "Dice (With TTA)": f"{results['with_tta']['dice']:.4f} ± {results['with_tta']['dice_std']:.4f}",
                "IoU (No TTA)": f"{results['no_tta']['iou']:.4f}",
                "IoU (With TTA)": f"{results['with_tta']['iou']:.4f}",
                "TTA Boost": f"+{results['tta_boost']['dice']:.4f} ({results['tta_boost']['dice_percent']:+.2f}%)",
                "Training Time (min)": f"{results['training_time_minutes']:.1f}"
            })
    
    comparison_df = pd.DataFrame(comparison_data)
    comparison_df.to_csv(os.path.join(cfg.SAVE_DIR, "model_comparison.csv"), index=False)
    
    # Save all results
    with open(os.path.join(cfg.SAVE_DIR, "all_results.json"), "w") as f:
        json.dump(all_results, f, indent=2)
    
    # Print final summary
    print(f"\n{'='*80}")
    print("🏆 FINAL RESULTS SUMMARY")
    print(f"{'='*80}")
    print(comparison_df.to_string(index=False))
    print(f"{'='*80}")
    
    # Find best model
    best_model = None
    best_dice = 0.0
    for model_name, results in all_results.items():
        if "error" not in results:
            dice = results['with_tta']['dice']
            if dice > best_dice:
                best_dice = dice
                best_model = model_name
    
    if best_model:
        print(f"\n🏆 BEST MODEL: {best_model}")
        print(f"   Dice (With TTA): {best_dice:.4f}")
        print(f"   Location: {os.path.join(cfg.SAVE_DIR, best_model)}/")
    
    print(f"\n📁 All results saved to: {cfg.SAVE_DIR}/")
    print(f"   ├─ model_comparison.csv       ⭐ COMPARISON TABLE")
    print(f"   ├─ all_results.json           ⭐ ALL RESULTS")
    print(f"   └─ [model_name]/              📂 Individual model results")
    print(f"       ├─ best_[model].h5")
    print(f"       ├─ results_[model].json")
    print(f"       ├─ per_image_results_[model].csv")
    print(f"       ├─ no_tta/                📸 Predictions without TTA")
    print(f"       └─ with_tta/              📸 Predictions with TTA")
    
    print(f"\n{'='*80}")
    print("🎉 MULTI-MODEL TRAINING COMPLETE!")
    print(f"{'='*80}\n")
    
    return all_results

# ==============================================================================
# MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("🔥 MULTI-MODEL MEDICAL IMAGE SEGMENTATION TRAINING")
    print("="*80)
    
    print("\n📋 CONFIGURATION:")
    print(f"   Data Root: {config.DATA_ROOT}")
    print(f"   Save Directory: {config.SAVE_DIR}")
    print(f"   Patch Size: {config.PATCH_SIZE}x{config.PATCH_SIZE}")
    print(f"   Batch Size: {config.BATCH_SIZE}")
    print(f"   Epochs: {config.EPOCHS}")
    
    print("\n🎯 MODELS TO TRAIN:")
    for model_name, enabled in config.MODELS_TO_TRAIN.items():
        status = "✅" if enabled else "⬜"
        print(f"   {status} {model_name}")
    
    print("\n⚠️ IMPORTANT:")
    print("   1. Set DATA_ROOT to your dataset path")
    print("   2. Adjust MODELS_TO_TRAIN to select models")
    print("   3. Each model will be trained sequentially")
    print("   4. Results will be saved per model")
    
    print("\n" + "="*80 + "\n")
    
    all_results = train_all_models(config, strategy)
    
    print("\n" + "="*80)
    print("🎉 ALL DONE!")
    print("="*80)
    print("\n✅ Check model_comparison.csv for results comparison!")
    print("✅ Check individual model folders for detailed results!")
    print("\n🚀 Happy segmenting! 🚀\n")