#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
================================================================================
🔥 MASL WITH COEFFICIENT PERTURBATION STUDY - COMPLETE IMPLEMENTATION
================================================================================

FEATURES:
- ✅ Coefficient perturbation study (±40% variations)
- ✅ Cross-dataset coefficient robustness testing
- ✅ Alternative coefficient sets (vessel/tumor/instrument-tuned)
- ✅ Automated comparison reports
- ✅ All original MASL features preserved
- ✅ Patch-based training and evaluation
- ✅ Comprehensive visualization

USAGE:
    1. Single training run:
       RUN_COEFFICIENT_STUDY = False
       SELECTED_COEFFICIENT_CONFIG = 'baseline'  # or any other config
    
    2. Perturbation study:
       RUN_COEFFICIENT_STUDY = True
       COEFFICIENT_STUDY_TYPE = 'perturbation'
    
    3. Cross-dataset study:
       RUN_COEFFICIENT_STUDY = True
       COEFFICIENT_STUDY_TYPE = 'cross_dataset'

================================================================================
"""

import os
import sys
import time
import json
from glob import glob
from pathlib import Path
from copy import deepcopy

import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, GlobalAveragePooling2D,
    Reshape, Multiply, BatchNormalization, Activation, Dropout,
    Concatenate, Add, UpSampling2D, LayerNormalization, Dense, Layer
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, CSVLogger, Callback, LearningRateScheduler
)
from tensorflow.keras import backend as K
import albumentations as A
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle, Patch
from scipy import ndimage
from scipy.stats import pearsonr
import pandas as pd

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# ==============================================================================
# 🔥 COEFFICIENT CONFIGURATIONS
# ==============================================================================

class MorphologyCoefficients:
    """
    Morphology modulation coefficients for MASL.
    
    These coefficients define the TEMPLATE for how morphological characteristics
    modulate loss components:
    - α_core      = 1.0 + core_coeff * compactness
    - α_boundary  = 1.0 + boundary_coeff * tubularity + compactness
    - α_structure = 1.0 + structure_coeff * tubularity
    - α_scale     = 1.0 + scale_coeff * irregularity
    - α_texture   = 1.0 + texture_coeff * irregularity
    
    The actual adaptation happens through:
    1. Per-sample modulation (automatic, based on morphology)
    2. Per-dataset learnable weights (optimized during training)
    """
    
    # ==================== BASELINE COEFFICIENTS ====================
    # BASELINE = {
    #     'name': 'BASELINE',
    #     'description': 'Standard coefficients (polyp-tuned)',
    #     'core': 0.5,
    #     'boundary': 1.5,
    #     'structure': 1.0,
    #     'scale': 1.5,
    #     'texture': 1.0
    # }
    
    # # ==================== PERTURBATION STUDIES ====================
    # # Test ±40% perturbation to validate robustness
    
    # PERTURB_CORE_LOW = {
    #     'name': 'PERTURB_CORE_-40%',
    #     'description': 'Core coefficient reduced by 40% (0.5 → 0.3)',
    #     'core': 0.3,      # -40%
    #     'boundary': 1.5,
    #     'structure': 1.0,
    #     'scale': 1.5,
    #     'texture': 1.0
    # }
    
    # PERTURB_CORE_HIGH = {
    #     'name': 'PERTURB_CORE_+40%',
    #     'description': 'Core coefficient increased by 40% (0.5 → 0.7)',
    #     'core': 0.7,      # +40%
    #     'boundary': 1.5,
    #     'structure': 1.0,
    #     'scale': 1.5,
    #     'texture': 1.0
    # }
    
    # PERTURB_BOUNDARY_LOW = {
    #     'name': 'PERTURB_BOUNDARY_-33%',
    #     'description': 'Boundary coefficient reduced by 33% (1.5 → 1.0)',
    #     'core': 0.5,
    #     'boundary': 1.0,  # -33%
    #     'structure': 1.0,
    #     'scale': 1.5,
    #     'texture': 1.0
    # }
    
    PERTURB_BOUNDARY_HIGH = {
        'name': 'PERTURB_BOUNDARY_+40%',
        'description': 'Boundary coefficient increased by 40% (1.5 → 2.1)',
        'core': 0.5,
        'boundary': 2.1,  # +40%
        'structure': 1.0,
        'scale': 1.5,
        'texture': 1.0
    }
    
    PERTURB_STRUCTURE_LOW = {
        'name': 'PERTURB_STRUCTURE_-30%',
        'description': 'Structure coefficient reduced by 30% (1.0 → 0.7)',
        'core': 0.5,
        'boundary': 1.5,
        'structure': 0.7,  # -30%
        'scale': 1.5,
        'texture': 1.0
    }
    
    PERTURB_STRUCTURE_HIGH = {
        'name': 'PERTURB_STRUCTURE_+40%',
        'description': 'Structure coefficient increased by 40% (1.0 → 1.4)',
        'core': 0.5,
        'boundary': 1.5,
        'structure': 1.4,  # +40%
        'scale': 1.5,
        'texture': 1.0
    }
    
    PERTURB_SCALE_LOW = {
        'name': 'PERTURB_SCALE_-33%',
        'description': 'Scale coefficient reduced by 33% (1.5 → 1.0)',
        'core': 0.5,
        'boundary': 1.5,
        'structure': 1.0,
        'scale': 1.0,     # -33%
        'texture': 1.0
    }
    
    PERTURB_SCALE_HIGH = {
        'name': 'PERTURB_SCALE_+40%',
        'description': 'Scale coefficient increased by 40% (1.5 → 2.1)',
        'core': 0.5,
        'boundary': 1.5,
        'structure': 1.0,
        'scale': 2.1,     # +40%
        'texture': 1.0
    }
    
    # ==================== ALTERNATIVE COEFFICIENT SETS ====================
    # Optimized for different anatomical structures
    
    VESSEL_TUNED = {
        'name': 'VESSEL_TUNED',
        'description': 'Optimized for tubular vessels (DRIVE/CHASE)',
        'core': 0.3,
        'boundary': 1.8,  # Emphasize connectivity
        'structure': 1.2,
        'scale': 1.5,
        'texture': 1.0
    }
    
    TUMOR_TUNED = {
        'name': 'TUMOR_TUNED',
        'description': 'Optimized for irregular tumors (BraTS)',
        'core': 0.7,
        'boundary': 1.2,
        'structure': 0.9,
        'scale': 2.0,     # Handle size variation
        'texture': 1.2
    }
    
    INSTRUMENT_TUNED = {
        'name': 'INSTRUMENT_TUNED',
        'description': 'Optimized for surgical instruments (EndoVis)',
        'core': 0.4,
        'boundary': 1.6,
        'structure': 1.1,
        'scale': 1.5,
        'texture': 1.0
    }
    
    @classmethod
    def get_all_configs(cls):
        """Return all coefficient configurations"""
        return {
            # 'baseline': cls.BASELINE,
            # 'perturb_core_low': cls.PERTURB_CORE_LOW,
            # 'perturb_core_high': cls.PERTURB_CORE_HIGH,
            # 'perturb_boundary_low': cls.PERTURB_BOUNDARY_LOW,
            'perturb_boundary_high': cls.PERTURB_BOUNDARY_HIGH,
            'perturb_structure_low': cls.PERTURB_STRUCTURE_LOW,
            'perturb_structure_high': cls.PERTURB_STRUCTURE_HIGH,
            'perturb_scale_low': cls.PERTURB_SCALE_LOW,
            'perturb_scale_high': cls.PERTURB_SCALE_HIGH,
            'vessel_tuned': cls.VESSEL_TUNED,
            'tumor_tuned': cls.TUMOR_TUNED,
            'instrument_tuned': cls.INSTRUMENT_TUNED
        }
    
    @classmethod
    def get_perturbation_configs(cls):
        """Return only perturbation study configurations"""
        return {
            # 'baseline': cls.BASELINE,
            # 'perturb_core_low': cls.PERTURB_CORE_LOW,
            # 'perturb_core_high': cls.PERTURB_CORE_HIGH,
            # 'perturb_boundary_low': cls.PERTURB_BOUNDARY_LOW,
            'perturb_boundary_high': cls.PERTURB_BOUNDARY_HIGH,
            'perturb_structure_low': cls.PERTURB_STRUCTURE_LOW,
            'perturb_structure_high': cls.PERTURB_STRUCTURE_HIGH,
            'perturb_scale_low': cls.PERTURB_SCALE_LOW,
            'perturb_scale_high': cls.PERTURB_SCALE_HIGH
        }
    
    @classmethod
    def get_cross_dataset_configs(cls):
        """Return cross-dataset comparison configurations"""
        return {
            'polyp_tuned': cls.BASELINE,
            'vessel_tuned': cls.VESSEL_TUNED,
            'tumor_tuned': cls.TUMOR_TUNED,
            'instrument_tuned': cls.INSTRUMENT_TUNED
        }

# ==============================================================================
# 🔥 CONFIGURATION WITH COEFFICIENT STUDY SUPPORT
# ==============================================================================

class Config:
    # ==================== COEFFICIENT PERTURBATION STUDY ====================
    RUN_COEFFICIENT_STUDY = True  # ⚡ SET TO True TO RUN STUDY
    COEFFICIENT_STUDY_TYPE = 'all'  # 'perturbation', 'cross_dataset', or 'all'
    SELECTED_COEFFICIENT_CONFIG = 'baseline'  # Which config to use for single run
    
    # ==================== GPU CONFIGURATION ====================
    GPU_NUMBERS = [0]
    
    # DATA PATHS
    DATA_ROOT = "/kaggle/input/chase-db/chase_data"  # ⚡ CHANGE THIS
    TRAIN_DIR = os.path.join(DATA_ROOT, "train")
    VAL_DIR   = os.path.join(DATA_ROOT, "val")
    TEST_DIR  = os.path.join(DATA_ROOT, "test")
    SAVE_DIR  = "/kaggle/working/MASL_COEFFICIENT_STUDY"
    
    # ==================== PATCH TRAINING SETTINGS ====================
    USE_PATCH_TRAINING = True
    PATCH_SIZE = 256
    STRIDE = 32
    PATCHES_PER_EPOCH = 6000
    MIN_VESSEL_RATIO = 0.005
    
    # ==================== MODEL ARCHITECTURE ====================
    INPUT_SIZE = 256
    F1, F2, F3, F4, F5 = 24, 32, 64, 80, 128
    
    USE_MRF_SE = True
    USE_SSTM   = True
    USE_BFP    = True
    
    MRF_KERNELS = [3, 5, 7]
    SE_REDUCTION = 16
    EXPAND_RATIO = 6
    DROPOUT = 0.1
    L2_REG = 1e-4
    
    SSTM_NUM_FREQUENCIES = 32
    SSTM_SSM_STATE_DIM = 16
    SSTM_USE_SPECTRAL = [True, True, True, True, True]
    SSTM_USE_SSM = [False, False, True, True, True]
    SSTM_DROPOUT = 0.1

    # ==================== TRAINING SETTINGS ====================
    BATCH_SIZE = 16
    EPOCHS = 50  # Reduced for coefficient study (use 350 for full training)
    LEARNING_RATE = 5e-4
    
    EARLY_STOPPING_PATIENCE = 30
    CHECKPOINT_MONITOR = "val_dice_coefficient"
    CHECKPOINT_MODE = "max"

    # ==================== PREPROCESSING ====================
    USE_CLAHE = True
    USE_GREEN_CHANNEL = True
    CLAHE_CLIP_LIMIT = 2.0
    CLAHE_TILE_GRID_SIZE = (8, 8)
    
    # ==================== TEST-TIME AUGMENTATION ====================
    USE_TTA = True
    TTA_AUGMENTATIONS = 8
    
    # ==================== FOV MASKING ====================
    USE_FOV_MASK = True
    FOV_MARGIN = 20

    # ==================== VISUALIZATION SETTINGS ====================
    SAVE_PREDICTIONS = True
    SAVE_OVERLAYS = True
    VIZ_DIR = os.path.join(SAVE_DIR, "visualizations")

    SEED = 42
    DETERMINISTIC = False

    def __init__(self):
        os.makedirs(self.SAVE_DIR, exist_ok=True)
        os.makedirs(self.VIZ_DIR, exist_ok=True)
        
        # Get coefficient configuration
        all_configs = MorphologyCoefficients.get_all_configs()
        if self.SELECTED_COEFFICIENT_CONFIG in all_configs:
            self.COEFFICIENT_CONFIG = all_configs[self.SELECTED_COEFFICIENT_CONFIG]
        else:
            print(f"⚠️ Unknown coefficient config: {self.SELECTED_COEFFICIENT_CONFIG}")
            print("   Using BASELINE")
            self.COEFFICIENT_CONFIG = MorphologyCoefficients.BASELINE
        
        print(f"🔥 MASL TRAINING WITH COEFFICIENT PERTURBATION STUDY")
        print(f"   Coefficient Study: {'ENABLED' if self.RUN_COEFFICIENT_STUDY else 'DISABLED'}")
        if self.RUN_COEFFICIENT_STUDY:
            print(f"   Study Type: {self.COEFFICIENT_STUDY_TYPE.upper()}")
        else:
            print(f"   Selected Config: {self.COEFFICIENT_CONFIG['name']}")
            print(f"   Coefficients: core={self.COEFFICIENT_CONFIG['core']:.1f}, "
                  f"boundary={self.COEFFICIENT_CONFIG['boundary']:.1f}, "
                  f"structure={self.COEFFICIENT_CONFIG['structure']:.1f}, "
                  f"scale={self.COEFFICIENT_CONFIG['scale']:.1f}, "
                  f"texture={self.COEFFICIENT_CONFIG['texture']:.1f}")

config = Config()

# ==============================================================================
# GPU SETUP
# ==============================================================================

def setup_gpus(gpu_numbers=None):
    gpus = tf.config.list_physical_devices('GPU')
    if not gpus:
        print("⚠️ No GPUs found! Using CPU.")
        return tf.distribute.get_strategy(), 0
    
    print(f"🔍 Total GPUs available: {len(gpus)}")
    
    if gpu_numbers is not None:
        selected_gpus = [gpus[i] for i in gpu_numbers if i < len(gpus)]
    else:
        selected_gpus = gpus
    
    try:
        tf.config.set_visible_devices(selected_gpus, 'GPU')
        for gpu in selected_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        
        num_gpus = len(selected_gpus)
        strategy = tf.distribute.MirroredStrategy() if num_gpus > 1 else tf.distribute.get_strategy()
        print(f"✅ Using {num_gpus} GPU(s)")
        return strategy, num_gpus
    except RuntimeError as e:
        print(f"⚠️ GPU setup error: {e}")
        return tf.distribute.get_strategy(), 0

strategy, num_gpus = setup_gpus(config.GPU_NUMBERS)

# ==============================================================================
# PREPROCESSING FUNCTIONS
# ==============================================================================

def preprocess_image_clahe(image, cfg):
    """Apply CLAHE preprocessing optimized for retinal vessel segmentation."""
    if cfg.USE_GREEN_CHANNEL:
        if len(image.shape) == 3:
            green_channel = image[:, :, 1]
        else:
            green_channel = image
    else:
        green_channel = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if len(image.shape) == 3 else image
    
    if cfg.USE_CLAHE:
        clahe = cv2.createCLAHE(
            clipLimit=cfg.CLAHE_CLIP_LIMIT,
            tileGridSize=cfg.CLAHE_TILE_GRID_SIZE
        )
        enhanced = clahe.apply(green_channel.astype(np.uint8))
    else:
        enhanced = green_channel
    
    enhanced = enhanced.astype(np.float32) / 255.0
    enhanced_rgb = np.stack([enhanced, enhanced, enhanced], axis=-1)
    
    return enhanced_rgb

def apply_fov_mask(image, mask, cfg):
    """Apply circular FOV mask for retinal images (DRIVE dataset)."""
    if not cfg.USE_FOV_MASK:
        return image, mask
    
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    radius = min(h, w) // 2 - cfg.FOV_MARGIN
    
    y, x = np.ogrid[:h, :w]
    mask_region = (x - center[0])**2 + (y - center[1])**2 <= radius**2
    
    if len(image.shape) == 3:
        image = image * mask_region[:, :, np.newaxis]
    else:
        image = image * mask_region
    
    mask = mask * mask_region
    
    return image, mask

# ==============================================================================
# UTILS
# ==============================================================================

def set_seed(seed=42, deterministic=False):
    import random
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_image_mask_pairs(images_dir, masks_dir):
    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.bmp']
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob(os.path.join(images_dir, ext)))
    image_files = sorted(image_files)
    
    if len(image_files) == 0:
        print(f"⚠️ No images found in {images_dir}")
        return []

    pairs = []
    for img_path in image_files:
        img_name = Path(img_path).stem
        possible_names = [
            f"{img_name}.png", f"{img_name}.jpg", f"{img_name}.tif", 
            f"{img_name}_mask.png", f"{img_name}_mask.jpg"
        ]
        for mask_name in possible_names:
            cand = os.path.join(masks_dir, mask_name)
            if os.path.exists(cand):
                pairs.append((img_path, cand))
                break
    return pairs

def load_dataset_split(split_dir):
    images_dir = os.path.join(split_dir, "images")
    masks_dir  = os.path.join(split_dir, "masks")
    return get_image_mask_pairs(images_dir, masks_dir)

# ==============================================================================
# PATCH EXTRACTION
# ==============================================================================

def extract_patches_from_image(image, mask, cfg):
    """Extract overlapping patches from a single image."""
    h, w = image.shape[:2]
    patches = []
    
    for y in range(0, h - cfg.PATCH_SIZE + 1, cfg.STRIDE):
        for x in range(0, w - cfg.PATCH_SIZE + 1, cfg.STRIDE):
            img_patch = image[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE]
            mask_patch = mask[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE]
            
            vessel_ratio = np.sum(mask_patch) / (cfg.PATCH_SIZE * cfg.PATCH_SIZE)
            
            if vessel_ratio >= cfg.MIN_VESSEL_RATIO:
                patches.append((img_patch, mask_patch))
    
    return patches

def extract_all_patches(pairs, cfg):
    """Pre-extract all patches from all images."""
    print(f"\n🔍 Extracting patches from {len(pairs)} images...")
    all_patches = []
    
    for img_path, mask_path in pairs:
        image = cv2.imread(img_path)
        if image is None:
            continue
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)
        
        image = preprocess_image_clahe(image, cfg)
        image, mask = apply_fov_mask(image, mask, cfg)
        
        patches = extract_patches_from_image(image, mask, cfg)
        all_patches.extend(patches)
    
    print(f"✅ Extracted {len(all_patches)} patches")
    return all_patches

# ==============================================================================
# AUGMENTATION
# ==============================================================================

def get_patch_augmentation(cfg):
    """Optimized augmentation for patch-based training."""
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=180, border_mode=cv2.BORDER_REFLECT_101, p=0.9),
        A.ElasticTransform(alpha=25, sigma=4, alpha_affine=4, border_mode=cv2.BORDER_REFLECT_101, p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.1, border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.6),
        A.GaussNoise(var_limit=(5.0, 15.0), p=0.2),
        A.GaussianBlur(blur_limit=(3, 5), p=0.1),
    ], p=1.0)

# ==============================================================================
# DATA GENERATORS
# ==============================================================================

class PatchBasedGenerator(tf.keras.utils.Sequence):
    """Generator for patch-based training."""
    def __init__(self, all_patches, cfg, augmentation=None, shuffle=True):
        self.all_patches = all_patches
        self.cfg = cfg
        self.augmentation = augmentation
        self.shuffle = shuffle
        self.steps_per_epoch = cfg.PATCHES_PER_EPOCH // cfg.BATCH_SIZE
        
        print(f"   Patch Generator: {len(all_patches)} total patches")
        print(f"   Steps per epoch: {self.steps_per_epoch}")
        print(f"   Patches per epoch: {self.steps_per_epoch * cfg.BATCH_SIZE}")
            
    def __len__(self):
        return self.steps_per_epoch

    def __getitem__(self, index):
        indices = np.random.choice(len(self.all_patches), self.cfg.BATCH_SIZE, replace=False)
        
        images, masks = [], []
        for idx in indices:
            img_patch, mask_patch = self.all_patches[idx]
            
            if self.augmentation:
                augmented = self.augmentation(image=img_patch, mask=mask_patch)
                img_patch = augmented["image"]
                mask_patch = augmented["mask"]
            
            if len(mask_patch.shape) == 2:
                mask_patch = np.expand_dims(mask_patch, axis=-1)
            
            images.append(img_patch)
            masks.append(mask_patch)
            
        return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)

    def on_epoch_end(self):
        pass

# ==============================================================================
# MODEL ARCHITECTURE (MEDSEGNET-SSF)
# ==============================================================================

class SpectralSelectiveTokenMixer(Layer):
    def __init__(self, channels, num_frequencies=32, ssm_state_dim=16, 
                 use_spectral=True, use_ssm=True, dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_frequencies = num_frequencies
        self.ssm_state_dim = ssm_state_dim
        self.use_spectral = use_spectral
        self.use_ssm = use_ssm
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        input_h, input_w = input_shape[1], input_shape[2]
        self.actual_frequencies = min(self.num_frequencies, input_h, input_w) if input_h else self.num_frequencies
        
        if self.use_spectral:
            self.freq_weights_real = self.add_weight(
                name='freq_weights_real',
                shape=(self.actual_frequencies, self.actual_frequencies, self.channels),
                initializer=self._get_initializer(),
                trainable=True
            )
            self.spectral_norm = LayerNormalization(epsilon=1e-6, name='spectral_norm')
        
        if self.use_ssm:
            self.ssm_C = Dense(self.channels, name='ssm_C')
            self.selection_gate = Dense(self.channels, activation='sigmoid', name='selection')
            self.ssm_norm = LayerNormalization(epsilon=1e-6, name='ssm_norm')
        
        if self.use_spectral and self.use_ssm:
            self.fusion = Dense(self.channels, name='fusion')
            self.fusion_norm = LayerNormalization(epsilon=1e-6, name='fusion_norm')
        
        self.norm = LayerNormalization(epsilon=1e-6, name='norm')
        super().build(input_shape)
    
    def _get_initializer(self):
        def init_fn(shape, dtype=None):
            H, W, C = shape
            freq_h = np.fft.fftfreq(H)[:, np.newaxis]
            freq_w = np.fft.fftfreq(W)[np.newaxis, :]
            freq_magnitude = np.sqrt(freq_h**2 + freq_w**2)
            gaussian = np.exp(-((freq_magnitude - 0.25)**2) / (2 * 0.15**2))
            gaussian = np.repeat(gaussian[:, :, np.newaxis], C, axis=2)
            return gaussian.astype(np.float32) * 0.5
        return init_fn
    
    def spectral_path(self, x):
        H, W = tf.shape(x)[1], tf.shape(x)[2]
        freq_size = tf.minimum(tf.minimum(H, W), self.actual_frequencies)
        x_complex = tf.cast(x, tf.complex64)
        x_freq = tf.signal.fft2d(x_complex)
        x_freq_real = tf.math.real(x_freq)
        x_freq_imag = tf.math.imag(x_freq)
        x_freq_real_resized = tf.image.resize(x_freq_real, [freq_size, freq_size], method='bilinear')
        x_freq_imag_resized = tf.image.resize(x_freq_imag, [freq_size, freq_size], method='bilinear')
        x_freq_resized = tf.complex(x_freq_real_resized, x_freq_imag_resized)
        freq_filter = tf.cast(self.freq_weights_real[:freq_size, :freq_size, :], tf.complex64)
        x_freq_filtered = x_freq_resized * freq_filter
        x_freq_filt_real = tf.math.real(x_freq_filtered)
        x_freq_filt_imag = tf.math.imag(x_freq_filtered)
        x_freq_back_real = tf.image.resize(x_freq_filt_real, [H, W], method='bilinear')
        x_freq_back_imag = tf.image.resize(x_freq_filt_imag, [H, W], method='bilinear')
        x_freq_back = tf.complex(x_freq_back_real, x_freq_back_imag)
        x_spatial = tf.signal.ifft2d(x_freq_back)
        return self.spectral_norm(tf.math.real(x_spatial))
    
    def ssm_path(self, x):
        B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x_flat = tf.reshape(x, [B, H * W, self.channels])
        selection = self.selection_gate(x_flat)
        x_selected = x_flat * selection
        x_ssm = self.ssm_C(x_selected)
        return self.ssm_norm(tf.reshape(x_ssm, [B, H, W, self.channels]))
    
    def call(self, x, training=None):
        outputs = []
        if self.use_spectral:
            outputs.append(self.spectral_path(x))
        if self.use_ssm:
            outputs.append(self.ssm_path(x))
        
        if len(outputs) == 2:
            fused = self.fusion_norm(self.fusion(tf.concat(outputs, axis=-1)))
        elif len(outputs) == 1:
            fused = outputs[0]
        else:
            fused = x
        
        fused = self.norm(fused)
        if training and self.dropout_rate > 0:
            fused = tf.nn.dropout(fused, rate=self.dropout_rate)
        return x + fused

def MRF_SE_BLOCK(x, filters, activation='elu', dropout=0.0, expand_ratio=6, 
                 regularizer=0.0, kernels=[3, 5, 7], se_reduction=16, name='mrf_se'):
    F_expanded = filters * expand_ratio
    
    conv = Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                  name=name+'_expand')(x) if expand_ratio > 1 else x
    conv = Activation(activation, name=name+'_expand_act')(BatchNormalization(name=name+'_expand_bn')(conv))
    
    features = []
    for k in kernels:
        dw = DepthwiseConv2D((k, k), padding='same', depthwise_initializer='he_uniform',
                            depthwise_regularizer=l2(regularizer) if regularizer > 0 else None,
                            name=f"{name}_dw{k}x{k}")(conv)
        features.append(Activation(activation, name=f"{name}_dw{k}x{k}_act")(
            BatchNormalization(name=f"{name}_dw{k}x{k}_bn")(dw)))
    
    combined = Concatenate(name=name+'_concat')(features) if len(features) > 1 else features[0]
    if len(features) > 1:
        combined = Activation(activation, name=name+'_fuse_act')(
            BatchNormalization(name=name+'_fuse_bn')(
                Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_fuse')(combined)))
    
    gap = Reshape((1, 1, F_expanded), name=name+'_reshape')(GlobalAveragePooling2D(name=name+'_gap')(combined))
    se = Conv2D(F_expanded, (1, 1), activation='sigmoid', kernel_initializer='he_uniform', name=name+'_se_expand')(
        Conv2D(max(F_expanded//se_reduction, 8), (1, 1), activation=activation, kernel_initializer='he_uniform',
              name=name+'_se_reduce')(gap))
    
    projected = Conv2D(filters, (1, 1), padding='same', kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_project')(Multiply(name=name+'_se_mult')([combined, se]))
    projected = BatchNormalization(name=name+'_project_bn')(projected)
    
    if dropout > 0:
        projected = Dropout(dropout, name=name+'_dropout')(projected)
    
    out = Add(name=name+'_add')([projected, x])
    return out

def boundary_detection_module(features, filters, name='boundary'):
    boundary_conv = Conv2D(filters // 2, (3, 3), padding='same', activation='relu', name=name + '_conv')(features)
    boundary_map = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name=name + '_map')(boundary_conv)
    return Multiply(name=name + '_mult')([features, boundary_map]), boundary_map

def BFP_decoder_stage(decoder_input, skip_features, filters, stage_name='bfp'):
    region = Concatenate(name=stage_name+'_concat')([
        UpSampling2D((2, 2), name=stage_name+'_up')(decoder_input),
        skip_features
    ])
    
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv1')(region)))
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv2')(region)))
    
    boundary_features, boundary_map = boundary_detection_module(region, filters, stage_name+'_boundary')
    
    boundary_refined = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_boundary_refine')(boundary_features)))
    
    output = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (1, 1), padding='same', name=stage_name+'_fusion')(
            region * (1 - boundary_map) + boundary_refined * boundary_map)))
    
    return output, boundary_map

def build_medsegnet_ssf(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING MED-SEGNET-SSF")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    x = Conv2D(16, (3, 3), padding='same', kernel_initializer='he_uniform', name='stem_conv')(inp)
    x = BatchNormalization(name='stem_bn')(x)
    x = Activation('elu', name='stem_act')(x)
    
    encoder_outputs = []
    filters = [cfg.F1, cfg.F2, cfg.F3, cfg.F4, cfg.F5]
    
    for i, f in enumerate(filters):
        x = Conv2D(f, (3, 3), strides=2, padding='same', kernel_initializer='he_uniform')(x)
        x = BatchNormalization()(x)
        x = Activation('elu')(x)
        
        if cfg.USE_MRF_SE:
            x = MRF_SE_BLOCK(x, f, activation='elu', dropout=cfg.DROPOUT,
                           expand_ratio=cfg.EXPAND_RATIO, regularizer=cfg.L2_REG,
                           kernels=cfg.MRF_KERNELS, se_reduction=cfg.SE_REDUCTION,
                           name=f'mrfse_stage{i+1}')
        
        if cfg.USE_SSTM:
            x = SpectralSelectiveTokenMixer(
                channels=f, num_frequencies=cfg.SSTM_NUM_FREQUENCIES,
                ssm_state_dim=cfg.SSTM_SSM_STATE_DIM,
                use_spectral=cfg.SSTM_USE_SPECTRAL[i],
                use_ssm=cfg.SSTM_USE_SSM[i],
                dropout=cfg.SSTM_DROPOUT,
                name=f'sstm_stage{i+1}'
            )(x)
        
        encoder_outputs.append(x)
    
    skip_connections = encoder_outputs[::-1]
    decoder = skip_connections[0]
    decoder_filters = filters[::-1][1:] + [16]
    
    for i, (skip, f) in enumerate(zip(skip_connections[1:], decoder_filters)):
        if cfg.USE_BFP:
            decoder, _ = BFP_decoder_stage(decoder, skip, f, stage_name=f'bfp_stage{i+1}')
    
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)
    decoder = Conv2D(16, (3, 3), padding='same', activation='relu')(decoder)
    out = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name='output')(decoder)
    
    model = Model(inputs=inp, outputs=out, name="MedSegNet_SSF")
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# 🔥 MODIFIED: MASL LOSS FUNCTION WITH COEFFICIENT SUPPORT
# ==============================================================================

class ClipConstraint(tf.keras.constraints.Constraint):
    def __init__(self, min_value=0.1, max_value=10.0):
        self.min_value = min_value
        self.max_value = max_value
    
    def __call__(self, w):
        return tf.clip_by_value(w, self.min_value, self.max_value)
    
    def get_config(self):
        return {'min_value': self.min_value, 'max_value': self.max_value}

class MorphologyAwareAdaptiveLoss(Layer):
    def __init__(self, coefficient_config, name='masl', **kwargs):
        """
        Initialize MASL with specific morphology coefficients.
        
        Args:
            coefficient_config: Dictionary containing coefficient values
        """
        super().__init__(name=name, **kwargs)
        self.epsilon = 1e-6
        
        # 🔥 STORE CONFIGURABLE COEFFICIENTS
        self.coeff_core = float(coefficient_config['core'])
        self.coeff_boundary = float(coefficient_config['boundary'])
        self.coeff_structure = float(coefficient_config['structure'])
        self.coeff_scale = float(coefficient_config['scale'])
        self.coeff_texture = float(coefficient_config['texture'])
        
        print(f"\n🔥 MASL Initialized with coefficients:")
        print(f"   Core:      {self.coeff_core:.2f}")
        print(f"   Boundary:  {self.coeff_boundary:.2f}")
        print(f"   Structure: {self.coeff_structure:.2f}")
        print(f"   Scale:     {self.coeff_scale:.2f}")
        print(f"   Texture:   {self.coeff_texture:.2f}")
        
    def build(self, input_shape):
        clip_constraint = ClipConstraint(min_value=0.1, max_value=10.0)
        
        self.w_region = self.add_weight(name='w_region', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint)
        self.w_boundary = self.add_weight(name='w_boundary', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint)
        self.w_structure = self.add_weight(name='w_structure', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint)
        self.w_scale = self.add_weight(name='w_scale', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint)
        self.w_texture = self.add_weight(name='w_texture', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint)
        super().build(input_shape)
    
    def morphological_dilation(self, x, kernel_size=5):
        return tf.nn.max_pool2d(x, kernel_size, strides=1, padding='SAME')
    
    def morphological_erosion(self, x, kernel_size=5):
        return -tf.nn.max_pool2d(-x, kernel_size, strides=1, padding='SAME')
    
    def detect_boundary(self, mask, kernel_size=5):
        dilated = self.morphological_dilation(mask, kernel_size)
        eroded = self.morphological_erosion(mask, kernel_size)
        boundary = dilated - eroded
        return tf.clip_by_value(boundary, 0.0, 1.0)
    
    def analyze_structure_characteristics(self, y_true):
        area = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        total_pixels = tf.cast(tf.shape(y_true)[1] * tf.shape(y_true)[2], tf.float32)
        
        dy = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_padded = tf.pad(dy, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_padded = tf.pad(dx, [[0, 0], [0, 0], [0, 1], [0, 0]])
        gradient_mag = tf.sqrt(dy_padded**2 + dx_padded**2 + self.epsilon)
        perimeter = tf.reduce_sum(gradient_mag, axis=[1, 2, 3]) + self.epsilon
        
        skeleton_approx = self.morphological_erosion(y_true, kernel_size=3)
        skeleton_area = tf.reduce_sum(skeleton_approx, axis=[1, 2, 3]) + self.epsilon
        
        tubularity = tf.reduce_mean(skeleton_area / (area + self.epsilon))
        compactness = tf.reduce_mean((4 * 3.14159 * area) / (perimeter**2 + self.epsilon))
        compactness = tf.clip_by_value(compactness, 0.0, 1.0)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        ddy = boundary[:, 2:, :, :] - 2*boundary[:, 1:-1, :, :] + boundary[:, :-2, :, :]
        ddx = boundary[:, :, 2:, :] - 2*boundary[:, :, 1:-1, :] + boundary[:, :, :-2, :]
        irregularity = tf.reduce_mean(tf.abs(ddy)) + tf.reduce_mean(tf.abs(ddx))
        
        object_size = tf.reduce_mean(area / total_pixels)
        
        return {
            'tubularity': tf.clip_by_value(tubularity, 0.0, 1.0),
            'compactness': compactness,
            'irregularity': tf.clip_by_value(irregularity, 0.0, 1.0),
            'object_size': tf.clip_by_value(object_size, 0.0, 1.0)
        }
    
    def core_loss(self, y_true, y_pred):
        intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
        dice = (2. * intersection + self.epsilon) / (
            tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon)
        dice_loss = 1.0 - tf.reduce_mean(dice)
        
        union = (tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection)
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        iou_loss = 1.0 - tf.reduce_mean(iou)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        weights = 1.0 + 5.0 * boundary
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        weighted_bce = tf.reduce_mean(weights * bce)
        
        return 0.4 * dice_loss + 0.3 * iou_loss + 0.3 * weighted_bce
    
    def boundary_loss(self, y_true, y_pred):
        total_loss = 0.0
        weights = [0.5, 0.3, 0.2]
        
        for scale, w in zip([1, 2, 4], weights):
            dy_true = y_true[:, scale:, :, :] - y_true[:, :-scale, :, :]
            dy_pred = y_pred[:, scale:, :, :] - y_pred[:, :-scale, :, :]
            dx_true = y_true[:, :, scale:, :] - y_true[:, :, :-scale, :]
            dx_pred = y_pred[:, :, scale:, :] - y_pred[:, :, :-scale, :]
            total_loss += w * (tf.reduce_mean(tf.abs(dy_true - dy_pred)) + tf.reduce_mean(tf.abs(dx_true - dx_pred)))
        
        return total_loss
    
    def structure_aware_loss(self, y_true, y_pred, characteristics):
        area_true = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        dy_true = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx_true = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_true_padded = tf.pad(dy_true, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_true_padded = tf.pad(dx_true, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_true = tf.reduce_sum(tf.sqrt(dy_true_padded**2 + dx_true_padded**2 + self.epsilon), axis=[1, 2, 3]) + self.epsilon
        
        area_pred = tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon
        dy_pred = y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]
        dx_pred = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
        dy_pred_padded = tf.pad(dy_pred, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_pred_padded = tf.pad(dx_pred, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_pred = tf.reduce_sum(tf.sqrt(dy_pred_padded**2 + dx_pred_padded**2 + self.epsilon), axis=[1, 2, 3]) + self.epsilon
        
        compact_true = area_true / (perimeter_true**2 + self.epsilon)
        compact_pred = area_pred / (perimeter_pred**2 + self.epsilon)
        
        return characteristics['compactness'] * tf.reduce_mean(tf.abs(compact_true - compact_pred))
    
    def scale_aware_focal_loss(self, y_true, y_pred, characteristics):
        size = characteristics['object_size']
        gamma = tf.cond(size < 0.05, lambda: 3.0, lambda: tf.cond(size < 0.2, lambda: 2.0, lambda: 1.5))
        
        p = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_weight = tf.pow(1 - p, gamma)
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        
        return tf.reduce_mean(focal_weight * bce)
    
    def texture_aware_loss(self, y_true, y_pred):
        ddy_true = y_true[:, 2:, :, :] - 2*y_true[:, 1:-1, :, :] + y_true[:, :-2, :, :]
        ddy_pred = y_pred[:, 2:, :, :] - 2*y_pred[:, 1:-1, :, :] + y_pred[:, :-2, :, :]
        ddx_true = y_true[:, :, 2:, :] - 2*y_true[:, :, 1:-1, :] + y_true[:, :, :-2, :]
        ddx_pred = y_pred[:, :, 2:, :] - 2*y_pred[:, :, 1:-1, :] + y_pred[:, :, :-2, :]
        
        return tf.reduce_mean(tf.abs(ddy_true - ddy_pred)) + tf.reduce_mean(tf.abs(ddx_true - ddx_pred))
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        characteristics = self.analyze_structure_characteristics(y_true)
        
        # 🔥 USE CONFIGURABLE COEFFICIENTS FOR MODULATION
        alpha_region = 1.0 + self.coeff_core * characteristics['compactness']
        alpha_boundary = 1.0 + self.coeff_boundary * characteristics['tubularity'] + characteristics['compactness']
        alpha_structure = 1.0 + self.coeff_structure * characteristics['tubularity']
        alpha_scale = 1.0 + self.coeff_scale * characteristics['irregularity']
        alpha_texture = 1.0 + self.coeff_texture * characteristics['irregularity']
        
        l_core = self.core_loss(y_true, y_pred)
        l_boundary = self.boundary_loss(y_true, y_pred)
        l_structure = self.structure_aware_loss(y_true, y_pred, characteristics)
        l_scale = self.scale_aware_focal_loss(y_true, y_pred, characteristics)
        l_texture = self.texture_aware_loss(y_true, y_pred)
        
        weighted_core = self.w_region * alpha_region * l_core
        weighted_boundary = self.w_boundary * alpha_boundary * l_boundary
        weighted_structure = self.w_structure * alpha_structure * l_structure
        weighted_scale = self.w_scale * alpha_scale * l_scale
        weighted_texture = self.w_texture * alpha_texture * l_texture
        
        total_weight = (self.w_region * alpha_region + self.w_boundary * alpha_boundary + 
                       self.w_structure * alpha_structure + self.w_scale * alpha_scale + self.w_texture * alpha_texture)
        
        masl_loss = (weighted_core + weighted_boundary + weighted_structure + weighted_scale + weighted_texture) / (total_weight + self.epsilon)
        
        return masl_loss
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'coeff_core': float(self.coeff_core),
            'coeff_boundary': float(self.coeff_boundary),
            'coeff_structure': float(self.coeff_structure),
            'coeff_scale': float(self.coeff_scale),
            'coeff_texture': float(self.coeff_texture)
        })
        return config

# Global MASL instance (will be reset for each training run)
_masl_instance = None

def masl_loss_fn(y_true, y_pred):
    global _masl_instance
    if _masl_instance is None:
        raise RuntimeError("MASL instance not initialized! Call initialize_masl() first.")
    return _masl_instance(y_true, y_pred)

def initialize_masl(coefficient_config):
    """Initialize global MASL instance with specific coefficients"""
    global _masl_instance
    _masl_instance = MorphologyAwareAdaptiveLoss(coefficient_config)
    # Build the layer with a dummy input shape
    _masl_instance.build((None, 256, 256, 1))
    return _masl_instance

# ==============================================================================
# METRICS
# ==============================================================================

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def precision_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    predicted_positives = K.sum(y_pred_bin)
    return true_positives / (predicted_positives + K.epsilon())

def recall_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    actual_positives = K.sum(y_true)
    return true_positives / (actual_positives + K.epsilon())

# ==============================================================================
# PATCH-BASED PREDICTION
# ==============================================================================

def _predict_patches_single(model, image, cfg):
    """Helper: Predict full image using overlapping patches (no TTA)."""
    h, w = image.shape[:2]
    
    prediction = np.zeros((h, w), dtype=np.float32)
    counts = np.zeros((h, w), dtype=np.float32)
    
    for y in range(0, h - cfg.PATCH_SIZE + 1, cfg.STRIDE):
        for x in range(0, w - cfg.PATCH_SIZE + 1, cfg.STRIDE):
            patch = image[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE]
            pred_patch = model.predict(np.expand_dims(patch, 0), verbose=0)[0]
            pred_patch = pred_patch[:, :, 0]
            
            prediction[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE] += pred_patch
            counts[y:y+cfg.PATCH_SIZE, x:x+cfg.PATCH_SIZE] += 1
    
    prediction = prediction / (counts + 1e-6)
    return prediction

def predict_full_image_with_patches(model, image, cfg, use_tta=False):
    """Predict on full-resolution image using overlapping patches."""
    if not use_tta:
        return _predict_patches_single(model, image, cfg)
    
    else:
        predictions = []
        
        # 8-way TTA
        predictions.append(_predict_patches_single(model, image, cfg))
        
        img_h = np.flip(image, axis=1)
        pred_h = _predict_patches_single(model, img_h, cfg)
        predictions.append(np.flip(pred_h, axis=1))
        
        img_v = np.flip(image, axis=0)
        pred_v = _predict_patches_single(model, img_v, cfg)
        predictions.append(np.flip(pred_v, axis=0))
        
        img_hv = np.flip(np.flip(image, 0), 1)
        pred_hv = _predict_patches_single(model, img_hv, cfg)
        predictions.append(np.flip(np.flip(pred_hv, 0), 1))
        
        for k in [1, 2, 3]:
            img_rot = np.rot90(image, k, axes=(0, 1))
            pred_rot = _predict_patches_single(model, img_rot, cfg)
            predictions.append(np.rot90(pred_rot, -k, axes=(0, 1)))
        
        img_diag = np.transpose(image, (1, 0, 2))
        pred_diag = _predict_patches_single(model, img_diag, cfg)
        predictions.append(np.transpose(pred_diag, (1, 0)))
        
        return np.mean(predictions, axis=0)

# ==============================================================================
# VISUALIZATION FUNCTIONS
# ==============================================================================

def save_prediction_image(original_image, ground_truth, prediction, save_path, image_name):
    """Save side-by-side comparison of original, ground truth, and prediction."""
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(original_image)
    axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    axes[1].imshow(ground_truth, cmap='gray')
    axes[1].set_title('Ground Truth', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    axes[2].imshow(prediction, cmap='gray')
    axes[2].set_title('Prediction', fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    overlay = original_image.copy()
    pred_binary = (prediction > 0.5).astype(np.uint8)
    overlay[pred_binary == 1] = [0, 255, 0]
    axes[3].imshow(overlay)
    axes[3].set_title('Overlay (Green=Prediction)', fontsize=14, fontweight='bold')
    axes[3].axis('off')
    
    plt.suptitle(f'{image_name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def save_individual_predictions(prediction, save_dir, image_name):
    """Save individual prediction as grayscale image."""
    pred_uint8 = (prediction * 255).astype(np.uint8)
    save_path = os.path.join(save_dir, f"{image_name}_prediction.png")
    cv2.imwrite(save_path, pred_uint8)
    
    pred_binary = ((prediction > 0.5) * 255).astype(np.uint8)
    save_path_binary = os.path.join(save_dir, f"{image_name}_prediction_binary.png")
    cv2.imwrite(save_path_binary, pred_binary)

# ==============================================================================
# EVALUATION FUNCTIONS
# ==============================================================================

def evaluate_test_set(model, test_pairs, cfg, use_tta=False):
    """Evaluate model on test set"""
    dice_scores = []
    iou_scores = []
    precision_scores = []
    recall_scores = []
    
    for img_path, mask_path in test_pairs:
        image_original = cv2.imread(img_path)
        if image_original is None:
            continue
        image_original_rgb = cv2.cvtColor(image_original, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype(np.float32)
        
        image_preprocessed = preprocess_image_clahe(image_original_rgb, cfg)
        image_preprocessed, mask_masked = apply_fov_mask(image_preprocessed, mask, cfg)
        
        pred = predict_full_image_with_patches(model, image_preprocessed, cfg, use_tta=use_tta)
        pred_binary = (pred > 0.5).astype(np.float32)
        
        intersection = np.sum(mask_masked * pred_binary)
        union = np.sum(mask_masked) + np.sum(pred_binary) - intersection
        
        dice = (2.0 * intersection) / (np.sum(mask_masked) + np.sum(pred_binary) + 1e-6)
        iou = intersection / (union + 1e-6)
        precision = intersection / (np.sum(pred_binary) + 1e-6)
        recall = intersection / (np.sum(mask_masked) + 1e-6)
        
        dice_scores.append(dice)
        iou_scores.append(iou)
        precision_scores.append(precision)
        recall_scores.append(recall)
    
    results = {
        'dice': {'mean': float(np.mean(dice_scores)), 'std': float(np.std(dice_scores))},
        'iou': {'mean': float(np.mean(iou_scores)), 'std': float(np.std(iou_scores))},
        'precision': {'mean': float(np.mean(precision_scores)), 'std': float(np.std(precision_scores))},
        'recall': {'mean': float(np.mean(recall_scores)), 'std': float(np.std(recall_scores))}
    }
    
    return results

# ==============================================================================
# LEARNING RATE SCHEDULER
# ==============================================================================

def cosine_annealing_with_warmup(epoch, lr, total_epochs=150, warmup_epochs=10, min_lr=1e-6):
    """Cosine annealing with warmup"""
    if epoch < warmup_epochs:
        return lr * (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return min_lr + (lr - min_lr) * (1 + np.cos(np.pi * progress)) / 2

# ==============================================================================
# 🔥 SINGLE TRAINING RUN (FOR COEFFICIENT STUDY)
# ==============================================================================

def train_single_config(cfg, strategy, num_gpus, coeff_config, train_patches, val_patches, test_pairs):
    """Train model with specific coefficient configuration"""
    global _masl_instance
    
    print(f"\n🔥 Training with: {coeff_config['name']}")
    print(f"   {coeff_config['description']}")
    
    # Initialize MASL with specific coefficients
    initialize_masl(coeff_config)
    
    # Create generators
    train_aug = get_patch_augmentation(cfg)
    train_gen = PatchBasedGenerator(train_patches, cfg, augmentation=train_aug, shuffle=True)
    val_gen = PatchBasedGenerator(val_patches, cfg, augmentation=None, shuffle=False)
    
    # Build model
    with strategy.scope():
        model = build_medsegnet_ssf(cfg)
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)
        
        model.compile(
            optimizer=optimizer,
            loss=masl_loss_fn,
            metrics=[dice_coefficient, iou_score, precision_metric, recall_metric]
        )
    
    # Callbacks
    lr_scheduler = LearningRateScheduler(
        lambda epoch: cosine_annealing_with_warmup(epoch, cfg.LEARNING_RATE, cfg.EPOCHS, warmup_epochs=10),
        verbose=0
    )
    
    callbacks = [
        ModelCheckpoint(
            os.path.join(cfg.SAVE_DIR, "best_model.h5"),
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            save_best_only=True, verbose=0
        ),
        EarlyStopping(
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            patience=cfg.EARLY_STOPPING_PATIENCE, verbose=0, restore_best_weights=True
        ),
        lr_scheduler
    ]
    
    # Train
    start_time = time.time()
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=cfg.EPOCHS,
        callbacks=callbacks,
        verbose=0  # Suppress training output for cleaner study logs
    )
    training_time = time.time() - start_time
    
    # Evaluate
    print("\n📊 Evaluating...")
    test_results = evaluate_test_set(model, test_pairs, cfg, use_tta=True)
    
    print(f"   Dice: {test_results['dice']['mean']:.4f} ± {test_results['dice']['std']:.4f}")
    print(f"   IoU:  {test_results['iou']['mean']:.4f} ± {test_results['iou']['std']:.4f}")
    
    # Get final learned weights
    final_weights = {
        'w_region': float(_masl_instance.w_region.numpy()),
        'w_boundary': float(_masl_instance.w_boundary.numpy()),
        'w_structure': float(_masl_instance.w_structure.numpy()),
        'w_scale': float(_masl_instance.w_scale.numpy()),
        'w_texture': float(_masl_instance.w_texture.numpy())
    }
    
    results = {
        'test_results': test_results,
        'training_time_minutes': training_time / 60,
        'final_val_dice': float(max(history.history['val_dice_coefficient'])),
        'learned_weights': final_weights,
        'final_dice': test_results['dice']['mean']
    }
    
    # Clean up
    del model
    K.clear_session()
    tf.keras.backend.clear_session()
    
    return results

# ==============================================================================
# 🔥 COEFFICIENT PERTURBATION STUDY
# ==============================================================================

def run_coefficient_study(cfg, strategy, num_gpus):
    """Run comprehensive coefficient perturbation study"""
    
    print("\n" + "="*80)
    print("🔥 COEFFICIENT PERTURBATION STUDY")
    print("="*80)
    
    # Load data once
    train_pairs = load_dataset_split(cfg.TRAIN_DIR)
    val_pairs = load_dataset_split(cfg.VAL_DIR)
    test_pairs = load_dataset_split(cfg.TEST_DIR)
    
    print(f"\n📊 Dataset: {len(train_pairs)} train, {len(val_pairs)} val, {len(test_pairs)} test")
    
    # Extract patches once
    print("\n🔍 Extracting patches...")
    train_patches = extract_all_patches(train_pairs, cfg)
    val_patches = extract_all_patches(val_pairs, cfg)
    
    # Determine which configs to test
    if cfg.COEFFICIENT_STUDY_TYPE == 'perturbation':
        configs_to_test = MorphologyCoefficients.get_perturbation_configs()
        print("   Study Type: PERTURBATION (±40% variations)")
    elif cfg.COEFFICIENT_STUDY_TYPE == 'cross_dataset':
        configs_to_test = MorphologyCoefficients.get_cross_dataset_configs()
        print("   Study Type: CROSS-DATASET (different anatomies)")
    elif cfg.COEFFICIENT_STUDY_TYPE == 'all':
        configs_to_test = MorphologyCoefficients.get_all_configs()
        print("   Study Type: COMPREHENSIVE (all configurations)")
    else:
        configs_to_test = {'baseline': MorphologyCoefficients.BASELINE}
        print("   Study Type: SINGLE CONFIG")
    
    print(f"   Total configurations: {len(configs_to_test)}")
    print("="*80)
    
    all_results = {}
    
    for config_name, coeff_config in configs_to_test.items():
        print(f"\n{'='*80}")
        print(f"🔥 TESTING [{len(all_results)+1}/{len(configs_to_test)}]: {coeff_config['name']}")
        print(f"   Coefficients: {{{coeff_config['core']}, {coeff_config['boundary']}, "
              f"{coeff_config['structure']}, {coeff_config['scale']}, {coeff_config['texture']}}}")
        print(f"{'='*80}")
        
        # Update save directory
        cfg_copy = deepcopy(cfg)
        cfg_copy.SAVE_DIR = os.path.join(cfg.SAVE_DIR, config_name)
        os.makedirs(cfg_copy.SAVE_DIR, exist_ok=True)
        
        # Train and evaluate
        try:
            results = train_single_config(
                cfg_copy, strategy, num_gpus, coeff_config,
                train_patches, val_patches, test_pairs
            )
            
            all_results[config_name] = {
                'config': coeff_config,
                'results': results,
                'training_successful': True
            }
            
            print(f"\n✅ {config_name} completed successfully!")
            
        except Exception as e:
            print(f"\n❌ {config_name} failed: {str(e)}")
            all_results[config_name] = {
                'config': coeff_config,
                'error': str(e),
                'training_successful': False
            }
    
    # Generate comparison report
    generate_coefficient_study_report(all_results, cfg)
    
    return all_results

# ==============================================================================
# 🔥 REPORT GENERATION
# ==============================================================================

def generate_coefficient_study_report(all_results, cfg):
    """Generate comprehensive comparison report"""
    
    print("\n" + "="*80)
    print("📊 COEFFICIENT PERTURBATION STUDY REPORT")
    print("="*80)
    
    # Extract successful results
    successful_results = {k: v for k, v in all_results.items() if v.get('training_successful', False)}
    
    if len(successful_results) == 0:
        print("❌ No successful training runs!")
        return
    
    # Find baseline
    baseline_name = 'baseline'
    if baseline_name not in successful_results:
        baseline_name = list(successful_results.keys())[0]
    
    baseline_dice = successful_results[baseline_name]['results']['final_dice']
    
    # Create comparison table
    print(f"\n📊 PERFORMANCE COMPARISON (Baseline: {baseline_name})")
    print("="*80)
    print(f"{'Config':<25} {'Coefficients':<30} {'Dice':<12} {'Δ Dice':<12}")
    print("-"*80)
    
    for config_name, data in successful_results.items():
        coeffs = data['config']
        dice = data['results']['final_dice']
        dice_std = data['results']['test_results']['dice']['std']
        delta = dice - baseline_dice
        
        coeff_str = f"{{{coeffs['core']:.1f}, {coeffs['boundary']:.1f}, {coeffs['structure']:.1f}}}"
        
        print(f"{config_name:<25} {coeff_str:<30} {dice:.4f}±{dice_std:.4f}  {delta:+.4f} ({(delta/baseline_dice)*100:+.2f}%)")
    
    print("="*80)
    
    # Perturbation analysis (if applicable)
    if 'perturb_' in list(successful_results.keys())[0] or len(successful_results) > 4:
        print("\n🔍 PERTURBATION ANALYSIS")
        print("="*80)
        
        max_degradation = 0
        max_config = None
        
        for config_name, data in successful_results.items():
            if config_name == baseline_name:
                continue
            
            dice = data['results']['final_dice']
            degradation = baseline_dice - dice
            
            if abs(degradation) > abs(max_degradation):
                max_degradation = degradation
                max_config = config_name
        
        print(f"Maximum Degradation: {abs(max_degradation):.4f} ({abs(max_degradation)/baseline_dice*100:.2f}%)")
        print(f"   Config: {max_config}")
        
        if abs(max_degradation) < 0.006:
            print("\n✅ EXCELLENT ROBUSTNESS: <0.6% degradation under ±40% perturbation")
        elif abs(max_degradation) < 0.01:
            print("\n✅ GOOD ROBUSTNESS: <1.0% degradation")
        else:
            print("\n⚠️ MODERATE SENSITIVITY: Consider coefficient tuning")
    
    # Save to JSON
    report_data = {
        'baseline_config': baseline_name,
        'baseline_dice': baseline_dice,
        'configs': {}
    }
    
    for config_name, data in successful_results.items():
        report_data['configs'][config_name] = {
            'coefficients': data['config'],
            'dice': data['results']['final_dice'],
            'dice_std': data['results']['test_results']['dice']['std'],
            'delta_vs_baseline': data['results']['final_dice'] - baseline_dice,
            'learned_weights': data['results'].get('learned_weights', {})
        }
    
    report_path = os.path.join(cfg.SAVE_DIR, "coefficient_study_report.json")
    with open(report_path, 'w') as f:
        json.dump(report_data, f, indent=2)
    
    print(f"\n📁 Report saved to: {report_path}")
    print("="*80)

# ==============================================================================
# MAIN TRAINING LOOP (STANDARD - NON-STUDY)
# ==============================================================================

def train_masl_model(cfg, strategy, num_gpus):
    """Standard training loop (single configuration)"""
    
    set_seed(cfg.SEED, cfg.DETERMINISTIC)
    
    # Load data
    train_pairs = load_dataset_split(cfg.TRAIN_DIR)
    val_pairs = load_dataset_split(cfg.VAL_DIR)
    test_pairs = load_dataset_split(cfg.TEST_DIR)
    
    if not train_pairs:
        print("❌ No training data found!")
        return
    
    print(f"\n📊 Dataset Statistics:")
    print(f"   Training images:   {len(train_pairs)}")
    print(f"   Validation images: {len(val_pairs)}")
    print(f"   Test images:       {len(test_pairs)}")
    
    # Extract patches
    print("\n🔥 EXTRACTING PATCHES...")
    train_patches = extract_all_patches(train_pairs, cfg)
    val_patches = extract_all_patches(val_pairs, cfg)
    
    # Initialize MASL
    initialize_masl(cfg.COEFFICIENT_CONFIG)
    
    # Create generators
    train_aug = get_patch_augmentation(cfg)
    train_gen = PatchBasedGenerator(train_patches, cfg, augmentation=train_aug, shuffle=True)
    val_gen = PatchBasedGenerator(val_patches, cfg, augmentation=None, shuffle=False)
    
    # Build model
    with strategy.scope():
        model = build_medsegnet_ssf(cfg)
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)
        
        print("\n📊 Loss Function: MASL")
        print("   - Morphology-Aware Adaptive Segmentation Loss")
        print("   - 5 Components: Core + Boundary + Structure + Scale + Texture")
        
        model.compile(
            optimizer=optimizer,
            loss=masl_loss_fn,
            metrics=[dice_coefficient, iou_score, precision_metric, recall_metric]
        )
    
    # Callbacks
    lr_scheduler = LearningRateScheduler(
        lambda epoch: cosine_annealing_with_warmup(epoch, cfg.LEARNING_RATE, cfg.EPOCHS, warmup_epochs=10),
        verbose=1
    )
    
    callbacks = [
        ModelCheckpoint(
            os.path.join(cfg.SAVE_DIR, "best_model_patch_masl.h5"),
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            save_best_only=True, verbose=1
        ),
        EarlyStopping(
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            patience=cfg.EARLY_STOPPING_PATIENCE, verbose=1, restore_best_weights=True
        ),
        CSVLogger(os.path.join(cfg.SAVE_DIR, "training_log_patch_masl.csv")),
        lr_scheduler
    ]
    
    # Train
    print(f"\n🚀 STARTING TRAINING")
    print("="*80)
    start_time = time.time()
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=cfg.EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    
    training_time = time.time() - start_time
    print(f"\n✅ Training finished in {training_time/60:.1f} minutes")
    
    # Evaluate
    print("\n📊 FINAL EVALUATION")
    print("="*80)
    test_results = evaluate_test_set(model, test_pairs, cfg, use_tta=True)
    
    print(f"\n✅ Test Dice: {test_results['dice']['mean']:.4f} ± {test_results['dice']['std']:.4f}")
    print(f"   Test IoU:  {test_results['iou']['mean']:.4f} ± {test_results['iou']['std']:.4f}")
    
    # Save results
    final_results = {
        "model": "MedSegNet-SSF",
        "loss_function": "MASL",
        "coefficient_config": cfg.COEFFICIENT_CONFIG,
        "training_time_minutes": training_time / 60,
        "test_results": test_results,
        "learned_weights": {
            'w_region': float(_masl_instance.w_region.numpy()),
            'w_boundary': float(_masl_instance.w_boundary.numpy()),
            'w_structure': float(_masl_instance.w_structure.numpy()),
            'w_scale': float(_masl_instance.w_scale.numpy()),
            'w_texture': float(_masl_instance.w_texture.numpy())
        }
    }
    
    with open(os.path.join(cfg.SAVE_DIR, "results.json"), "w") as f:
        json.dump(final_results, f, indent=2)
    
    print(f"\n📁 Results saved to: {cfg.SAVE_DIR}/results.json")
    print("="*80)
    
    return model, history

# ==============================================================================
# MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("🔥 MASL WITH COEFFICIENT PERTURBATION STUDY")
    print("="*80)
    
    if config.RUN_COEFFICIENT_STUDY:
        print("\n⚡ RUNNING COEFFICIENT STUDY")
        print(f"   Study Type: {config.COEFFICIENT_STUDY_TYPE.upper()}")
        print(f"   This will train multiple models with different coefficient sets")
        print("\n⚠️ Note: This may take several hours depending on the number of configurations")
        print("="*80 + "\n")
        
        all_results = run_coefficient_study(config, strategy, num_gpus)
        
        print("\n🎉 COEFFICIENT STUDY COMPLETE!")
        print("="*80)
        print("\n✅ Check coefficient_study_report.json for detailed comparison!")
        
    else:
        print("\n⚡ RUNNING SINGLE TRAINING")
        print(f"   Coefficient Config: {config.COEFFICIENT_CONFIG['name']}")
        print(f"   Coefficients: {{{config.COEFFICIENT_CONFIG['core']}, "
              f"{config.COEFFICIENT_CONFIG['boundary']}, {config.COEFFICIENT_CONFIG['structure']}}}")
        print("\n💡 Tip: Set RUN_COEFFICIENT_STUDY=True to compare multiple coefficient sets")
        print("="*80 + "\n")
        
        model, history = train_masl_model(config, strategy, num_gpus)
        
        print("\n🎉 TRAINING COMPLETE!")
        print("="*80)
        print("\n✅ Check results.json for detailed metrics!")
    
    print("\n🚀 All done! 🚀")