#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
================================================================================
🔥 ENDOVIS17 - COMPLETE: MULTIPLE ARCHITECTURES + ALL FIXES
================================================================================
FINAL VERSION - ALL DIMENSION ISSUES RESOLVED
================================================================================
"""

import os
import sys
import time
import json
from glob import glob
from pathlib import Path

import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, GlobalAveragePooling2D,
    Reshape, Multiply, BatchNormalization, Activation, Dropout,
    Concatenate, Add, UpSampling2D, LayerNormalization, Dense, Layer,
    MaxPooling2D, Conv2DTranspose, MultiHeadAttention, Embedding
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau
)
import albumentations as A
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# ==============================================================================
# CONFIGURATION
# ==============================================================================

class Config:
    GPU_NUMBERS = [0]
    
    DATA_ROOT = "/kaggle/input/endovis-17/PartsSegmentation"
    TRAIN_DIR = os.path.join(DATA_ROOT, "train")
    VAL_DIR   = os.path.join(DATA_ROOT, "val")
    TEST_DIR  = os.path.join(DATA_ROOT, "test")
    SAVE_DIR  = "/kaggle/working/ENDOVIS17_COMPLETE_FIXED"
    
    NUM_CLASSES = 6
    
    @property
    def CLASS_NAMES(self):
        return [f"Class {i}" if i == 0 else f"Instrument Part {i}" 
               for i in range(self.NUM_CLASSES)]
    
    @property
    def CLASS_COLORS(self):
        import colorsys
        colors = [[0, 0, 0]]
        for i in range(1, self.NUM_CLASSES):
            hue = (i - 1) / max(self.NUM_CLASSES - 1, 1)
            rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
            colors.append([int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)])
        return np.array(colors, dtype=np.uint8)
    
    # Architecture
    ARCHITECTURE = "transunet"  # "medsegnet_ssf", "umamba", "transunet", "pranet"
    
    # MedSegNet-SSF
    INPUT_SIZE  = 512
    F1, F2, F3, F4, F5 = 24, 32, 64, 80, 128
    USE_MRF_SE = True
    USE_SSTM   = True
    USE_BFP    = True
    MRF_KERNELS = [3, 5, 7]
    SE_REDUCTION = 16
    EXPAND_RATIO = 6
    DROPOUT = 0.15
    L2_REG = 1e-4
    SSTM_NUM_FREQUENCIES = 32
    SSTM_SSM_STATE_DIM = 16
    SSTM_USE_SPECTRAL = [True, True, True, True, True]
    SSTM_USE_SSM = [False, False, True, True, True]
    SSTM_DROPOUT = 0.1
    
    # Baseline architectures
    UMAMBA_CHANNELS = [32, 64, 128, 256, 512]
    UMAMBA_SSM_DIM = 16
    TRANSUNET_NUM_HEADS = 4
    TRANSUNET_TRANSFORMER_LAYERS = 2
    PRANET_CHANNELS = [32, 64, 128, 256]
    
    # Training
    BATCH_SIZE    = 2
    EPOCH_EXPANSION_FACTOR = 30
    EPOCHS        = 30
    LEARNING_RATE = 1e-4
    USE_CLASS_WEIGHTS = True
    USE_MIXED_PRECISION = True
    EARLY_STOPPING_PATIENCE = 10
    CHECKPOINT_MONITOR      = "val_overall_dice"
    CHECKPOINT_MODE         = "max"
    SEED          = 42
    NUM_VIS_SAMPLES = 5
    
    # Loss
    LOSS_TYPE = "dice_focal"  # "dice_focal", "multiclass_masl"
    DICE_WEIGHT = 0.5
    FOCAL_WEIGHT = 0.5
    FOCAL_ALPHA = 0.25
    FOCAL_GAMMA = 2.0

    def __init__(self):
        self.SAVE_DIR = f"/kaggle/working/ENDOVIS17_{self.ARCHITECTURE.upper()}"
        os.makedirs(self.SAVE_DIR, exist_ok=True)
        print(f"🔥 Architecture: {self.ARCHITECTURE.upper()}")
        print(f"   Loss: {self.LOSS_TYPE.upper()}")

config = Config()

# ==============================================================================
# GPU & SEED
# ==============================================================================

def setup_gpus(gpu_numbers=None, use_mixed_precision=True):
    gpus = tf.config.list_physical_devices('GPU')
    if not gpus:
        return tf.distribute.get_strategy(), 0
    
    if gpu_numbers:
        selected_gpus = [gpus[i] for i in gpu_numbers if i < len(gpus)]
    else:
        selected_gpus = gpus
    
    try:
        tf.config.set_visible_devices(selected_gpus, 'GPU')
        for gpu in selected_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        
        if use_mixed_precision:
            from tensorflow.keras import mixed_precision
            policy = mixed_precision.Policy('mixed_float16')
            mixed_precision.set_global_policy(policy)
            print("✅ Mixed precision enabled")
        
        num_gpus = len(selected_gpus)
        strategy = tf.distribute.MirroredStrategy() if num_gpus > 1 else tf.distribute.get_strategy()
        print(f"✅ Using {num_gpus} GPU(s)")
        return strategy, num_gpus
    except RuntimeError:
        return tf.distribute.get_strategy(), 0

def set_seed(seed=42):
    import random
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

strategy, num_gpus = setup_gpus(config.GPU_NUMBERS, config.USE_MIXED_PRECISION)

# ==============================================================================
# DATA LOADING
# ==============================================================================

def get_image_mask_pairs(images_dir, masks_dir):
    image_files = sorted(glob(os.path.join(images_dir, "*.png")) + 
                        glob(os.path.join(images_dir, "*.jpg")))
    pairs = []
    for img_path in image_files:
        mask_path = os.path.join(masks_dir, Path(img_path).name)
        if os.path.exists(mask_path):
            pairs.append((img_path, mask_path))
    return pairs

def load_dataset_split(split_dir):
    images_dir = os.path.join(split_dir, "images")
    masks_dir = os.path.join(split_dir, "masks")
    return get_image_mask_pairs(images_dir, masks_dir)

def create_global_pixel_mapping(all_splits):
    print("\n" + "="*80)
    print("🔥 CREATING GLOBAL PIXEL MAPPING")
    print("="*80)
    
    all_pixel_values = set()
    
    for split_name, pairs in all_splits.items():
        print(f"\nScanning {split_name} ({len(pairs)} images)...")
        split_pixels = set()
        
        for img_path, mask_path in pairs:
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                continue
            split_pixels.update(np.unique(mask).tolist())
        
        print(f"   Found pixels: {sorted(split_pixels)}")
        all_pixel_values.update(split_pixels)
    
    sorted_pixel_values = sorted(list(all_pixel_values))
    global_mapping = {pixel_val: class_idx 
                     for class_idx, pixel_val in enumerate(sorted_pixel_values)}
    num_classes = len(sorted_pixel_values)
    
    print("\n" + "="*80)
    print("✅ GLOBAL MAPPING CREATED")
    print("="*80)
    print(f"   Pixels: {sorted_pixel_values}")
    print(f"   Mapping: {global_mapping}")
    print(f"   NUM_CLASSES: {num_classes}")
    print("="*80)
    
    return global_mapping, num_classes

def preprocess_mask_with_mapping(mask, pixel_mapping):
    output_mask = np.zeros_like(mask, dtype=np.uint8)
    for pixel_val, class_idx in pixel_mapping.items():
        output_mask[mask == pixel_val] = class_idx
    return output_mask

def calculate_class_weights(pairs, cfg, pixel_mapping, num_samples=100):
    print("\n🔍 Calculating class weights...")
    
    class_counts = np.zeros(cfg.NUM_CLASSES, dtype=np.int64)
    
    for img_path, mask_path in pairs[:min(num_samples, len(pairs))]:
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        
        mask = preprocess_mask_with_mapping(mask, pixel_mapping)
        for class_idx in range(cfg.NUM_CLASSES):
            class_counts[class_idx] += np.sum(mask == class_idx)
    
    total_pixels = np.sum(class_counts)
    class_weights = np.zeros(cfg.NUM_CLASSES, dtype=np.float32)
    
    for class_idx in range(cfg.NUM_CLASSES):
        if class_counts[class_idx] > 0:
            class_weights[class_idx] = total_pixels / (cfg.NUM_CLASSES * class_counts[class_idx])
        else:
            class_weights[class_idx] = 0.0
    
    non_zero_sum = np.sum(class_weights[class_weights > 0])
    if non_zero_sum > 0:
        class_weights = class_weights / non_zero_sum * cfg.NUM_CLASSES
    
    print(f"\n📊 Class Distribution:")
    for class_idx in range(cfg.NUM_CLASSES):
        if class_counts[class_idx] > 0:
            pct = (class_counts[class_idx] / total_pixels) * 100
            print(f"   Class {class_idx}: {class_counts[class_idx]:>10,} pixels ({pct:>5.2f}%), weight={class_weights[class_idx]:.4f}")
    
    return class_weights, class_counts

# ==============================================================================
# AUGMENTATION
# ==============================================================================

def get_surgical_augmentation(cfg):
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=15, p=0.7),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.8),
        A.GaussianBlur(blur_limit=(3, 5), p=0.2),
        A.GaussNoise(p=0.3),
        A.Resize(height=cfg.INPUT_SIZE, width=cfg.INPUT_SIZE),
    ], p=1.0)

def get_validation_augmentation(cfg):
    return A.Compose([A.Resize(cfg.INPUT_SIZE, cfg.INPUT_SIZE)])

# ==============================================================================
# DATA GENERATOR
# ==============================================================================

class EndoVis17Generator(tf.keras.utils.Sequence):
    def __init__(self, pairs, cfg, pixel_mapping, augmentation=None, shuffle=True, expansion_factor=1):
        self.pairs = pairs
        self.cfg = cfg
        self.pixel_mapping = pixel_mapping
        self.augmentation = augmentation
        self.shuffle = shuffle
        self.expansion_factor = expansion_factor
        self.indices = np.arange(len(self.pairs))
        
        self.real_batches = len(self.pairs) // self.cfg.BATCH_SIZE
        self.virtual_batches = self.real_batches * self.expansion_factor
        
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __len__(self):
        return self.virtual_batches

    def __getitem__(self, index):
        real_index_ptr = index % self.real_batches
        batch_start = real_index_ptr * self.cfg.BATCH_SIZE
        batch_end = batch_start + self.cfg.BATCH_SIZE
        batch_indices = self.indices[batch_start:batch_end]
        
        images, masks = [], []
        for idx in batch_indices:
            img_path, mask_path = self.pairs[idx]
            
            image = cv2.imread(img_path)
            if image is None:
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                continue
            
            mask = preprocess_mask_with_mapping(mask, self.pixel_mapping)
            
            if self.augmentation:
                augmented = self.augmentation(image=image, mask=mask)
                image = augmented["image"]
                mask = augmented["mask"]
            
            image = image.astype(np.float32) / 255.0
            mask_one_hot = np.eye(self.cfg.NUM_CLASSES, dtype=np.float32)[mask.astype(np.int32)]
            
            images.append(image)
            masks.append(mask_one_hot)
            
        return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

# ==============================================================================
# ARCHITECTURE 1: MEDSEGNET-SSF
# ==============================================================================

class SpectralSelectiveTokenMixer(Layer):
    def __init__(self, channels, num_frequencies=32, ssm_state_dim=16, 
                 use_spectral=True, use_ssm=True, dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_frequencies = num_frequencies
        self.ssm_state_dim = ssm_state_dim
        self.use_spectral = use_spectral
        self.use_ssm = use_ssm
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        input_h, input_w = input_shape[1], input_shape[2]
        self.actual_frequencies = min(self.num_frequencies, input_h, input_w) if input_h else self.num_frequencies
        
        if self.use_spectral:
            self.freq_weights_real = self.add_weight(
                name='freq_weights_real',
                shape=(self.actual_frequencies, self.actual_frequencies, self.channels),
                initializer=self._get_initializer(),
                trainable=True
            )
            self.spectral_norm = LayerNormalization(epsilon=1e-6, name='spectral_norm')
        
        if self.use_ssm:
            self.ssm_C = Dense(self.channels, name='ssm_C')
            self.selection_gate = Dense(self.channels, activation='sigmoid', name='selection')
            self.ssm_norm = LayerNormalization(epsilon=1e-6, name='ssm_norm')
        
        if self.use_spectral and self.use_ssm:
            self.fusion = Dense(self.channels, name='fusion')
            self.fusion_norm = LayerNormalization(epsilon=1e-6, name='fusion_norm')
        
        self.norm = LayerNormalization(epsilon=1e-6, name='norm')
        super().build(input_shape)
    
    def _get_initializer(self):
        def init_fn(shape, dtype=None):
            H, W, C = shape
            freq_h = np.fft.fftfreq(H)[:, np.newaxis]
            freq_w = np.fft.fftfreq(W)[np.newaxis, :]
            freq_magnitude = np.sqrt(freq_h**2 + freq_w**2)
            gaussian = np.exp(-((freq_magnitude - 0.25)**2) / (2 * 0.15**2))
            gaussian = np.repeat(gaussian[:, :, np.newaxis], C, axis=2)
            return gaussian.astype(np.float32) * 0.5
        return init_fn
    
    def spectral_path(self, x):
        H, W = tf.shape(x)[1], tf.shape(x)[2]
        freq_size = tf.minimum(tf.minimum(H, W), self.actual_frequencies)
        x_complex = tf.cast(x, tf.complex64)
        x_freq = tf.signal.fft2d(x_complex)
        x_freq_real = tf.math.real(x_freq)
        x_freq_imag = tf.math.imag(x_freq)
        x_freq_real_resized = tf.image.resize(x_freq_real, [freq_size, freq_size], method='bilinear')
        x_freq_imag_resized = tf.image.resize(x_freq_imag, [freq_size, freq_size], method='bilinear')
        x_freq_resized = tf.complex(x_freq_real_resized, x_freq_imag_resized)
        freq_filter = tf.cast(self.freq_weights_real[:freq_size, :freq_size, :], tf.complex64)
        x_freq_filtered = x_freq_resized * freq_filter
        x_freq_filt_real = tf.math.real(x_freq_filtered)
        x_freq_filt_imag = tf.math.imag(x_freq_filtered)
        x_freq_back_real = tf.image.resize(x_freq_filt_real, [H, W], method='bilinear')
        x_freq_back_imag = tf.image.resize(x_freq_filt_imag, [H, W], method='bilinear')
        x_freq_back = tf.complex(x_freq_back_real, x_freq_back_imag)
        x_spatial = tf.signal.ifft2d(x_freq_back)
        return self.spectral_norm(tf.math.real(x_spatial))
    
    def ssm_path(self, x):
        B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x_flat = tf.reshape(x, [B, H * W, self.channels])
        selection = self.selection_gate(x_flat)
        x_selected = x_flat * selection
        x_ssm = self.ssm_C(x_selected)
        return self.ssm_norm(tf.reshape(x_ssm, [B, H, W, self.channels]))
    
    def call(self, x, training=None):
        outputs = []
        if self.use_spectral:
            outputs.append(self.spectral_path(x))
        if self.use_ssm:
            outputs.append(self.ssm_path(x))
        
        if len(outputs) == 2:
            fused = self.fusion_norm(self.fusion(tf.concat(outputs, axis=-1)))
        elif len(outputs) == 1:
            fused = outputs[0]
        else:
            fused = x
        
        fused = self.norm(fused)
        if training and self.dropout_rate > 0:
            fused = tf.nn.dropout(fused, rate=self.dropout_rate)
        return x + fused

def MRF_SE_BLOCK(x, filters, activation='elu', dropout=0.0, expand_ratio=6, 
                 regularizer=0.0, kernels=[3, 5, 7], se_reduction=16, name='mrf_se'):
    F_expanded = filters * expand_ratio
    
    conv = Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                  name=name+'_expand')(x) if expand_ratio > 1 else x
    conv = Activation(activation, name=name+'_expand_act')(
        BatchNormalization(name=name+'_expand_bn')(conv))
    
    features = []
    for k in kernels:
        dw = DepthwiseConv2D((k, k), padding='same',
                            depthwise_initializer='he_uniform',
                            depthwise_regularizer=l2(regularizer) if regularizer > 0 else None,
                            name=f"{name}_dw{k}x{k}")(conv)
        features.append(Activation(activation, name=f"{name}_dw{k}x{k}_act")(
            BatchNormalization(name=f"{name}_dw{k}x{k}_bn")(dw)))
    
    combined = Concatenate(name=name+'_concat')(features) if len(features) > 1 else features[0]
    if len(features) > 1:
        combined = Activation(activation, name=name+'_fuse_act')(
            BatchNormalization(name=name+'_fuse_bn')(
                Conv2D(F_expanded, (1, 1), padding='same',
                      kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_fuse')(combined)))
    
    gap = Reshape((1, 1, F_expanded), name=name+'_reshape')(
        GlobalAveragePooling2D(name=name+'_gap')(combined))
    se = Conv2D(F_expanded, (1, 1), activation='sigmoid',
               kernel_initializer='he_uniform',
               name=name+'_se_expand')(
        Conv2D(max(F_expanded//se_reduction, 8), (1, 1),
              activation=activation,
              kernel_initializer='he_uniform',
              name=name+'_se_reduce')(gap))
    
    projected = Conv2D(filters, (1, 1), padding='same',
                      kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_project')(Multiply(name=name+'_se_mult')([combined, se]))
    projected = BatchNormalization(name=name+'_project_bn')(projected)
    
    if dropout > 0:
        projected = Dropout(dropout, name=name+'_dropout')(projected)
    
    out = Add(name=name+'_add')([projected, x])
    return out

def boundary_detection_module(features, filters, name='boundary'):
    boundary_conv = Conv2D(filters // 2, (3, 3), padding='same',
                          activation='relu', name=name + '_conv')(features)
    boundary_map = Conv2D(1, (1, 1), padding='same',
                         activation='sigmoid', name=name + '_map')(boundary_conv)
    return Multiply(name=name + '_mult')([features, boundary_map]), boundary_map

def BFP_decoder_stage(decoder_input, skip_features, filters, stage_name='bfp'):
    region = Concatenate(name=stage_name+'_concat')([
        UpSampling2D((2, 2), name=stage_name+'_up')(decoder_input),
        skip_features
    ])
    
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv1')(region)))
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv2')(region)))
    
    boundary_features, boundary_map = boundary_detection_module(region, filters, stage_name+'_boundary')
    
    boundary_refined = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_boundary_refine')(boundary_features)))
    
    output = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (1, 1), padding='same', name=stage_name+'_fusion')(
            region * (1 - boundary_map) + boundary_refined * boundary_map)))
    
    return output, boundary_map

def build_medsegnet_ssf(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING MED-SEGNET-SSF")
    print(f"   Input: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x3")
    print(f"   Output: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x{cfg.NUM_CLASSES}")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    x = Conv2D(16, (3, 3), padding='same', kernel_initializer='he_uniform', name='stem_conv')(inp)
    x = BatchNormalization(name='stem_bn')(x)
    x = Activation('elu', name='stem_act')(x)
    
    encoder_outputs = []
    filters = [cfg.F1, cfg.F2, cfg.F3, cfg.F4, cfg.F5]
    
    for i, f in enumerate(filters):
        x = Conv2D(f, (3, 3), strides=2, padding='same', kernel_initializer='he_uniform')(x)
        x = BatchNormalization()(x)
        x = Activation('elu')(x)
        
        if cfg.USE_MRF_SE:
            x = MRF_SE_BLOCK(x, f, activation='elu', dropout=cfg.DROPOUT,
                           expand_ratio=cfg.EXPAND_RATIO, regularizer=cfg.L2_REG,
                           kernels=cfg.MRF_KERNELS, se_reduction=cfg.SE_REDUCTION,
                           name=f'mrfse_stage{i+1}')
        
        if cfg.USE_SSTM:
            x = SpectralSelectiveTokenMixer(
                channels=f, num_frequencies=cfg.SSTM_NUM_FREQUENCIES,
                ssm_state_dim=cfg.SSTM_SSM_STATE_DIM,
                use_spectral=cfg.SSTM_USE_SPECTRAL[i],
                use_ssm=cfg.SSTM_USE_SSM[i],
                dropout=cfg.SSTM_DROPOUT,
                name=f'sstm_stage{i+1}'
            )(x)
        
        encoder_outputs.append(x)
    
    skip_connections = encoder_outputs[::-1]
    decoder = skip_connections[0]
    decoder_filters = filters[::-1][1:] + [16]
    
    for i, (skip, f) in enumerate(zip(skip_connections[1:], decoder_filters)):
        if cfg.USE_BFP:
            decoder, _ = BFP_decoder_stage(decoder, skip, f, stage_name=f'bfp_stage{i+1}')
    
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)
    decoder = Conv2D(16, (3, 3), padding='same', activation='relu')(decoder)
    
    out = Conv2D(cfg.NUM_CLASSES, (1, 1), padding='same', activation='softmax', 
                dtype='float32', name='output')(decoder)
    
    model = Model(inputs=inp, outputs=out, name="MedSegNet_SSF")
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# ARCHITECTURE 2: UMAMBA (FULLY FIXED)
# ==============================================================================

class SSMBlock(Layer):
    def __init__(self, channels, ssm_dim=16, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.ssm_dim = ssm_dim
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        self.input_proj = Dense(self.channels, name='input_proj')
        self.ssm_transform = Dense(self.channels, name='ssm_transform')
        self.output_proj = Dense(self.channels, name='output_proj')
        self.norm = LayerNormalization(epsilon=1e-6, name='norm')
        super().build(input_shape)
    
    def call(self, x, training=None):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], self.channels
        
        x_flat = tf.reshape(x, [B, H * W, C])
        x_proj = self.input_proj(x_flat)
        x_ssm = self.ssm_transform(x_proj)
        y = self.output_proj(x_ssm)
        y = self.norm(y)
        
        if training and self.dropout_rate > 0:
            y = tf.nn.dropout(y, rate=self.dropout_rate)
        
        y = tf.reshape(y, [B, H, W, C])
        return x + y

def umamba_encoder_block(x, filters, stage, dropout=0.1, ssm_dim=16):
    x = Conv2D(filters, (3, 3), strides=2, padding='same', 
              kernel_initializer='he_uniform', name=f'umamba_enc{stage}_down')(x)
    x = BatchNormalization(name=f'umamba_enc{stage}_bn1')(x)
    x = Activation('relu', name=f'umamba_enc{stage}_act1')(x)
    
    x = SSMBlock(filters, ssm_dim=ssm_dim, dropout=dropout, 
                name=f'umamba_enc{stage}_ssm')(x)
    
    x = Conv2D(filters, (3, 3), padding='same', 
              kernel_initializer='he_uniform', name=f'umamba_enc{stage}_refine')(x)
    x = BatchNormalization(name=f'umamba_enc{stage}_bn2')(x)
    x = Activation('relu', name=f'umamba_enc{stage}_act2')(x)
    
    return x

def umamba_decoder_block(x, skip, filters, stage, dropout=0.1):
    x = Conv2DTranspose(filters, (2, 2), strides=2, padding='same',
                       kernel_initializer='he_uniform', name=f'umamba_dec{stage}_up')(x)
    x = BatchNormalization(name=f'umamba_dec{stage}_bn1')(x)
    x = Activation('relu', name=f'umamba_dec{stage}_act1')(x)
    
    x = Concatenate(name=f'umamba_dec{stage}_concat')([x, skip])
    
    x = Conv2D(filters, (3, 3), padding='same',
              kernel_initializer='he_uniform', name=f'umamba_dec{stage}_conv1')(x)
    x = BatchNormalization(name=f'umamba_dec{stage}_bn2')(x)
    x = Activation('relu', name=f'umamba_dec{stage}_act2')(x)
    
    x = Conv2D(filters, (3, 3), padding='same',
              kernel_initializer='he_uniform', name=f'umamba_dec{stage}_conv2')(x)
    x = BatchNormalization(name=f'umamba_dec{stage}_bn3')(x)
    x = Activation('relu', name=f'umamba_dec{stage}_act3')(x)
    
    if dropout > 0:
        x = Dropout(dropout, name=f'umamba_dec{stage}_dropout')(x)
    
    return x

def build_umamba(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING UMAMBA (FULLY FIXED)")
    print(f"   Input: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x3")
    print(f"   Output: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x{cfg.NUM_CLASSES}")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    # Stem: 512x512@32
    x = Conv2D(cfg.UMAMBA_CHANNELS[0], (3, 3), padding='same',
              kernel_initializer='he_uniform', name='umamba_stem')(inp)
    x = BatchNormalization(name='umamba_stem_bn')(x)
    x = Activation('relu', name='umamba_stem_act')(x)
    
    # Encoder: 512 → 256 → 128 → 64 → 32
    encoder_outputs = [x]
    for i, filters in enumerate(cfg.UMAMBA_CHANNELS[1:], 1):
        x = umamba_encoder_block(x, filters, i, 
                               dropout=cfg.DROPOUT, ssm_dim=cfg.UMAMBA_SSM_DIM)
        encoder_outputs.append(x)
    # encoder_outputs: [512@32, 256@64, 128@128, 64@256, 32@512]
    
    # Decoder: 32 → 64 → 128 → 256 → 512
    x = encoder_outputs[-1]
    for i in range(len(encoder_outputs) - 1):
        skip = encoder_outputs[-(i+2)]
        filters = skip.shape[-1]
        x = umamba_decoder_block(x, skip, filters, i+1, dropout=cfg.DROPOUT)
    # Now x is 512@32 ✓
    
    out = Conv2D(cfg.NUM_CLASSES, (1, 1), padding='same', 
                activation='softmax', dtype='float32', name='output')(x)
    
    model = Model(inputs=inp, outputs=out, name="UMamba")
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# ARCHITECTURE 3: TRANSUNET (FULLY FIXED)
# ==============================================================================

class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        self.attn = MultiHeadAttention(
            num_heads=self.num_heads, key_dim=self.embed_dim // self.num_heads,
            dropout=self.dropout_rate, name='mha'
        )
        self.norm1 = LayerNormalization(epsilon=1e-6, name='norm1')
        
        mlp_hidden_dim = int(self.embed_dim * self.mlp_ratio)
        self.mlp = tf.keras.Sequential([
            Dense(mlp_hidden_dim, activation='gelu', name='mlp_fc1'),
            Dropout(self.dropout_rate),
            Dense(self.embed_dim, name='mlp_fc2'),
            Dropout(self.dropout_rate)
        ], name='mlp')
        self.norm2 = LayerNormalization(epsilon=1e-6, name='norm2')
        
        super().build(input_shape)
    
    def call(self, x, training=None):
        x = x + self.attn(self.norm1(x), self.norm1(x), training=training)
        x = x + self.mlp(self.norm2(x), training=training)
        return x

def transunet_encoder_block(x, filters, stage):
    x = Conv2D(filters, (3, 3), strides=2, padding='same',
              kernel_initializer='he_uniform', name=f'transunet_enc{stage}_conv1')(x)
    x = BatchNormalization(name=f'transunet_enc{stage}_bn1')(x)
    x = Activation('relu', name=f'transunet_enc{stage}_act1')(x)
    
    x = Conv2D(filters, (3, 3), padding='same',
              kernel_initializer='he_uniform', name=f'transunet_enc{stage}_conv2')(x)
    x = BatchNormalization(name=f'transunet_enc{stage}_bn2')(x)
    x = Activation('relu', name=f'transunet_enc{stage}_act2')(x)
    
    return x

def transunet_decoder_block(x, skip, filters, stage):
    x = Conv2DTranspose(filters, (2, 2), strides=2, padding='same',
                       kernel_initializer='he_uniform', name=f'transunet_dec{stage}_up')(x)
    x = BatchNormalization(name=f'transunet_dec{stage}_bn1')(x)
    x = Activation('relu', name=f'transunet_dec{stage}_act1')(x)
    
    x = Concatenate(name=f'transunet_dec{stage}_concat')([x, skip])
    
    x = Conv2D(filters, (3, 3), padding='same',
              kernel_initializer='he_uniform', name=f'transunet_dec{stage}_conv1')(x)
    x = BatchNormalization(name=f'transunet_dec{stage}_bn2')(x)
    x = Activation('relu', name=f'transunet_dec{stage}_act2')(x)
    
    x = Conv2D(filters, (3, 3), padding='same',
              kernel_initializer='he_uniform', name=f'transunet_dec{stage}_conv2')(x)
    x = BatchNormalization(name=f'transunet_dec{stage}_bn3')(x)
    x = Activation('relu', name=f'transunet_dec{stage}_act3')(x)
    
    return x

def build_transunet(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING TRANSUNET (FULLY FIXED)")
    print(f"   Input: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x3")
    print(f"   Output: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x{cfg.NUM_CLASSES}")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    channels = [32, 64, 128, 256]
    x = Conv2D(channels[0], (3, 3), padding='same',
              kernel_initializer='he_uniform', name='transunet_stem')(inp)
    x = BatchNormalization(name='transunet_stem_bn')(x)
    x = Activation('relu', name='transunet_stem_act')(x)
    
    encoder_outputs = [x]
    for i, filters in enumerate(channels[1:], 1):
        x = transunet_encoder_block(x, filters, i)
        encoder_outputs.append(x)
    # encoder_outputs: [512@32, 256@64, 128@128, 64@256]
    
    # Transformer bottleneck
    bottleneck_h = cfg.INPUT_SIZE // (2 ** len(channels))  # 64
    bottleneck_w = bottleneck_h
    
    x_flat = tf.reshape(x, [-1, bottleneck_h * bottleneck_w, channels[-1]])
    
    num_patches = bottleneck_h * bottleneck_w
    pos_encoding = tf.Variable(
        tf.random.normal([1, num_patches, channels[-1]], stddev=0.02),
        trainable=True, name='pos_encoding'
    )
    x_trans = x_flat + pos_encoding
    
    for i in range(cfg.TRANSUNET_TRANSFORMER_LAYERS):
        x_trans = TransformerBlock(
            channels[-1], cfg.TRANSUNET_NUM_HEADS, dropout=cfg.DROPOUT,
            name=f'transformer_block_{i}'
        )(x_trans)
    
    x = tf.reshape(x_trans, [-1, bottleneck_h, bottleneck_w, channels[-1]])
    
    # Decoder: 64 → 128 → 256 → 512
    skip_list = encoder_outputs[-2::-1]
    filters_list = channels[-2::-1]
    
    for i, (skip, filters) in enumerate(zip(skip_list, filters_list), 1):
        x = transunet_decoder_block(x, skip, filters, i)
    # Now x is 512@32 ✓
    
    out = Conv2D(cfg.NUM_CLASSES, (1, 1), padding='same',
                activation='softmax', dtype='float32', name='output')(x)
    
    model = Model(inputs=inp, outputs=out, name="TransUNet")
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# ARCHITECTURE 4: PRANET (FULLY FIXED)
# ==============================================================================

class RFB_Block(Layer):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__(**kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        
    def build(self, input_shape):
        self.branch1 = Conv2D(self.out_channels // 4, (1, 1), padding='same', name='rfb_branch1')
        
        self.branch2 = tf.keras.Sequential([
            Conv2D(self.out_channels // 4, (1, 1), name='rfb_br2_1'),
            Conv2D(self.out_channels // 4, (3, 3), dilation_rate=3, padding='same', name='rfb_br2_2')
        ], name='branch2')
        
        self.branch3 = tf.keras.Sequential([
            Conv2D(self.out_channels // 4, (1, 1), name='rfb_br3_1'),
            Conv2D(self.out_channels // 4, (3, 3), dilation_rate=5, padding='same', name='rfb_br3_2')
        ], name='branch3')
        
        self.branch4 = tf.keras.Sequential([
            Conv2D(self.out_channels // 4, (1, 1), name='rfb_br4_1'),
            Conv2D(self.out_channels // 4, (3, 3), dilation_rate=7, padding='same', name='rfb_br4_2')
        ], name='branch4')
        
        self.conv_cat = Conv2D(self.out_channels, (1, 1), padding='same', name='rfb_cat')
        self.bn = BatchNormalization(name='rfb_bn')
        
        super().build(input_shape)
    
    def call(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x)
        
        x_cat = Concatenate()([x1, x2, x3, x4])
        x_out = self.conv_cat(x_cat)
        return Activation('relu')(self.bn(x_out))

class ReverseAttention(Layer):
    def __init__(self, channels, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        
    def build(self, input_shape):
        self.conv1 = Conv2D(self.channels, (3, 3), padding='same', name='ra_conv1')
        self.bn1 = BatchNormalization(name='ra_bn1')
        
        self.conv2 = Conv2D(self.channels, (3, 3), padding='same', name='ra_conv2')
        self.bn2 = BatchNormalization(name='ra_bn2')
        
        self.conv_att = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name='ra_att')
        
        super().build(input_shape)
    
    def call(self, x):
        att = self.conv_att(x)
        
        x1 = Activation('relu')(self.bn1(self.conv1(x)))
        x1 = x1 * (1 - att)
        
        x2 = Activation('relu')(self.bn2(self.conv2(x1)))
        
        return x2 + x

def pranet_encoder_block(x, filters, stage):
    x = Conv2D(filters, (3, 3), strides=2, padding='same',
              kernel_initializer='he_uniform', name=f'pranet_enc{stage}_conv')(x)
    x = BatchNormalization(name=f'pranet_enc{stage}_bn')(x)
    x = Activation('relu', name=f'pranet_enc{stage}_act')(x)
    
    x = RFB_Block(filters, filters, name=f'pranet_enc{stage}_rfb')(x)
    
    return x

def build_pranet(cfg):
    print("\n" + "="*80)
    print("🔥 BUILDING PRANET (FULLY FIXED)")
    print(f"   Input: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x3")
    print(f"   Output: {cfg.INPUT_SIZE}x{cfg.INPUT_SIZE}x{cfg.NUM_CLASSES}")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    # Stem: 512x512@32
    x = Conv2D(cfg.PRANET_CHANNELS[0], (3, 3), padding='same',
              kernel_initializer='he_uniform', name='pranet_stem')(inp)
    x = BatchNormalization(name='pranet_stem_bn')(x)
    x = Activation('relu', name='pranet_stem_act')(x)
    
    # Encoder: 512 → 256 → 128 → 64
    encoder_outputs = [x]
    for i, filters in enumerate(cfg.PRANET_CHANNELS[1:], 1):
        x = pranet_encoder_block(x, filters, i)
        encoder_outputs.append(x)
    # encoder_outputs: [512@32, 256@64, 128@128, 64@256]
    
    # Bottleneck
    x = encoder_outputs[-1]
    x = ReverseAttention(cfg.PRANET_CHANNELS[-1], name='pranet_ra_bottleneck')(x)
    
    # Decoder: 64 → 128 → 256 → 512
    for i in range(len(encoder_outputs) - 1):
        skip = encoder_outputs[-(i+2)]
        filters = skip.shape[-1]
        
        x = Conv2DTranspose(filters, (2, 2), strides=2, padding='same',
                           kernel_initializer='he_uniform', name=f'pranet_dec{i+1}_up')(x)
        x = BatchNormalization(name=f'pranet_dec{i+1}_bn1')(x)
        x = Activation('relu', name=f'pranet_dec{i+1}_act1')(x)
        
        x = Concatenate(name=f'pranet_dec{i+1}_concat')([x, skip])
        
        x = ReverseAttention(filters, name=f'pranet_ra_dec{i+1}')(x)
        
        x = Conv2D(filters, (3, 3), padding='same',
                  kernel_initializer='he_uniform', name=f'pranet_dec{i+1}_refine')(x)
        x = BatchNormalization(name=f'pranet_dec{i+1}_bn2')(x)
        x = Activation('relu', name=f'pranet_dec{i+1}_act2')(x)
    # Now x is 512@32 ✓
    
    out = Conv2D(cfg.NUM_CLASSES, (1, 1), padding='same',
                activation='softmax', dtype='float32', name='output')(x)
    
    model = Model(inputs=inp, outputs=out, name="PraNet")
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# ARCHITECTURE FACTORY
# ==============================================================================

def build_model(cfg):
    architecture = cfg.ARCHITECTURE.lower()
    
    if architecture == "medsegnet_ssf":
        return build_medsegnet_ssf(cfg)
    elif architecture == "umamba":
        return build_umamba(cfg)
    elif architecture == "transunet":
        return build_transunet(cfg)
    elif architecture == "pranet":
        return build_pranet(cfg)
    else:
        raise ValueError(f"Unknown architecture: {cfg.ARCHITECTURE}")

# ==============================================================================
# LOSS FUNCTIONS
# ==============================================================================

def dice_loss_multiclass(class_weights=None):
    def loss_fn(y_true, y_pred):
        smooth = 1.0
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        
        y_true_fg = y_true[..., 1:]
        y_pred_fg = y_pred[..., 1:]
        
        intersection = tf.reduce_sum(y_true_fg * y_pred_fg, axis=[1, 2])
        union = tf.reduce_sum(y_true_fg, axis=[1, 2]) + tf.reduce_sum(y_pred_fg, axis=[1, 2])
        
        dice_per_class = (2.0 * intersection + smooth) / (union + smooth)
        
        if class_weights is not None:
            weights = tf.constant(class_weights[1:], dtype=tf.float32)
            valid_mask = tf.cast(weights > 0, tf.float32)
            weighted_dice = dice_per_class * valid_mask
            num_valid = tf.reduce_sum(valid_mask) + 1e-7
            mean_dice = tf.reduce_sum(weighted_dice) / num_valid
        else:
            mean_dice = tf.reduce_mean(dice_per_class)
        
        return 1.0 - mean_dice
    
    return loss_fn

def focal_loss_multiclass(class_weights=None, alpha=0.25, gamma=2.0):
    def loss_fn(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
        ce = -y_true * tf.math.log(y_pred)
        pt = tf.reduce_sum(y_true * y_pred, axis=-1)
        focal_weight = tf.pow(1.0 - pt, gamma)
        focal_ce = ce * focal_weight[..., tf.newaxis]
        
        if class_weights is not None:
            class_weights_tf = tf.constant(class_weights, dtype=tf.float32)
            class_indices = tf.argmax(y_true, axis=-1)
            pixel_weights = tf.gather(class_weights_tf, class_indices)
            focal_ce = focal_ce * pixel_weights[..., tf.newaxis]
        
        return alpha * tf.reduce_mean(focal_ce)
    
    return loss_fn

def combined_dice_focal_loss(class_weights=None, dice_weight=0.5, focal_weight=0.5, 
                             focal_alpha=0.25, focal_gamma=2.0):
    dice_loss_fn = dice_loss_multiclass(class_weights)
    focal_loss_fn = focal_loss_multiclass(class_weights, focal_alpha, focal_gamma)
    
    def loss_fn(y_true, y_pred):
        dice_component = dice_loss_fn(y_true, y_pred)
        focal_component = focal_loss_fn(y_true, y_pred)
        return dice_weight * dice_component + focal_weight * focal_component
    
    return loss_fn

class ClipConstraint(Constraint):
    def __init__(self, min_value=0.1, max_value=10.0):
        self.min_value = min_value
        self.max_value = max_value

    def __call__(self, w):
        return tf.clip_by_value(w, self.min_value, self.max_value)

class MulticlassMASL(Layer):
    def __init__(self, num_classes=6, class_weights=None, name='multiclass_masl', **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.class_weights = class_weights
        self.epsilon = 1e-6
        
    def build(self, input_shape):
        clip_constraint = ClipConstraint(min_value=0.1, max_value=10.0)
        
        self.w_region = self.add_weight(
            name='w_region', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint
        )
        self.w_boundary = self.add_weight(
            name='w_boundary', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint
        )
        self.w_structure = self.add_weight(
            name='w_structure', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint
        )
        
        super().build(input_shape)
    
    def morphological_dilation(self, x, kernel_size=5):
        return tf.nn.max_pool2d(x, kernel_size, strides=1, padding='SAME')
    
    def morphological_erosion(self, x, kernel_size=5):
        return -tf.nn.max_pool2d(-x, kernel_size, strides=1, padding='SAME')
    
    def detect_boundary(self, mask, kernel_size=5):
        dilated = self.morphological_dilation(mask, kernel_size)
        eroded = self.morphological_erosion(mask, kernel_size)
        boundary = dilated - eroded
        return tf.clip_by_value(boundary, 0.0, 1.0)
    
    def core_loss_per_class(self, y_true_class, y_pred_class):
        intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1, 2, 3])
        dice = (2. * intersection + self.epsilon) / (
            tf.reduce_sum(y_true_class, axis=[1, 2, 3]) + 
            tf.reduce_sum(y_pred_class, axis=[1, 2, 3]) + self.epsilon
        )
        dice_loss = 1.0 - tf.reduce_mean(dice)
        
        union = (tf.reduce_sum(y_true_class, axis=[1, 2, 3]) + 
                tf.reduce_sum(y_pred_class, axis=[1, 2, 3]) - intersection)
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        iou_loss = 1.0 - tf.reduce_mean(iou)
        
        boundary = self.detect_boundary(y_true_class, kernel_size=5)
        weights = 1.0 + 5.0 * boundary
        bce = -(y_true_class * tf.math.log(y_pred_class + self.epsilon) + 
               (1 - y_true_class) * tf.math.log(1 - y_pred_class + self.epsilon))
        weighted_bce = tf.reduce_mean(weights * bce)
        
        return 0.4 * dice_loss + 0.3 * iou_loss + 0.3 * weighted_bce
    
    def boundary_loss_per_class(self, y_true_class, y_pred_class):
        total_loss = 0.0
        weights = [0.5, 0.3, 0.2]
        
        for scale, w in zip([1, 2, 4], weights):
            dy_true = y_true_class[:, scale:, :, :] - y_true_class[:, :-scale, :, :]
            dy_pred = y_pred_class[:, scale:, :, :] - y_pred_class[:, :-scale, :, :]
            
            dx_true = y_true_class[:, :, scale:, :] - y_true_class[:, :, :-scale, :]
            dx_pred = y_pred_class[:, :, scale:, :] - y_pred_class[:, :, :-scale, :]
            
            total_loss += w * (tf.reduce_mean(tf.abs(dy_true - dy_pred)) + 
                             tf.reduce_mean(tf.abs(dx_true - dx_pred)))
        
        return total_loss
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.clip_by_value(y_pred, self.epsilon, 1.0 - self.epsilon)
        
        losses = []
        weights = []
        
        for class_idx in range(1, self.num_classes):
            y_true_class = y_true[..., class_idx:class_idx+1]
            y_pred_class = y_pred[..., class_idx:class_idx+1]
            
            if self.class_weights is not None:
                class_weight = float(self.class_weights[class_idx])
            else:
                class_weight = 1.0
            
            class_sum = tf.reduce_sum(y_true_class)
            class_has_pixels = tf.cast(class_sum > 0.0, tf.float32)
            effective_weight = class_weight * class_has_pixels
            
            l_core = self.core_loss_per_class(y_true_class, y_pred_class)
            l_boundary = self.boundary_loss_per_class(y_true_class, y_pred_class)
            
            class_loss = self.w_region * l_core + self.w_boundary * l_boundary
            
            losses.append(class_loss)
            weights.append(effective_weight)
        
        losses_tensor = tf.stack(losses)
        weights_tensor = tf.stack(weights)
        
        total_weight = tf.reduce_sum(weights_tensor) + self.epsilon
        weighted_loss = tf.reduce_sum(losses_tensor * weights_tensor) / total_weight
        
        return weighted_loss

def create_multiclass_masl_loss(num_classes, class_weights=None):
    masl_layer = MulticlassMASL(num_classes=num_classes, class_weights=class_weights)
    
    def loss_fn(y_true, y_pred):
        return masl_layer(y_true, y_pred)
    
    return loss_fn

# ==============================================================================
# METRICS
# ==============================================================================

def create_instrument_dice_metrics(num_classes, class_names, global_mapping):
    metrics = {}
    
    pixel_for_class = {class_idx: pixel_val 
                      for pixel_val, class_idx in global_mapping.items()}
    
    def overall_dice(y_true, y_pred):
        y_pred_class = tf.argmax(y_pred, axis=-1)
        y_true_class = tf.argmax(y_true, axis=-1)
        
        dices = []
        for c in range(1, num_classes):
            y_true_c = tf.cast(tf.equal(y_true_class, c), tf.float32)
            y_pred_c = tf.cast(tf.equal(y_pred_class, c), tf.float32)
            
            intersection = tf.reduce_sum(y_true_c * y_pred_c)
            dice = (2.0 * intersection + 1e-6) / (tf.reduce_sum(y_true_c) + tf.reduce_sum(y_pred_c) + 1e-6)
            dices.append(dice)
        
        return tf.reduce_mean(tf.stack(dices))
    
    metrics['overall_dice'] = overall_dice
    
    for class_idx in range(1, num_classes):
        pixel_val = pixel_for_class.get(class_idx, class_idx)
        
        def make_dice_metric(c):
            def instrument_dice(y_true, y_pred):
                y_pred_class = tf.argmax(y_pred, axis=-1)
                y_true_class = tf.argmax(y_true, axis=-1)
                
                y_true_c = tf.cast(tf.equal(y_true_class, c), tf.float32)
                y_pred_c = tf.cast(tf.equal(y_pred_class, c), tf.float32)
                
                intersection = tf.reduce_sum(y_true_c * y_pred_c)
                dice = (2.0 * intersection + 1e-6) / (tf.reduce_sum(y_true_c) + tf.reduce_sum(y_pred_c) + 1e-6)
                
                return dice
            return instrument_dice
        
        metric_name = f"pixel_{pixel_val}_dice"
        metrics[metric_name] = make_dice_metric(class_idx)
    
    return metrics

def mean_iou(y_true, y_pred, num_classes=None):
    if num_classes is None:
        num_classes = config.NUM_CLASSES
    y_pred_class = tf.argmax(y_pred, axis=-1)
    y_true_class = tf.argmax(y_true, axis=-1)
    
    ious = []
    for c in range(1, num_classes):
        y_true_c = tf.cast(tf.equal(y_true_class, c), tf.float32)
        y_pred_c = tf.cast(tf.equal(y_pred_class, c), tf.float32)
        
        intersection = tf.reduce_sum(y_true_c * y_pred_c)
        union = tf.reduce_sum(y_true_c) + tf.reduce_sum(y_pred_c) - intersection
        
        iou = (intersection + 1e-6) / (union + 1e-6)
        ious.append(iou)
    
    return tf.reduce_mean(tf.stack(ious))

# ==============================================================================
# VISUALIZATION
# ==============================================================================

def mask_to_color(mask, class_colors):
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx in range(len(class_colors)):
        color_mask[mask == class_idx] = class_colors[class_idx]
    return color_mask

def visualize_predictions(model, test_gen, cfg, num_samples=5, save_dir=None):
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    
    print("\n🎨 Generating Visualizations...")
    
    num_samples = min(num_samples, len(test_gen))
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        images, masks_onehot = test_gen[i]
        img = images[0:1]
        
        gt_mask = np.argmax(masks_onehot[0], axis=-1)
        gt_color = mask_to_color(gt_mask, cfg.CLASS_COLORS)
        
        pred_onehot = model.predict(img, verbose=0)[0]
        pred_mask = np.argmax(pred_onehot, axis=-1)
        pred_color = mask_to_color(pred_mask, cfg.CLASS_COLORS)
        
        class_dices = []
        for c in range(1, cfg.NUM_CLASSES):
            gt_c = (gt_mask == c).astype(np.float32)
            pred_c = (pred_mask == c).astype(np.float32)
            intersection = np.sum(gt_c * pred_c)
            dice = (2.0 * intersection) / (np.sum(gt_c) + np.sum(pred_c) + 1e-6)
            class_dices.append(dice)
        
        overall_dice = np.mean(class_dices)
        
        axes[i, 0].imshow(img[0])
        axes[i, 0].set_title(f'Image {i+1}' if i == 0 else '', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(gt_color)
        axes[i, 1].set_title('Ground Truth' if i == 0 else '', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(pred_color)
        axes[i, 2].set_title('Prediction' if i == 0 else '', fontsize=12, fontweight='bold')
        axes[i, 2].axis('off')
        
        overlay = (img[0] * 0.5 + pred_color / 255.0 * 0.5)
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title(f'Overlay (Dice={overall_dice:.3f})' if i == 0 else f'Dice={overall_dice:.3f}', 
                            fontsize=12, fontweight='bold')
        axes[i, 3].axis('off')
    
    legend_elements = [Patch(facecolor=cfg.CLASS_COLORS[i]/255.0, 
                             edgecolor='black', label=cfg.CLASS_NAMES[i]) 
                      for i in range(cfg.NUM_CLASSES)]
    fig.legend(handles=legend_elements, loc='lower center', ncol=min(4, cfg.NUM_CLASSES), 
              fontsize=11, frameon=True)
    
    plt.suptitle(f'{cfg.ARCHITECTURE.upper()} - Instrument Segmentation', 
                fontsize=16, weight='bold')
    plt.tight_layout(rect=[0, 0.05, 1, 0.98])
    
    if save_dir:
        save_path = os.path.join(save_dir, 'predictions.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"✅ Saved to {save_path}")
    
    plt.close()

# ==============================================================================
# TRAINING
# ==============================================================================

def train_endovis17(cfg, strategy, num_gpus):
    set_seed(cfg.SEED)
    
    print("\n" + "="*80)
    print("📊 LOADING DATASET")
    print("="*80)
    
    train_pairs = load_dataset_split(cfg.TRAIN_DIR)
    val_pairs = load_dataset_split(cfg.VAL_DIR)
    test_pairs = load_dataset_split(cfg.TEST_DIR)
    
    all_splits = {
        "TRAIN": train_pairs,
        "VAL": val_pairs,
        "TEST": test_pairs
    }
    
    global_pixel_mapping, num_classes = create_global_pixel_mapping(all_splits)
    cfg.NUM_CLASSES = num_classes
    
    print(f"\nDataset: Train={len(train_pairs)}, Val={len(val_pairs)}, Test={len(test_pairs)}")
    
    class_weights, class_counts = calculate_class_weights(train_pairs, cfg, global_pixel_mapping)
    
    train_aug = get_surgical_augmentation(cfg)
    val_aug = get_validation_augmentation(cfg)
    
    train_gen = EndoVis17Generator(train_pairs, cfg, global_pixel_mapping, train_aug, True, cfg.EPOCH_EXPANSION_FACTOR)
    val_gen = EndoVis17Generator(val_pairs, cfg, global_pixel_mapping, val_aug, False, 1)
    test_gen = EndoVis17Generator(test_pairs, cfg, global_pixel_mapping, val_aug, False, 1)
    
    print("\n✅ Global mapping applied")
    
    with strategy.scope():
        model = build_model(cfg)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)
        
        if cfg.USE_MIXED_PRECISION:
            optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
        
        print(f"\n🔥 LOSS: {cfg.LOSS_TYPE.upper()}")
        
        if cfg.LOSS_TYPE == "multiclass_masl":
            loss_fn = create_multiclass_masl_loss(
                num_classes=cfg.NUM_CLASSES,
                class_weights=class_weights
            )
        else:
            loss_fn = combined_dice_focal_loss(
                class_weights=class_weights,
                dice_weight=cfg.DICE_WEIGHT,
                focal_weight=cfg.FOCAL_WEIGHT,
                focal_alpha=cfg.FOCAL_ALPHA,
                focal_gamma=cfg.FOCAL_GAMMA
            )
        
        instrument_metrics = create_instrument_dice_metrics(
            cfg.NUM_CLASSES, 
            cfg.CLASS_NAMES,
            global_pixel_mapping
        )
        all_metrics = [mean_iou] + list(instrument_metrics.values())
        
        model.compile(optimizer=optimizer, loss=loss_fn, metrics=all_metrics)
        
        print(f"\n✅ Model compiled")
    
    callbacks = [
        ModelCheckpoint(
            os.path.join(cfg.SAVE_DIR, "best_model.h5"),
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            save_best_only=True, verbose=1
        ),
        EarlyStopping(
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            patience=cfg.EARLY_STOPPING_PATIENCE, verbose=1, restore_best_weights=True
        ),
        ReduceLROnPlateau(
            monitor='val_overall_dice', factor=0.5, patience=5,
            verbose=1, mode='max', min_lr=1e-6
        ),
        CSVLogger(os.path.join(cfg.SAVE_DIR, "training_log.csv")),
    ]
    
    print(f"\n🚀 TRAINING ({cfg.EPOCHS} EPOCHS)")
    print("="*80)
    start_time = time.time()
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=cfg.EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    
    training_time = time.time() - start_time
    print(f"\n✅ Training finished in {training_time/60:.1f} min")
    
    print("\n" + "="*80)
    print("📊 TEST EVALUATION")
    print("="*80)
    test_results = model.evaluate(test_gen, verbose=1)
    
    results = {
        "model": cfg.ARCHITECTURE.upper(),
        "loss_type": cfg.LOSS_TYPE,
        "num_classes": cfg.NUM_CLASSES,
        "global_mapping": {str(k): int(v) for k, v in global_pixel_mapping.items()},
        "training_time_minutes": training_time / 60,
        "test_metrics": {name: float(value) for name, value in zip(model.metrics_names, test_results)}
    }
    
    with open(os.path.join(cfg.SAVE_DIR, "results.json"), "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\n✅ Test Results:")
    for metric, value in results['test_metrics'].items():
        print(f"   {metric}: {value:.4f}")
    
    viz_dir = os.path.join(cfg.SAVE_DIR, "visualizations")
    visualize_predictions(model, test_gen, cfg, num_samples=cfg.NUM_VIS_SAMPLES, save_dir=viz_dir)
    
    print("\n" + "="*80)
    print("✅ COMPLETE!")
    print("="*80)
    print(f"\n📁 Results: {cfg.SAVE_DIR}/")
    
    return model, history

# ==============================================================================
# MAIN
# ==============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("🔥 ENDOVIS17 - ALL FIXES APPLIED")
    print("="*80)
    print(f"\n📌 Architecture: {config.ARCHITECTURE.upper()}")
    print(f"📌 Loss: {config.LOSS_TYPE.upper()}")
    print(f"📌 Batch Size: {config.BATCH_SIZE}")
    print(f"📌 Mixed Precision: {'ON' if config.USE_MIXED_PRECISION else 'OFF'}")
    print("\n✅ ALL FIXES:")
    print("   ✅ Global pixel mapping")
    print("   ✅ Pixel-based metrics")
    print("   ✅ All decoder dimensions fixed")
    print("   ✅ Memory optimizations")
    print("="*80)
    
    model, history = train_endovis17(config, strategy, num_gpus)
    
    if model is not None:
        print("\n🎉 DONE!")