#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
================================================================================
🔥 COMPLETE MASL TRAINING SCRIPT WITH COMPREHENSIVE ABLATION STUDIES
================================================================================

Features:
- Full MASL implementation
- MedSegNet-SSF architecture  
- 23 Ablation Study Configurations
- Config class-based control
- DuckNet data generator with virtual epochs
- Enhanced visualizations
- Prediction saving with running metrics

Usage:
    # Edit Config class, set ABLATION_ID = 0, then run:
    python script.py

Author: Your Name
Date: 2024
================================================================================
"""

import os
import sys
import time
import json
from glob import glob
from pathlib import Path

import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, GlobalAveragePooling2D,
    Reshape, Multiply, BatchNormalization, Activation, Dropout,
    Concatenate, Add, UpSampling2D, LayerNormalization, Dense, Layer
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, CSVLogger, Callback
)
from tensorflow.keras import backend as K
import albumentations as A
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle, Patch
from scipy import ndimage
from scipy.stats import pearsonr
import pandas as pd

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# ==============================================================================
# 🔥 ABLATION STUDY CONFIGURATIONS
# ==============================================================================

ABLATION_CONFIGS = {
    # ==================== BASELINE ====================
    0: {
        "name": "FULL_MODEL_BASELINE",
        "description": "Full Med-SegNet-SSF + Complete MASL (5 components + morphology)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    # ==================== MASL COMPONENT ABLATIONS (1-5) ====================
    1: {
        "name": "MASL_NO_CORE",
        "description": "MASL without Core Loss (L_core)",
        "masl_components": {
            "use_core": False,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    2: {
        "name": "MASL_NO_BOUNDARY",
        "description": "MASL without Boundary Loss (L_bnd)",
        "masl_components": {
            "use_core": True,
            "use_boundary": False,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    3: {
        "name": "MASL_NO_STRUCTURE",
        "description": "MASL without Structure Loss (L_str)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": False,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    4: {
        "name": "MASL_NO_SCALE",
        "description": "MASL without Scale-Aware Focal Loss (L_sca)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": False,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    5: {
        "name": "MASL_NO_TEXTURE",
        "description": "MASL without Texture Loss (L_tex)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": False
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    # ==================== MASL ADAPTATION MECHANISM ABLATIONS (6-7) ====================
    6: {
        "name": "MASL_FIXED_WEIGHTS",
        "description": "MASL with morphology modulation BUT fixed learned weights",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": False,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    7: {
        "name": "MASL_NO_MORPHOLOGY",
        "description": "MASL with learned weights BUT no morphology modulation (α_i = 1)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": False
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    # ==================== SSTM TRUNCATION SIZE ABLATION (8-11) ====================
    8: {
        "name": "SSTM_K16",
        "description": "SSTM with K=16 (lowest complexity)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 16
        },
        "bfp_routing": "soft"
    },
    
    9: {
        "name": "SSTM_K24",
        "description": "SSTM with K=24",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 24
        },
        "bfp_routing": "soft"
    },
    
    10: {
        "name": "SSTM_K48",
        "description": "SSTM with K=48 (higher complexity)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 48
        },
        "bfp_routing": "soft"
    },
    
    11: {
        "name": "SSTM_K64",
        "description": "SSTM with K=64 (highest complexity)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 64
        },
        "bfp_routing": "soft"
    },
    
    # ==================== SSTM STAGE PLACEMENT ABLATION (12-15) ====================
    12: {
        "name": "SSTM_EARLY_ONLY",
        "description": "SSTM only at Stages 1-2 (early/high-resolution)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, False, False, False],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    13: {
        "name": "SSTM_MIDDLE_ONLY",
        "description": "SSTM only at Stage 3 (middle resolution)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [False, False, True, False, False],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    14: {
        "name": "SSTM_LATE_ONLY",
        "description": "SSTM only at Stages 4-5 (late/bottleneck)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [False, False, False, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    15: {
        "name": "NO_SSTM",
        "description": "No SSTM at any stage (baseline)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": False,
            "use_bfp": True,
            "sstm_stages": [False, False, False, False, False],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    # ==================== BFP ROUTING VARIANTS (16-19) ====================
    16: {
        "name": "BFP_HARD_ROUTING",
        "description": "BFP with Hard routing (threshold β at 0.5)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "hard"
    },
    
    17: {
        "name": "BFP_NO_ROUTING",
        "description": "BFP without routing (just concatenate region & boundary)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "none"
    },
    
    18: {
        "name": "BFP_LEARNED_ROUTING",
        "description": "BFP with learned routing weights (instead of β map)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "learned"
    },
    
    19: {
        "name": "NO_BFP",
        "description": "Standard decoder (no BFP module)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": False,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    # ==================== ARCHITECTURE COMPONENT ABLATIONS (20-22) ====================
    20: {
        "name": "NO_MRFSE",
        "description": "No MRF-SE blocks",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": False,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    21: {
        "name": "VANILLA_UNET",
        "description": "Vanilla U-Net (no MRFSE, no SSTM, no BFP)",
        "masl_components": {
            "use_core": True,
            "use_boundary": True,
            "use_structure": True,
            "use_scale": True,
            "use_texture": True
        },
        "masl_settings": {
            "use_learned_weights": True,
            "use_morphology_modulation": True
        },
        "architecture": {
            "use_mrfse": False,
            "use_sstm": False,
            "use_bfp": False,
            "sstm_stages": [False, False, False, False, False],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    },
    
    22: {
        "name": "SIMPLE_DICE_LOSS",
        "description": "Full architecture but simple Dice loss (not MASL)",
        "masl_components": {
            "use_core": True,
            "use_boundary": False,
            "use_structure": False,
            "use_scale": False,
            "use_texture": False
        },
        "masl_settings": {
            "use_learned_weights": False,
            "use_morphology_modulation": False,
            "use_simple_dice": True
        },
        "architecture": {
            "use_mrfse": True,
            "use_sstm": True,
            "use_bfp": True,
            "sstm_stages": [True, True, True, True, True],
            "sstm_k": 32
        },
        "bfp_routing": "soft"
    }
}

def print_ablation_menu():
    """Print available ablation configurations"""
    print("\n" + "="*80)
    print("🔥 AVAILABLE ABLATION STUDIES")
    print("="*80)
    
    categories = {
        "BASELINE": [0],
        "MASL Component Ablations": [1, 2, 3, 4, 5],
        "MASL Adaptation Mechanisms": [6, 7],
        "SSTM Truncation Size (K)": [8, 9, 10, 11],
        "SSTM Stage Placement": [12, 13, 14, 15],
        "BFP Routing Variants": [16, 17, 18, 19],
        "Architecture Components": [20, 21, 22]
    }
    
    for category, ids in categories.items():
        print(f"\n📌 {category}:")
        for id in ids:
            cfg = ABLATION_CONFIGS[id]
            print(f"   [{id:2d}] {cfg['name']:30s} - {cfg['description']}")
    
    print("\n" + "="*80)

# ==============================================================================
# 🔥 CONFIGURATION
# ==============================================================================

class Config:
    # ==================== 🔥 SET ABLATION ID HERE 🔥 ====================
    ABLATION_ID = 19  # ⚡ CHANGE THIS TO RUN DIFFERENT ABLATIONS (0-22)
    # ==================================================================
    
    # ==================== GPU CONFIGURATION ====================
    GPU_NUMBERS = [0]
    
    # DATA PATHS
    DATA_ROOT = "/kaggle/input/ph2-ori/ph2_dataset"
    
    # MODEL ARCHITECTURE
    INPUT_SIZE  = 352
    F1, F2, F3, F4, F5 = 24, 32, 64, 80, 128
    
    # Hyperparameters
    MRF_KERNELS = [3, 5, 7]
    SE_REDUCTION = 16
    EXPAND_RATIO = 6
    DROPOUT = 0.1
    L2_REG = 1e-4
    
    SSTM_SSM_STATE_DIM = 16
    SSTM_DROPOUT = 0.1

    # ==================== TRAINING SETTINGS ====================
    BATCH_SIZE    = 8
    EPOCH_EXPANSION_FACTOR = 30  # Virtual epochs multiplier
    EPOCHS        = 100  # Actual epochs (will be multiplied by expansion factor)
    LEARNING_RATE = 1e-4
    
    EARLY_STOPPING_PATIENCE = 40
    CHECKPOINT_MONITOR      = "val_dice_coefficient"
    CHECKPOINT_MODE         = "max"

    SEED          = 42
    DETERMINISTIC = False

    def __init__(self):
        # Validate ablation ID
        if self.ABLATION_ID not in ABLATION_CONFIGS:
            print(f"\n❌ INVALID ABLATION_ID: {self.ABLATION_ID}")
            print(f"   Valid range: 0-{len(ABLATION_CONFIGS)-1}")
            print_ablation_menu()
            sys.exit(1)
        
        # Load ablation configuration
        self.ABLATION_CONFIG = ABLATION_CONFIGS[self.ABLATION_ID]
        
        # Update save directory with ablation name
        ablation_name = self.ABLATION_CONFIG["name"]
        ablation_name = "".join(c if c.isalnum() or c in "-_." else "_" for c in ablation_name)
        
        self.SAVE_DIR = os.path.join("runs", f"ABLATION_{self.ABLATION_ID:02d}_{ablation_name}")
        os.makedirs(self.SAVE_DIR, exist_ok=True)
        
        print(f"\n🔥 ABLATION STUDY #{self.ABLATION_ID}: {self.ABLATION_CONFIG['name']}")
        print(f"   {self.ABLATION_CONFIG['description']}")
        print(f"   Output: {self.SAVE_DIR}")
        print(f"   Virtual Epoch Factor: {self.EPOCH_EXPANSION_FACTOR}x")

# ==============================================================================
# GPU SETUP
# ==============================================================================

def setup_gpus(gpu_numbers=None):
    """Configure GPUs based on specified GPU numbers."""
    gpus = tf.config.list_physical_devices('GPU')
    
    if not gpus:
        print("⚠️ No GPUs found! Using CPU.")
        return tf.distribute.get_strategy(), 0
    
    print(f"🔍 Total GPUs available: {len(gpus)}")
    
    if gpu_numbers is not None:
        selected_gpus = [gpus[i] for i in gpu_numbers if i < len(gpus)]
    else:
        selected_gpus = gpus
    
    try:
        tf.config.set_visible_devices(selected_gpus, 'GPU')
        for gpu in selected_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        
        num_gpus = len(selected_gpus)
        strategy = tf.distribute.MirroredStrategy() if num_gpus > 1 else tf.distribute.get_strategy()
        print(f"✅ Using {num_gpus} GPU(s)")
        return strategy, num_gpus
    except RuntimeError as e:
        print(f"⚠️ GPU setup error: {e}")
        return tf.distribute.get_strategy(), 0

# ==============================================================================
# UTILS & DATA LOADING
# ==============================================================================

def set_seed(seed=42, deterministic=False):
    import random
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_image_mask_pairs(images_dir, masks_dir):
    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.bmp']
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob(os.path.join(images_dir, ext)))
    image_files = sorted(image_files)
    
    if len(image_files) == 0:
        print(f"⚠️ No images found in {images_dir}")
        return []

    pairs = []
    for img_path in image_files:
        img_name = Path(img_path).stem
        possible_names = [
            f"{img_name}.png", f"{img_name}.jpg", f"{img_name}.bmp", 
            f"{img_name}_mask.png", f"{img_name}_mask.jpg"
        ]
        for mask_name in possible_names:
            cand = os.path.join(masks_dir, mask_name)
            if os.path.exists(cand):
                pairs.append((img_path, cand))
                break
    return pairs

def load_dataset_split(data_root):
    """Load Kvasir-SEG dataset"""
    train_dir = os.path.join(data_root, "train")
    print(train_dir)
    val_dir = os.path.join(data_root, "val")
    test_dir = os.path.join(data_root, "test")
    
    train_images = os.path.join(train_dir, "images")
    train_masks = os.path.join(train_dir, "masks")
    val_images = os.path.join(val_dir, "images")
    val_masks = os.path.join(val_dir, "masks")
    test_images = os.path.join(test_dir, "images")
    test_masks = os.path.join(test_dir, "masks")
    
    return {
        'train': get_image_mask_pairs(train_images, train_masks),
        'val': get_image_mask_pairs(val_images, val_masks),
        'test': get_image_mask_pairs(test_images, test_masks)
    }

def get_ducknet_augmentation(cfg):
    """DuckNet-style aggressive augmentation"""
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        
        # Aggressive crop
        A.RandomResizedCrop(
            size =(cfg.INPUT_SIZE,cfg.INPUT_SIZE),
            # height=cfg.INPUT_SIZE,
            # width=cfg.INPUT_SIZE,
            scale=(0.5, 1.0),
            ratio=(0.9, 1.1),
            interpolation=cv2.INTER_CUBIC,
            p=0.6
        ),
        
        A.ShiftScaleRotate(
            shift_limit=0.0625, scale_limit=0.2, rotate_limit=180, 
            border_mode=cv2.BORDER_CONSTANT, value=0, p=1.0
        ),
        
        # Aggressive coarse dropout
        A.CoarseDropout(
            max_holes=12,
            max_height=48,
            max_width=48,
            min_holes=5,
            min_height=20,
            min_width=20,
            fill_value=0,
            mask_fill_value=0,
            p=0.5
        ),
        
        A.ColorJitter(brightness=0.4, contrast=0.2, saturation=0.1, hue=0.01, p=1.0),
        
        A.Resize(height=cfg.INPUT_SIZE, width=cfg.INPUT_SIZE),
    ], p=1.0)

def get_validation_augmentation(cfg):
    return A.Compose([A.Resize(cfg.INPUT_SIZE, cfg.INPUT_SIZE)])

# ==============================================================================
# DUCKNET DATA GENERATOR WITH VIRTUAL EPOCHS
# ==============================================================================

class DuckNetExpandedGenerator(tf.keras.utils.Sequence):
    """
    DuckNet-style data generator with virtual epoch expansion.
    
    This generator cycles through the dataset multiple times per epoch,
    applying different augmentations each time to create "virtual" epochs.
    """
    def __init__(self, pairs, cfg, augmentation=None, shuffle=True, expansion_factor=1):
        self.pairs = pairs
        self.cfg = cfg
        self.augmentation = augmentation
        self.shuffle = shuffle
        self.expansion_factor = expansion_factor
        self.indices = np.arange(len(self.pairs))
        
        # Calculate real and virtual batches
        self.real_batches = len(self.pairs) // self.cfg.BATCH_SIZE
        self.virtual_batches = self.real_batches * self.expansion_factor
        
        if self.shuffle:
            np.random.shuffle(self.indices)
        
        print(f"   DuckNet Generator: {len(self.pairs)} samples")
        print(f"   Real batches/epoch: {self.real_batches}")
        print(f"   Virtual batches/epoch: {self.virtual_batches} ({expansion_factor}x expansion)")
        print(f"   Virtual images/epoch: {len(self.pairs) * expansion_factor}")
            
    def __len__(self):
        return self.virtual_batches

    def __getitem__(self, index):
        # Map virtual batch index to real batch index (cycle through dataset)
        real_index_ptr = index % self.real_batches
        batch_start = real_index_ptr * self.cfg.BATCH_SIZE
        batch_end = batch_start + self.cfg.BATCH_SIZE
        batch_indices = self.indices[batch_start:batch_end]
        
        images, masks = [], []
        for idx in batch_indices:
            img_path, mask_path = self.pairs[idx]
            
            # Load image
            image = cv2.imread(img_path)
            if image is None: 
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Load mask
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask = (mask > 127).astype(np.float32)
            
            # Apply augmentation (different each time due to randomness)
            if self.augmentation:
                augmented = self.augmentation(image=image, mask=mask)
                image = augmented["image"]
                mask = augmented["mask"]
            
            # Normalize image
            image = image.astype(np.float32) / 255.0
            
            # Ensure mask has channel dimension
            if len(mask.shape) == 2:
                mask = np.expand_dims(mask, axis=-1)
            
            images.append(image)
            masks.append(mask)
            
        return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)

    def on_epoch_end(self):
        """Shuffle indices at the end of each epoch"""
        if self.shuffle:
            np.random.shuffle(self.indices)

# ==============================================================================
# MODEL ARCHITECTURE (WITH ABLATION SUPPORT)
# ==============================================================================

class SpectralSelectiveTokenMixer(Layer):
    def __init__(self, channels, num_frequencies=32, ssm_state_dim=16, 
                 dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_frequencies = num_frequencies
        self.ssm_state_dim = ssm_state_dim
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        input_h, input_w = input_shape[1], input_shape[2]
        self.actual_frequencies = min(self.num_frequencies, input_h, input_w) if input_h else self.num_frequencies
        
        self.freq_weights_real = self.add_weight(
            name='freq_weights_real',
            shape=(self.actual_frequencies, self.actual_frequencies, self.channels),
            initializer=self._get_initializer(),
            trainable=True
        )
        self.spectral_norm = LayerNormalization(epsilon=1e-6, name='spectral_norm')
        
        self.ssm_C = Dense(self.channels, name='ssm_C')
        self.selection_gate = Dense(self.channels, activation='sigmoid', name='selection')
        self.ssm_norm = LayerNormalization(epsilon=1e-6, name='ssm_norm')
        
        self.fusion = Dense(self.channels, name='fusion')
        self.fusion_norm = LayerNormalization(epsilon=1e-6, name='fusion_norm')
        
        self.norm = LayerNormalization(epsilon=1e-6, name='norm')
        super().build(input_shape)
    
    def _get_initializer(self):
        def init_fn(shape, dtype=None):
            H, W, C = shape
            freq_h = np.fft.fftfreq(H)[:, np.newaxis]
            freq_w = np.fft.fftfreq(W)[np.newaxis, :]
            freq_magnitude = np.sqrt(freq_h**2 + freq_w**2)
            gaussian = np.exp(-((freq_magnitude - 0.25)**2) / (2 * 0.15**2))
            gaussian = np.repeat(gaussian[:, :, np.newaxis], C, axis=2)
            return gaussian.astype(np.float32) * 0.5
        return init_fn
    
    def spectral_path(self, x):
        H, W = tf.shape(x)[1], tf.shape(x)[2]
        freq_size = tf.minimum(tf.minimum(H, W), self.actual_frequencies)
        x_complex = tf.cast(x, tf.complex64)
        x_freq = tf.signal.fft2d(x_complex)
        x_freq_real = tf.math.real(x_freq)
        x_freq_imag = tf.math.imag(x_freq)
        x_freq_real_resized = tf.image.resize(x_freq_real, [freq_size, freq_size])
        x_freq_imag_resized = tf.image.resize(x_freq_imag, [freq_size, freq_size])
        x_freq_resized = tf.complex(x_freq_real_resized, x_freq_imag_resized)
        freq_filter = tf.cast(self.freq_weights_real[:freq_size, :freq_size, :], tf.complex64)
        x_freq_filtered = x_freq_resized * freq_filter
        x_freq_filt_real = tf.math.real(x_freq_filtered)
        x_freq_filt_imag = tf.math.imag(x_freq_filtered)
        x_freq_back_real = tf.image.resize(x_freq_filt_real, [H, W])
        x_freq_back_imag = tf.image.resize(x_freq_filt_imag, [H, W])
        x_freq_back = tf.complex(x_freq_back_real, x_freq_back_imag)
        x_spatial = tf.signal.ifft2d(x_freq_back)
        return self.spectral_norm(tf.math.real(x_spatial))
    
    def ssm_path(self, x):
        B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x_flat = tf.reshape(x, [B, H * W, self.channels])
        selection = self.selection_gate(x_flat)
        x_selected = x_flat * selection
        x_ssm = self.ssm_C(x_selected)
        return self.ssm_norm(tf.reshape(x_ssm, [B, H, W, self.channels]))
    
    def call(self, x, training=None):
        spectral_out = self.spectral_path(x)
        ssm_out = self.ssm_path(x)
        fused = self.fusion_norm(self.fusion(tf.concat([spectral_out, ssm_out], axis=-1)))
        fused = self.norm(fused)
        if training and self.dropout_rate > 0:
            fused = tf.nn.dropout(fused, rate=self.dropout_rate)
        return x + fused

def MRF_SE_BLOCK(x, filters, activation='elu', dropout=0.0, expand_ratio=6, 
                 regularizer=0.0, kernels=[3, 5, 7], se_reduction=16, name='mrf_se'):
    F_expanded = filters * expand_ratio
    
    conv = Conv2D(F_expanded, (1, 1), padding='same', 
                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                  name=name+'_expand')(x) if expand_ratio > 1 else x
    conv = Activation(activation)(BatchNormalization()(conv))
    
    features = []
    for k in kernels:
        dw = DepthwiseConv2D((k, k), padding='same',
                            depthwise_regularizer=l2(regularizer) if regularizer > 0 else None,
                            name=f"{name}_dw{k}x{k}")(conv)
        features.append(Activation(activation)(BatchNormalization()(dw)))
    
    combined = Concatenate()(features) if len(features) > 1 else features[0]
    if len(features) > 1:
        combined = Activation(activation)(BatchNormalization()(
            Conv2D(F_expanded, (1, 1), padding='same',
                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None)(combined)))
    
    gap = Reshape((1, 1, F_expanded))(GlobalAveragePooling2D()(combined))
    se = Conv2D(F_expanded, (1, 1), activation='sigmoid')(
        Conv2D(max(F_expanded//se_reduction, 8), (1, 1), activation=activation)(gap))
    
    projected = Conv2D(filters, (1, 1), padding='same',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None)(
        Multiply()([combined, se]))
    projected = BatchNormalization()(projected)
    
    if dropout > 0:
        projected = Dropout(dropout)(projected)
    
    return Add()([projected, x])

def boundary_detection_module(features, filters, name='boundary'):
    boundary_conv = Conv2D(filters // 2, (3, 3), padding='same',
                          activation='relu', name=name + '_conv')(features)
    boundary_map = Conv2D(1, (1, 1), padding='same',
                         activation='sigmoid', name=name + '_map')(boundary_conv)
    return Multiply()([features, boundary_map]), boundary_map

def BFP_decoder_stage(decoder_input, skip_features, filters, routing_type='soft', stage_name='bfp'):
    """Boundary-Focused Decoder with different routing variants"""
    region = Concatenate()([UpSampling2D((2, 2))(decoder_input), skip_features])
    
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same')(region)))
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same')(region)))
    
    boundary_features, boundary_map = boundary_detection_module(region, filters, stage_name+'_boundary')
    
    boundary_refined = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same')(boundary_features)))
    
    if routing_type == 'soft':
        output = Activation('relu')(BatchNormalization()(
            Conv2D(filters, (1, 1), padding='same')(
                region * (1 - boundary_map) + boundary_refined * boundary_map)))
    elif routing_type == 'hard':
        beta_hard = tf.cast(boundary_map > 0.5, tf.float32)
        output = Activation('relu')(BatchNormalization()(
            Conv2D(filters, (1, 1), padding='same')(
                region * (1 - beta_hard) + boundary_refined * beta_hard)))
    elif routing_type == 'none':
        output = Activation('relu')(BatchNormalization()(
            Conv2D(filters, (1, 1), padding='same')(
                Concatenate()([region, boundary_refined]))))
    elif routing_type == 'learned':
        combined = Concatenate()([region, boundary_refined])
        routing_weights = Conv2D(2, (1, 1), activation='softmax')(combined)
        w_region = routing_weights[..., 0:1]
        w_boundary = routing_weights[..., 1:2]
        output = Activation('relu')(BatchNormalization()(
            Conv2D(filters, (1, 1), padding='same')(
                region * w_region + boundary_refined * w_boundary)))
    
    return output, boundary_map

def build_medsegnet_ssf(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING MED-SEGNET-SSF WITH ABLATION CONFIG")
    print("="*80)
    
    ablation = cfg.ABLATION_CONFIG
    arch = ablation['architecture']
    
    print(f"   MRFSE: {arch['use_mrfse']}")
    print(f"   SSTM: {arch['use_sstm']} (K={arch['sstm_k']})")
    print(f"   SSTM Stages: {arch['sstm_stages']}")
    print(f"   BFP: {arch['use_bfp']} (routing={cfg.ABLATION_CONFIG['bfp_routing']})")
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    x = Conv2D(16, (3, 3), padding='same')(inp)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    
    encoder_outputs = []
    filters = [cfg.F1, cfg.F2, cfg.F3, cfg.F4, cfg.F5]
    
    for i, f in enumerate(filters):
        x = Conv2D(f, (3, 3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('elu')(x)
        
        if arch['use_mrfse']:
            x = MRF_SE_BLOCK(x, f, activation='elu', dropout=cfg.DROPOUT,
                           expand_ratio=cfg.EXPAND_RATIO, regularizer=cfg.L2_REG,
                           kernels=cfg.MRF_KERNELS, se_reduction=cfg.SE_REDUCTION,
                           name=f'mrfse_stage{i+1}')
        
        if arch['use_sstm'] and arch['sstm_stages'][i]:
            x = SpectralSelectiveTokenMixer(
                channels=f, 
                num_frequencies=arch['sstm_k'],
                ssm_state_dim=cfg.SSTM_SSM_STATE_DIM,
                dropout=cfg.SSTM_DROPOUT,
                name=f'sstm_stage{i+1}'
            )(x)
        
        encoder_outputs.append(x)
        print(f"  Encoder Stage {i+1}: filters={f}, SSTM={arch['sstm_stages'][i]}")
    
    skip_connections = encoder_outputs[::-1]
    decoder = skip_connections[0]
    decoder_filters = filters[::-1][1:] + [16]
    
    for i, (skip, f) in enumerate(zip(skip_connections[1:], decoder_filters)):
        if arch['use_bfp']:
            decoder, _ = BFP_decoder_stage(
                decoder, skip, f, 
                routing_type=cfg.ABLATION_CONFIG['bfp_routing'],
                stage_name=f'bfp_stage{i+1}'
            )
        else:
            decoder = UpSampling2D((2, 2))(decoder)
            decoder = Concatenate()([decoder, skip])
            decoder = Activation('relu')(BatchNormalization()(
                Conv2D(f, (3, 3), padding='same')(decoder)))
            decoder = Activation('relu')(BatchNormalization()(
                Conv2D(f, (3, 3), padding='same')(decoder)))
        
        print(f"  Decoder Stage {i+1}: filters={f}, BFP={arch['use_bfp']}")
    
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)
    decoder = Conv2D(16, (3, 3), padding='same', activation='relu')(decoder)
    out = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name='output')(decoder)
    
    model = Model(inputs=inp, outputs=out, name="MedSegNet_SSF")
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# MASL LOSS FUNCTION (WITH ABLATION SUPPORT)
# ==============================================================================

class ClipConstraint(tf.keras.constraints.Constraint):
    def __init__(self, min_value=0.1, max_value=10.0):
        self.min_value = min_value
        self.max_value = max_value
    
    def __call__(self, w):
        return tf.clip_by_value(w, self.min_value, self.max_value)
    
    def get_config(self):
        return {'min_value': self.min_value, 'max_value': self.max_value}

class MorphologyAwareAdaptiveLoss(Layer):
    def __init__(self, ablation_config, name='masl', **kwargs):
        super().__init__(name=name, **kwargs)
        self.ablation_config = ablation_config
        self.epsilon = 1e-6
        
    def build(self, input_shape):
        masl_components = self.ablation_config['masl_components']
        masl_settings = self.ablation_config['masl_settings']
        
        if masl_settings.get('use_learned_weights', True):
            clip_constraint = ClipConstraint(min_value=0.1, max_value=10.0)
            
            if masl_components['use_core']:
                self.w_region = self.add_weight(
                    name='w_region', shape=(), initializer=tf.constant_initializer(1.0),
                    trainable=True, constraint=clip_constraint, dtype=tf.float32
                )
            
            if masl_components['use_boundary']:
                self.w_boundary = self.add_weight(
                    name='w_boundary', shape=(), initializer=tf.constant_initializer(1.0),
                    trainable=True, constraint=clip_constraint, dtype=tf.float32
                )
            
            if masl_components['use_structure']:
                self.w_structure = self.add_weight(
                    name='w_structure', shape=(), initializer=tf.constant_initializer(1.0),
                    trainable=True, constraint=clip_constraint, dtype=tf.float32
                )
            
            if masl_components['use_scale']:
                self.w_scale = self.add_weight(
                    name='w_scale', shape=(), initializer=tf.constant_initializer(0.5),
                    trainable=True, constraint=clip_constraint, dtype=tf.float32
                )
            
            if masl_components['use_texture']:
                self.w_texture = self.add_weight(
                    name='w_texture', shape=(), initializer=tf.constant_initializer(0.5),
                    trainable=True, constraint=clip_constraint, dtype=tf.float32
                )
        
        super().build(input_shape)
    
    def morphological_dilation(self, x, kernel_size=5):
        return tf.nn.max_pool2d(x, kernel_size, strides=1, padding='SAME')
    
    def morphological_erosion(self, x, kernel_size=5):
        return -tf.nn.max_pool2d(-x, kernel_size, strides=1, padding='SAME')
    
    def detect_boundary(self, mask, kernel_size=5):
        dilated = self.morphological_dilation(mask, kernel_size)
        eroded = self.morphological_erosion(mask, kernel_size)
        boundary = dilated - eroded
        return tf.clip_by_value(boundary, 0.0, 1.0)
    
    def analyze_structure_characteristics(self, y_true):
        area = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        total_pixels = tf.cast(tf.shape(y_true)[1] * tf.shape(y_true)[2], tf.float32)
        
        dy = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_padded = tf.pad(dy, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_padded = tf.pad(dx, [[0, 0], [0, 0], [0, 1], [0, 0]])
        gradient_mag = tf.sqrt(dy_padded**2 + dx_padded**2 + self.epsilon)
        perimeter = tf.reduce_sum(gradient_mag, axis=[1, 2, 3]) + self.epsilon
        
        skeleton_approx = self.morphological_erosion(y_true, kernel_size=3)
        skeleton_area = tf.reduce_sum(skeleton_approx, axis=[1, 2, 3]) + self.epsilon
        
        tubularity = tf.reduce_mean(skeleton_area / (area + self.epsilon))
        compactness = tf.reduce_mean((4 * 3.14159 * area) / (perimeter**2 + self.epsilon))
        compactness = tf.clip_by_value(compactness, 0.0, 1.0)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        ddy = boundary[:, 2:, :, :] - 2*boundary[:, 1:-1, :, :] + boundary[:, :-2, :, :]
        ddx = boundary[:, :, 2:, :] - 2*boundary[:, :, 1:-1, :] + boundary[:, :, :-2, :]
        irregularity = tf.reduce_mean(tf.abs(ddy)) + tf.reduce_mean(tf.abs(ddx))
        
        object_size = tf.reduce_mean(area / total_pixels)
        
        return {
            'tubularity': tf.clip_by_value(tubularity, 0.0, 1.0),
            'compactness': compactness,
            'irregularity': tf.clip_by_value(irregularity, 0.0, 1.0),
            'object_size': tf.clip_by_value(object_size, 0.0, 1.0)
        }
    
    def core_loss(self, y_true, y_pred):
        intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
        dice = (2. * intersection + self.epsilon) / (
            tf.reduce_sum(y_true, axis=[1, 2, 3]) + 
            tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon
        )
        dice_loss = 1.0 - tf.reduce_mean(dice)
        
        union = (tf.reduce_sum(y_true, axis=[1, 2, 3]) + 
                tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection)
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        iou_loss = 1.0 - tf.reduce_mean(iou)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        weights = 1.0 + 5.0 * boundary
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + 
               (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        weighted_bce = tf.reduce_mean(weights * bce)
        
        return 0.4 * dice_loss + 0.3 * iou_loss + 0.3 * weighted_bce
    
    def boundary_loss(self, y_true, y_pred):
        total_loss = 0.0
        weights = [0.5, 0.3, 0.2]
        
        for scale, w in zip([1, 2, 4], weights):
            dy_true = y_true[:, scale:, :, :] - y_true[:, :-scale, :, :]
            dy_pred = y_pred[:, scale:, :, :] - y_pred[:, :-scale, :, :]
            
            dx_true = y_true[:, :, scale:, :] - y_true[:, :, :-scale, :]
            dx_pred = y_pred[:, :, scale:, :] - y_pred[:, :, :-scale, :]
            
            total_loss += w * (tf.reduce_mean(tf.abs(dy_true - dy_pred)) + 
                             tf.reduce_mean(tf.abs(dx_true - dx_pred)))
        
        return total_loss
    
    def structure_aware_loss(self, y_true, y_pred, characteristics):
        area_true = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        dy_true = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx_true = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_true_padded = tf.pad(dy_true, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_true_padded = tf.pad(dx_true, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_true = tf.reduce_sum(tf.sqrt(dy_true_padded**2 + dx_true_padded**2 + self.epsilon), 
                                      axis=[1, 2, 3]) + self.epsilon
        
        area_pred = tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon
        dy_pred = y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]
        dx_pred = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
        dy_pred_padded = tf.pad(dy_pred, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_pred_padded = tf.pad(dx_pred, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_pred = tf.reduce_sum(tf.sqrt(dy_pred_padded**2 + dx_pred_padded**2 + self.epsilon), 
                                      axis=[1, 2, 3]) + self.epsilon
        
        compact_true = area_true / (perimeter_true**2 + self.epsilon)
        compact_pred = area_pred / (perimeter_pred**2 + self.epsilon)
        
        return characteristics['compactness'] * tf.reduce_mean(tf.abs(compact_true - compact_pred))
    
    def scale_aware_focal_loss(self, y_true, y_pred, characteristics):
        size = characteristics['object_size']
        gamma = tf.cond(
            size < 0.05,
            lambda: 3.0,
            lambda: tf.cond(size < 0.2, lambda: 2.0, lambda: 1.5)
        )
        
        p = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_weight = tf.pow(1 - p, gamma)
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + 
               (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        
        return tf.reduce_mean(focal_weight * bce)
    
    def texture_aware_loss(self, y_true, y_pred):
        ddy_true = y_true[:, 2:, :, :] - 2*y_true[:, 1:-1, :, :] + y_true[:, :-2, :, :]
        ddy_pred = y_pred[:, 2:, :, :] - 2*y_pred[:, 1:-1, :, :] + y_pred[:, :-2, :, :]
        
        ddx_true = y_true[:, :, 2:, :] - 2*y_true[:, :, 1:-1, :] + y_true[:, :, :-2, :]
        ddx_pred = y_pred[:, :, 2:, :] - 2*y_pred[:, :, 1:-1, :] + y_pred[:, :, :-2, :]
        
        return tf.reduce_mean(tf.abs(ddy_true - ddy_pred)) + tf.reduce_mean(tf.abs(ddx_true - ddx_pred))
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        masl_components = self.ablation_config['masl_components']
        masl_settings = self.ablation_config['masl_settings']
        
        # Special case: Simple Dice Loss
        if masl_settings.get('use_simple_dice', False):
            intersection = tf.reduce_sum(y_true * y_pred)
            dice = (2. * intersection + self.epsilon) / (
                tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + self.epsilon
            )
            return 1.0 - dice
        
        characteristics = self.analyze_structure_characteristics(y_true)
        
        if masl_settings.get('use_morphology_modulation', True):
            alpha_region = 1.0 + 0.5 * characteristics['compactness']
            alpha_boundary = 1.0 + 1.5 * characteristics['tubularity'] + characteristics['compactness']
            alpha_structure = 1.0 + characteristics['tubularity']
            alpha_scale = 1.0 + 1.5 * characteristics['irregularity']
            alpha_texture = 1.0 + characteristics['irregularity']
        else:
            alpha_region = 1.0
            alpha_boundary = 1.0
            alpha_structure = 1.0
            alpha_scale = 1.0
            alpha_texture = 1.0
        
        total_loss = 0.0
        total_weight = 0.0
        
        if masl_components['use_core']:
            l_core = self.core_loss(y_true, y_pred)
            w = self.w_region if masl_settings.get('use_learned_weights', True) else 1.0
            w = tf.cast(w, tf.float32)
            total_loss += w * alpha_region * l_core
            total_weight += w * alpha_region
        
        if masl_components['use_boundary']:
            l_boundary = self.boundary_loss(y_true, y_pred)
            w = self.w_boundary if masl_settings.get('use_learned_weights', True) else 1.0
            w = tf.cast(w, tf.float32)
            total_loss += w * alpha_boundary * l_boundary
            total_weight += w * alpha_boundary
        
        if masl_components['use_structure']:
            l_structure = self.structure_aware_loss(y_true, y_pred, characteristics)
            w = self.w_structure if masl_settings.get('use_learned_weights', True) else 1.0
            w = tf.cast(w, tf.float32)
            total_loss += w * alpha_structure * l_structure
            total_weight += w * alpha_structure
        
        if masl_components['use_scale']:
            l_scale = self.scale_aware_focal_loss(y_true, y_pred, characteristics)
            w = self.w_scale if masl_settings.get('use_learned_weights', True) else 1.0
            w = tf.cast(w, tf.float32)
            total_loss += w * alpha_scale * l_scale
            total_weight += w * alpha_scale
        
        if masl_components['use_texture']:
            l_texture = self.texture_aware_loss(y_true, y_pred)
            w = self.w_texture if masl_settings.get('use_learned_weights', True) else 1.0
            w = tf.cast(w, tf.float32)
            total_loss += w * alpha_texture * l_texture
            total_weight += w * alpha_texture
        
        masl_loss = total_loss / (total_weight + self.epsilon)
        
        return masl_loss
    
    def get_config(self):
        return super().get_config()

_masl_instance = None

def masl_loss_fn(y_true, y_pred):
    return _masl_instance(y_true, y_pred)

# ==============================================================================
# METRICS
# ==============================================================================

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def precision_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    predicted_positives = K.sum(y_pred_bin)
    return true_positives / (predicted_positives + K.epsilon())

def recall_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    actual_positives = K.sum(y_true)
    return true_positives / (actual_positives + K.epsilon())

# ==============================================================================
# PREDICTION SAVING AND VISUALIZATION
# ==============================================================================

def save_predictions_and_visualizations(model, test_gen, cfg):
    """Save prediction masks and create visualization grids with running metrics"""
    
    print("\n" + "="*80)
    print("💾 SAVING PREDICTION MASKS")
    print("="*80)
    
    # Create output directories
    pred_dir = os.path.join(cfg.SAVE_DIR, "predictions")
    vis_dir = os.path.join(cfg.SAVE_DIR, "visualizations")
    os.makedirs(pred_dir, exist_ok=True)
    os.makedirs(vis_dir, exist_ok=True)
    
    # Metrics storage
    all_dice_scores = []
    all_iou_scores = []
    all_precision_scores = []
    all_recall_scores = []
    
    sample_idx = 0
    
    # Process all batches
    for batch_idx in range(len(test_gen)):
        images, masks = test_gen[batch_idx]
        predictions = model.predict(images, verbose=0)
        
        for i in range(len(images)):
            # Get image, mask, prediction
            img = images[i]
            true_mask = masks[i]
            pred_mask = predictions[i]
            
            # Calculate metrics for this sample
            pred_binary = (pred_mask > 0.5).astype(np.float32)
            
            # Dice
            intersection = np.sum(true_mask * pred_binary)
            dice = (2. * intersection + 1e-6) / (np.sum(true_mask) + np.sum(pred_binary) + 1e-6)
            all_dice_scores.append(dice)
            
            # IoU
            union = np.sum(true_mask) + np.sum(pred_binary) - intersection
            iou = (intersection + 1e-6) / (union + 1e-6)
            all_iou_scores.append(iou)
            
            # Precision
            true_positives = np.sum(true_mask * pred_binary)
            predicted_positives = np.sum(pred_binary)
            precision = true_positives / (predicted_positives + 1e-6)
            all_precision_scores.append(precision)
            
            # Recall
            actual_positives = np.sum(true_mask)
            recall = true_positives / (actual_positives + 1e-6)
            all_recall_scores.append(recall)
            
            # Save prediction mask
            pred_mask_uint8 = (pred_mask.squeeze() * 255).astype(np.uint8)
            pred_filename = os.path.join(pred_dir, f"pred_{sample_idx:04d}.png")
            cv2.imwrite(pred_filename, pred_mask_uint8)
            
            sample_idx += 1
            
            # Print running metrics every 10 samples
            if sample_idx % 10 == 0:
                running_dice = np.mean(all_dice_scores)
                print(f"   Processed {sample_idx} samples... (Dice: {running_dice:.4f})")
    
    # Calculate final statistics
    mean_dice = np.mean(all_dice_scores)
    std_dice = np.std(all_dice_scores)
    mean_iou = np.mean(all_iou_scores)
    std_iou = np.std(all_iou_scores)
    mean_precision = np.mean(all_precision_scores)
    std_precision = np.std(all_precision_scores)
    mean_recall = np.mean(all_recall_scores)
    std_recall = np.std(all_recall_scores)
    
    print("\n" + "="*80)
    print(f"✅ SAVED {sample_idx} PREDICTION MASKS")
    print("="*80)
    
    print(f"\n📊 Overall Metrics:")
    print(f"   Mean Dice:      {mean_dice:.4f} ± {std_dice:.4f}")
    print(f"   Mean IoU:       {mean_iou:.4f} ± {std_iou:.4f}")
    print(f"   Mean Precision: {mean_precision:.4f} ± {std_precision:.4f}")
    print(f"   Mean Recall:    {mean_recall:.4f} ± {std_recall:.4f}")
    
    # Create visualization grids
    print("\n" + "="*80)
    print("🎨 CREATING VISUALIZATION GRIDS")
    print("="*80)
    
    create_prediction_visualizations(model, test_gen, vis_dir, cfg)
    
    # Return metrics
    metrics_summary = {
        'mean_dice': float(mean_dice),
        'std_dice': float(std_dice),
        'mean_iou': float(mean_iou),
        'std_iou': float(std_iou),
        'mean_precision': float(mean_precision),
        'std_precision': float(std_precision),
        'mean_recall': float(mean_recall),
        'std_recall': float(std_recall),
        'num_samples': sample_idx
    }
    
    return metrics_summary

def create_prediction_visualizations(model, test_gen, vis_dir, cfg):
    """Create beautiful visualization grids showing predictions"""
    
    # Select representative samples (best, worst, median)
    num_vis_samples = min(12, len(test_gen) * cfg.BATCH_SIZE)
    
    # Collect all samples with metrics
    samples_with_metrics = []
    for batch_idx in range(len(test_gen)):
        images, masks = test_gen[batch_idx]
        predictions = model.predict(images, verbose=0)
        
        for i in range(len(images)):
            img = images[i]
            true_mask = masks[i]
            pred_mask = predictions[i]
            
            pred_binary = (pred_mask > 0.5).astype(np.float32)
            intersection = np.sum(true_mask * pred_binary)
            dice = (2. * intersection + 1e-6) / (np.sum(true_mask) + np.sum(pred_binary) + 1e-6)
            
            samples_with_metrics.append({
                'image': img,
                'true_mask': true_mask,
                'pred_mask': pred_mask,
                'dice': dice
            })
    
    # Sort by dice score
    samples_with_metrics.sort(key=lambda x: x['dice'])
    
    # Select samples: worst 3, median 6, best 3
    num_worst = min(3, len(samples_with_metrics) // 4)
    num_best = min(3, len(samples_with_metrics) // 4)
    num_median = num_vis_samples - num_worst - num_best
    
    median_start = len(samples_with_metrics) // 2 - num_median // 2
    
    selected_samples = (
        samples_with_metrics[:num_worst] +  # Worst
        samples_with_metrics[median_start:median_start + num_median] +  # Median
        samples_with_metrics[-num_best:]  # Best
    )
    
    # Create grid visualization
    rows = (num_vis_samples + 3) // 4  # 4 columns
    cols = 4
    
    fig = plt.figure(figsize=(20, 5 * rows))
    gs = gridspec.GridSpec(rows, cols, figure=fig, hspace=0.3, wspace=0.2)
    
    for idx, sample in enumerate(selected_samples[:num_vis_samples]):
        row = idx // cols
        col = idx % cols
        
        # Create subplot
        ax = fig.add_subplot(gs[row, col])
        
        # Create composite visualization
        img = sample['image']
        true_mask = sample['true_mask'].squeeze()
        pred_mask = sample['pred_mask'].squeeze()
        pred_binary = (pred_mask > 0.5).astype(np.float32)
        
        # Create RGB overlay
        overlay = img.copy()
        
        # True mask in green
        overlay[:, :, 1] = np.where(true_mask > 0.5, 
                                     np.clip(overlay[:, :, 1] + 0.3, 0, 1), 
                                     overlay[:, :, 1])
        
        # Prediction in red/blue (red for correct, blue for incorrect)
        correct_pred = pred_binary * true_mask
        incorrect_pred = pred_binary * (1 - true_mask)
        
        overlay[:, :, 0] = np.where(correct_pred > 0.5,
                                     np.clip(overlay[:, :, 0] + 0.3, 0, 1),
                                     overlay[:, :, 0])
        overlay[:, :, 2] = np.where(incorrect_pred > 0.5,
                                     np.clip(overlay[:, :, 2] + 0.3, 0, 1),
                                     overlay[:, :, 2])
        
        ax.imshow(overlay)
        ax.set_title(f"Sample {idx + 1}\nDice: {sample['dice']:.4f}", 
                     fontsize=10, fontweight='bold')
        ax.axis('off')
    
    # Add legend
    legend_elements = [
        Patch(facecolor='green', alpha=0.5, label='Ground Truth'),
        Patch(facecolor='red', alpha=0.5, label='Correct Prediction'),
        Patch(facecolor='blue', alpha=0.5, label='False Positive')
    ]
    fig.legend(handles=legend_elements, loc='lower center', ncol=3, 
               fontsize=12, frameon=False, bbox_to_anchor=(0.5, -0.02))
    
    plt.suptitle(f'Prediction Visualizations - {cfg.ABLATION_CONFIG["name"]}', 
                 fontsize=16, fontweight='bold', y=0.995)
    
    vis_path = os.path.join(vis_dir, 'prediction_grid.png')
    plt.savefig(vis_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"   ✅ Saved visualization grid: {vis_path}")
    
    # Create individual comparison grids
    create_individual_comparisons(selected_samples[:6], vis_dir, cfg)

def create_individual_comparisons(samples, vis_dir, cfg):
    """Create detailed individual comparisons"""
    
    fig, axes = plt.subplots(len(samples), 4, figsize=(16, 4 * len(samples)))
    
    if len(samples) == 1:
        axes = axes.reshape(1, -1)
    
    for idx, sample in enumerate(samples):
        img = sample['image']
        true_mask = sample['true_mask'].squeeze()
        pred_mask = sample['pred_mask'].squeeze()
        pred_binary = (pred_mask > 0.5).astype(np.float32)
        
        # Original image
        axes[idx, 0].imshow(img)
        axes[idx, 0].set_title('Original Image', fontsize=10)
        axes[idx, 0].axis('off')
        
        # Ground truth
        axes[idx, 1].imshow(img)
        axes[idx, 1].imshow(true_mask, alpha=0.5, cmap='Greens', vmin=0, vmax=1)
        axes[idx, 1].set_title('Ground Truth', fontsize=10)
        axes[idx, 1].axis('off')
        
        # Prediction
        axes[idx, 2].imshow(img)
        axes[idx, 2].imshow(pred_binary, alpha=0.5, cmap='Reds', vmin=0, vmax=1)
        axes[idx, 2].set_title(f'Prediction\nDice: {sample["dice"]:.4f}', fontsize=10)
        axes[idx, 2].axis('off')
        
        # Difference map
        difference = np.zeros((*true_mask.shape, 3))
        # True Positive (White)
        tp = true_mask * pred_binary
        difference[:, :, :] += np.stack([tp, tp, tp], axis=-1)
        # False Negative (Red)
        fn = true_mask * (1 - pred_binary)
        difference[:, :, 0] += fn
        # False Positive (Blue)
        fp = (1 - true_mask) * pred_binary
        difference[:, :, 2] += fp
        
        axes[idx, 3].imshow(difference)
        axes[idx, 3].set_title('Difference\n(TP:White FN:Red FP:Blue)', fontsize=9)
        axes[idx, 3].axis('off')
    
    plt.suptitle(f'Detailed Predictions - {cfg.ABLATION_CONFIG["name"]}', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    detail_path = os.path.join(vis_dir, 'detailed_comparisons.png')
    plt.savefig(detail_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"   ✅ Saved detailed comparisons: {detail_path}")

# ==============================================================================
# MAIN TRAINING LOOP
# ==============================================================================

def train_ablation_model(cfg, strategy, num_gpus):
    global _masl_instance
    
    set_seed(cfg.SEED, cfg.DETERMINISTIC)
    
    # 1. LOAD DATA
    splits = load_dataset_split(cfg.DATA_ROOT)
    
    if not splits['train']:
        print("❌ No training data found!")
        print(f"   Checked: {cfg.DATA_ROOT}")
        return None, None, None
    
    print(f"\n📊 Dataset Statistics:")
    print(f"   Training samples:   {len(splits['train'])}")
    print(f"   Validation samples: {len(splits['val'])}")
    print(f"   Test samples:       {len(splits['test'])}")
    
    # 2. GENERATORS (DuckNet-style with virtual epochs)
    train_aug = get_ducknet_augmentation(cfg)
    val_aug = get_validation_augmentation(cfg)
    
    train_gen = DuckNetExpandedGenerator(
        splits['train'], cfg, 
        augmentation=train_aug, 
        shuffle=True,
        expansion_factor=cfg.EPOCH_EXPANSION_FACTOR
    )
    
    val_gen = DuckNetExpandedGenerator(
        splits['val'], cfg, 
        augmentation=val_aug, 
        shuffle=False,
        expansion_factor=1  # No expansion for validation
    )
    
    test_gen = DuckNetExpandedGenerator(
        splits['test'], cfg, 
        augmentation=val_aug, 
        shuffle=False,
        expansion_factor=1  # No expansion for test
    )
    
    print(f"\n📊 Training Configuration:")
    print(f"   Steps per Epoch: {len(train_gen)}")
    if num_gpus > 1:
        print(f"   Effective batch size: {cfg.BATCH_SIZE * num_gpus}")
    
    # 3. BUILD MODEL
    with strategy.scope():
        model = build_medsegnet_ssf(cfg)
        
        _masl_instance = MorphologyAwareAdaptiveLoss(cfg.ABLATION_CONFIG)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)
        
        print(f"\n📊 Loss Function: {cfg.ABLATION_CONFIG['name']}")
        print(f"   {cfg.ABLATION_CONFIG['description']}")
        
        model.compile(
            optimizer=optimizer,
            loss=masl_loss_fn,
            metrics=[dice_coefficient, iou_score, precision_metric, recall_metric]
        )
    
    # 4. CALLBACKS
    callbacks = [
        ModelCheckpoint(
            os.path.join(cfg.SAVE_DIR, "best_model.h5"),
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            save_best_only=True, verbose=1
        ),
        EarlyStopping(
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            patience=cfg.EARLY_STOPPING_PATIENCE, verbose=1, restore_best_weights=True
        ),
        CSVLogger(os.path.join(cfg.SAVE_DIR, "training_log.csv"))
    ]
    
    # 5. TRAIN
    gpu_info = f" on {num_gpus} GPU(s)" if num_gpus > 0 else " on CPU"
    print(f"\n🚀 STARTING TRAINING ({cfg.EPOCHS} EPOCHS){gpu_info}")
    print("="*80)
    start_time = time.time()
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=cfg.EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    
    training_time = time.time() - start_time
    print(f"\n✅ Training finished in {training_time/60:.1f} minutes")
    
    # 6. EVALUATE
    print("\n" + "="*80)
    print("📊 EVALUATING ON TEST SET")
    print("="*80)
    test_results = model.evaluate(test_gen, verbose=1)
    
    # 7. SAVE PREDICTIONS AND CREATE VISUALIZATIONS
    prediction_metrics = save_predictions_and_visualizations(model, test_gen, cfg)
    
    # 8. SAVE RESULTS
    results = {
        "ablation_id": cfg.ABLATION_ID,
        "ablation_name": cfg.ABLATION_CONFIG['name'],
        "ablation_description": cfg.ABLATION_CONFIG['description'],
        "ablation_config": cfg.ABLATION_CONFIG,
        "training_time_minutes": training_time / 60,
        "total_parameters": int(model.count_params()),
        "test_results": {name: float(value) for name, value in zip(model.metrics_names, test_results)},
        "prediction_metrics": prediction_metrics  # Add detailed metrics
    }
    
    with open(os.path.join(cfg.SAVE_DIR, "results.json"), "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\n✅ Results saved!")
    print(f"\n📊 TEST RESULTS:")
    for name, value in results['test_results'].items():
        print(f"   {name:20s}: {value:.4f}")
    
    print(f"\n📁 Outputs saved to: {cfg.SAVE_DIR}/")
    print(f"   ├─ best_model.h5")
    print(f"   ├─ training_log.csv")
    print(f"   ├─ results.json")
    print(f"   ├─ predictions/ ({prediction_metrics['num_samples']} masks)")
    print(f"   └─ visualizations/ (2 figures)")
    
    return model, history, results

# ==============================================================================
# MAIN EXECUTION
# ==============================================================================

def main():
    print("\n" + "="*80)
    print("🔥 MASL TRAINING WITH ABLATION STUDIES")
    print("="*80)
    
    # Initialize config (reads ABLATION_ID from Config class)
    config = Config()
    
    # Setup GPUs
    strategy, num_gpus = setup_gpus(config.GPU_NUMBERS)
    
    # Train model
    model, history, results = train_ablation_model(config, strategy, num_gpus)
    
    if results is not None:
        print("\n" + "="*80)
        print("🎉 ABLATION STUDY COMPLETE!")
        print("="*80)
        print(f"\n📊 Summary:")
        print(f"   Ablation: {config.ABLATION_CONFIG['name']}")
        # print(f"   Test Dice: {results['test_results']['dice_coefficient']:.4f}")
        print(f"   Prediction Dice: {results['prediction_metrics']['mean_dice']:.4f} ± {results['prediction_metrics']['std_dice']:.4f}")
        print(f"   Parameters: {results['total_parameters']:,}")
        print(f"   Time: {results['training_time_minutes']:.1f} min")

if __name__ == "__main__":
    main()