#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
================================================================================
🔥 MED-SEGNET-SSF: SPECTRAL-SELECTIVE FEATURES FOR MEDICAL SEGMENTATION
================================================================================

FEATURES:
- ✅ MedSegNet-SSF Architecture (MRF-SE + SSTM + BFP + MASL)
- ✅ Comprehensive Spectral Truncation Analysis
- ✅ Multi-GPU Support

Date: 2025
================================================================================
"""

import os
import sys
import time
import json
from glob import glob
from pathlib import Path

import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, GlobalAveragePooling2D,
    Reshape, Multiply, BatchNormalization, Activation, Dropout,
    Concatenate, Add, UpSampling2D, LayerNormalization, Dense, Layer
)
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, CSVLogger, LearningRateScheduler
)
from tensorflow.keras import backend as K
import albumentations as A
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle, Patch
from scipy import ndimage
from scipy.stats import pearsonr, ttest_ind
from scipy.ndimage import distance_transform_edt
import pandas as pd
import seaborn as sns

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 11
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 14

# ==============================================================================
# 🔥 CONFIGURATION
# ==============================================================================

class Config:
    # ==================== GPU CONFIGURATION ====================
    GPU_NUMBERS = [0]
    
    # DATA PATHS
    DATA_ROOT = "/kaggle/input/endovis-17/BinarySegmentation"
    TRAIN_DIR = os.path.join(DATA_ROOT, "train")
    VAL_DIR   = os.path.join(DATA_ROOT, "val")
    TEST_DIR  = os.path.join(DATA_ROOT, "test")
    SAVE_DIR  = "MedSegNet_SSF_OUTPUT"

    # ==================== SPECTRAL ANALYSIS SETTINGS ====================
    ANALYZE_SPECTRAL_TRUNCATION = True
    TRUNCATION_SIZES = [16, 24, 32, 48, 64, 128]
    SPECTRAL_ANALYSIS_SAMPLES = 50

    # ==================== MODEL ARCHITECTURE ====================
    INPUT_SIZE  = 352
    F1, F2, F3, F4, F5 = 24, 32, 64, 80, 128
    
    USE_MRF_SE = True
    USE_SSTM   = True
    USE_BFP    = True
    MRF_KERNELS = [3, 5, 7]
    SE_REDUCTION = 16
    EXPAND_RATIO = 6
    SSTM_NUM_FREQUENCIES = 32  # K=32 (optimal from analysis)
    SSTM_SSM_STATE_DIM = 16
    SSTM_USE_SPECTRAL = [True, True, True, True, True]
    SSTM_USE_SSM = [False, False, True, True, True]
    SSTM_DROPOUT = 0.1
    
    DROPOUT = 0.1
    L2_REG = 1e-4

    # ==================== TRAINING SETTINGS ====================
    BATCH_SIZE    = 4
    EPOCH_EXPANSION_FACTOR = 30
    EPOCHS        = 30
    LEARNING_RATE = 1e-4
    
    EARLY_STOPPING_PATIENCE = 40
    CHECKPOINT_MONITOR      = "val_dice_coefficient"
    CHECKPOINT_MODE         = "max"

    SEED          = 42
    DETERMINISTIC = False

    def __init__(self):
        os.makedirs(self.SAVE_DIR, exist_ok=True)
        os.makedirs(os.path.join(self.SAVE_DIR, "figures"), exist_ok=True)
        os.makedirs(os.path.join(self.SAVE_DIR, "tables"), exist_ok=True)
        os.makedirs(os.path.join(self.SAVE_DIR, "predictions"), exist_ok=True)
        print(f"🔥 MED-SEGNET-SSF CONFIGURATION")
        print(f"   Virtual Epoch Factor: {self.EPOCH_EXPANSION_FACTOR}x")
        print(f"   SSTM Frequencies (K): {self.SSTM_NUM_FREQUENCIES}")
        print(f"   Spectral Analysis: {'ENABLED' if self.ANALYZE_SPECTRAL_TRUNCATION else 'DISABLED'}")

config = Config()

# ==============================================================================
# GPU SETUP
# ==============================================================================

def setup_gpus(gpu_numbers=None):
    """Configure GPUs based on specified GPU numbers."""
    gpus = tf.config.list_physical_devices('GPU')
    
    if not gpus:
        print("⚠️ No GPUs found! Using CPU.")
        return tf.distribute.get_strategy(), 0
    
    print(f"🔍 Total GPUs available: {len(gpus)}")
    
    if gpu_numbers is not None:
        selected_gpus = [gpus[i] for i in gpu_numbers if i < len(gpus)]
    else:
        selected_gpus = gpus
    
    try:
        tf.config.set_visible_devices(selected_gpus, 'GPU')
        for gpu in selected_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        
        num_gpus = len(selected_gpus)
        strategy = tf.distribute.MirroredStrategy() if num_gpus > 1 else tf.distribute.get_strategy()
        print(f"✅ Using {num_gpus} GPU(s)")
        return strategy, num_gpus
    except RuntimeError as e:
        print(f"⚠️ GPU setup error: {e}")
        return tf.distribute.get_strategy(), 0

strategy, num_gpus = setup_gpus(config.GPU_NUMBERS)

# ==============================================================================
# UTILS & AUGMENTATION
# ==============================================================================

def set_seed(seed=42, deterministic=False):
    import random
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_image_mask_pairs(images_dir, masks_dir):
    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.bmp']
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob(os.path.join(images_dir, ext)))
    image_files = sorted(image_files)
    
    if len(image_files) == 0:
        print(f"⚠️ No images found in {images_dir}")
        return []

    pairs = []
    for img_path in image_files:
        img_name = Path(img_path).stem
        possible_names = [
            f"{img_name}.png", f"{img_name}.jpg", f"{img_name}.tif", 
            f"{img_name}_mask.png", f"{img_name}_mask.jpg"
        ]
        for mask_name in possible_names:
            cand = os.path.join(masks_dir, mask_name)
            if os.path.exists(cand):
                pairs.append((img_path, cand))
                break
    return pairs

def load_dataset_split(split_dir):
    images_dir = os.path.join(split_dir, "images")
    masks_dir  = os.path.join(split_dir, "masks")
    return get_image_mask_pairs(images_dir, masks_dir)

def get_augmentation(cfg):
    """Aggressive Augmentation Pipeline."""
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.0625, 
            scale_limit=0.2, 
            rotate_limit=180, 
            border_mode=cv2.BORDER_CONSTANT,
            p=1.0
        ),
        A.ColorJitter(brightness=0.4, contrast=0.2, saturation=0.1, hue=0.01, p=1.0),
        A.Resize(height=cfg.INPUT_SIZE, width=cfg.INPUT_SIZE),
    ], p=1.0)

def get_validation_augmentation(cfg):
    return A.Compose([A.Resize(cfg.INPUT_SIZE, cfg.INPUT_SIZE)])

# ==============================================================================
# DATA GENERATOR
# ==============================================================================

class ExpandedGenerator(tf.keras.utils.Sequence):
    """Data generator with epoch expansion."""
    def __init__(self, pairs, cfg, augmentation=None, shuffle=True, expansion_factor=1):
        self.pairs = pairs
        self.cfg = cfg
        self.augmentation = augmentation
        self.shuffle = shuffle
        self.expansion_factor = expansion_factor
        self.indices = np.arange(len(self.pairs))
        
        self.real_batches = len(self.pairs) // self.cfg.BATCH_SIZE
        self.virtual_batches = self.real_batches * self.expansion_factor
        
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __len__(self):
        return self.virtual_batches

    def __getitem__(self, index):
        real_index_ptr = index % self.real_batches
        
        batch_start = real_index_ptr * self.cfg.BATCH_SIZE
        batch_end = batch_start + self.cfg.BATCH_SIZE
        batch_indices = self.indices[batch_start:batch_end]
        
        images, masks = [], []
        
        for idx in batch_indices:
            img_path, mask_path = self.pairs[idx]
            
            image = cv2.imread(img_path)
            if image is None: continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask = (mask > 127).astype(np.float32)
            
            if self.augmentation:
                augmented = self.augmentation(image=image, mask=mask)
                image = augmented["image"]
                mask = augmented["mask"]
            
            image = image.astype(np.float32) / 255.0
            if len(mask.shape) == 2:
                mask = np.expand_dims(mask, axis=-1)
            
            images.append(image)
            masks.append(mask)
            
        return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

# ==============================================================================
# 🔥 SPECTRAL ANALYSIS FUNCTIONS
# ==============================================================================

def compute_spectral_energy_retention(image, K):
    """Compute energy retention for truncation size K."""
    if len(image.shape) == 3:
        energies = []
        for c in range(image.shape[2]):
            channel = image[:, :, c]
            fft = np.fft.fft2(channel)
            fft_shifted = np.fft.fftshift(fft)
            
            total_energy = np.sum(np.abs(fft_shifted)**2)
            
            h, w = fft_shifted.shape
            cy, cx = h // 2, w // 2
            k_half = K // 2
            
            truncated = fft_shifted[cy-k_half:cy+k_half, cx-k_half:cx+k_half]
            truncated_energy = np.sum(np.abs(truncated)**2)
            
            energies.append(truncated_energy / (total_energy + 1e-10))
        
        return np.mean(energies)
    else:
        fft = np.fft.fft2(image)
        fft_shifted = np.fft.fftshift(fft)
        
        total_energy = np.sum(np.abs(fft_shifted)**2)
        
        h, w = fft_shifted.shape
        cy, cx = h // 2, w // 2
        k_half = K // 2
        
        truncated = fft_shifted[cy-k_half:cy+k_half, cx-k_half:cx+k_half]
        truncated_energy = np.sum(np.abs(truncated)**2)
        
        return truncated_energy / (total_energy + 1e-10)

def compute_reconstruction_error(image, K):
    """Compute reconstruction error (RMSE) for truncation size K."""
    if len(image.shape) == 3:
        gray = cv2.cvtColor((image * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
    else:
        gray = image
    
    fft = np.fft.fft2(gray)
    fft_shifted = np.fft.fftshift(fft)
    
    h, w = fft_shifted.shape
    cy, cx = h // 2, w // 2
    k_half = K // 2
    
    fft_truncated = np.zeros_like(fft_shifted)
    fft_truncated[cy-k_half:cy+k_half, cx-k_half:cx+k_half] = fft_shifted[cy-k_half:cy+k_half, cx-k_half:cx+k_half]
    
    fft_back = np.fft.ifftshift(fft_truncated)
    reconstructed = np.real(np.fft.ifft2(fft_back))
    
    rmse = np.sqrt(np.mean((gray - reconstructed)**2))
    
    return rmse, reconstructed

def analyze_spectral_concentration(pairs, cfg, modality_name="Dataset"):
    """Comprehensive spectral energy analysis."""
    print(f"\n{'='*80}")
    print(f"🔬 SPECTRAL ENERGY ANALYSIS: {modality_name}")
    print(f"{'='*80}")
    
    n_samples = min(cfg.SPECTRAL_ANALYSIS_SAMPLES, len(pairs))
    selected_pairs = np.random.choice(len(pairs), n_samples, replace=False)
    
    results = {K: {'energy': [], 'rmse': []} for K in cfg.TRUNCATION_SIZES}
    sample_images = []
    
    for idx in selected_pairs:
        img_path, _ = pairs[idx]
        image = cv2.imread(img_path)
        if image is None:
            continue
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (cfg.INPUT_SIZE, cfg.INPUT_SIZE))
        image = image.astype(np.float32) / 255.0
        
        if len(sample_images) < 3:
            sample_images.append(image.copy())
        
        for K in cfg.TRUNCATION_SIZES:
            if K < min(image.shape[0], image.shape[1]):
                energy_retention = compute_spectral_energy_retention(image, K)
                rmse, _ = compute_reconstruction_error(image, K)
                results[K]['energy'].append(energy_retention * 100)
                results[K]['rmse'].append(rmse)
    
    # Compute statistics
    stats = {}
    for K in cfg.TRUNCATION_SIZES:
        if results[K]['energy']:
            stats[K] = {
                'energy_mean': np.mean(results[K]['energy']),
                'energy_std': np.std(results[K]['energy']),
                'energy_min': np.min(results[K]['energy']),
                'energy_max': np.max(results[K]['energy']),
                'rmse_mean': np.mean(results[K]['rmse']),
                'rmse_std': np.std(results[K]['rmse']),
                'rmse_min': np.min(results[K]['rmse']),
                'rmse_max': np.max(results[K]['rmse'])
            }
    
    # Print table
    print(f"\n{'K Value':<10} {'Energy (%)':<20} {'RMSE':<20} {'Compression':<15}")
    print("-" * 75)
    for K in cfg.TRUNCATION_SIZES:
        if K in stats:
            s = stats[K]
            compression = (K**2 / cfg.INPUT_SIZE**2) * 100
            print(f"{K:<10} {s['energy_mean']:.2f}±{s['energy_std']:.2f}  "
                  f"{s['rmse_mean']:.4f}±{s['rmse_std']:.4f}  "
                  f"{compression:.2f}%")
    
    return stats, results, sample_images

def create_comprehensive_spectral_figure(stats, results, sample_images, cfg):
    """Create Figure 3: Comprehensive Spectral Analysis."""
    print("\n📊 Generating Figure 3: Comprehensive Spectral Analysis...")
    
    fig = plt.figure(figsize=(20, 12))
    gs = gridspec.GridSpec(3, 6, figure=fig, hspace=0.35, wspace=0.35,
                          left=0.05, right=0.95, top=0.95, bottom=0.05)
    
    # ==================== ROW 1: FFT VISUALIZATION ====================
    for i, img in enumerate(sample_images[:3]):
        ax_img = fig.add_subplot(gs[0, i*2])
        ax_img.imshow(img)
        ax_img.set_title(f'Sample {i+1}', fontweight='bold')
        ax_img.axis('off')
        
        ax_fft = fig.add_subplot(gs[0, i*2+1])
        gray = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        fft = np.fft.fft2(gray)
        fft_shifted = np.fft.fftshift(fft)
        magnitude = np.log(np.abs(fft_shifted) + 1)
        
        im = ax_fft.imshow(magnitude, cmap='hot', interpolation='bilinear')
        
        h, w = magnitude.shape
        cy, cx = h // 2, w // 2
        
        k_half = 16 // 2
        rect16 = Rectangle((cx-k_half, cy-k_half), 16, 16, linewidth=1.5,
                          edgecolor='#3498db', facecolor='none', label='K=16')
        ax_fft.add_patch(rect16)
        
        k_half = 32 // 2
        rect32 = Rectangle((cx-k_half, cy-k_half), 32, 32, linewidth=2.5,
                          edgecolor='#2ecc71', facecolor='none', label='K=32 (Optimal)', linestyle='--')
        ax_fft.add_patch(rect32)
        
        k_half = 64 // 2
        rect64 = Rectangle((cx-k_half, cy-k_half), 64, 64, linewidth=1.5,
                          edgecolor='#e74c3c', facecolor='none', label='K=64')
        ax_fft.add_patch(rect64)
        
        ax_fft.set_title('FFT Magnitude (log)', fontweight='bold')
        ax_fft.axis('off')
        
        if i == 2:
            ax_fft.legend(loc='upper right', fontsize=8, framealpha=0.9)
    
    # ==================== ROW 2: PERFORMANCE METRICS ====================
    
    # (A) Energy Retention
    ax_energy = fig.add_subplot(gs[1, 0:2])
    K_values = sorted(stats.keys())
    energy_means = [stats[K]['energy_mean'] for K in K_values]
    energy_stds = [stats[K]['energy_std'] for K in K_values]
    
    ax_energy.plot(K_values, energy_means, 'o-', linewidth=2.5, markersize=8,
                  color='#2980b9', label='Mean Energy Retention')
    ax_energy.fill_between(K_values,
                           [m - s for m, s in zip(energy_means, energy_stds)],
                           [m + s for m, s in zip(energy_means, energy_stds)],
                           alpha=0.2, color='#2980b9', label='±1 Std Dev')
    
    idx_32 = K_values.index(32) if 32 in K_values else None
    if idx_32 is not None:
        ax_energy.plot(K_values[idx_32], energy_means[idx_32], 'D',
                      markersize=12, color='#27ae60', label='K=32 (Optimal)',
                      markeredgecolor='black', markeredgewidth=1.5, zorder=5)
    
    ax_energy.axhline(y=95, color='#e74c3c', linestyle='--', linewidth=2,
                     alpha=0.7, label='95% Threshold')
    ax_energy.set_xlabel('Truncation Size K', fontweight='bold')
    ax_energy.set_ylabel('Energy Retention (%)', fontweight='bold')
    ax_energy.set_title('(A) Spectral Energy Retention', fontweight='bold', loc='left')
    ax_energy.legend(loc='lower right', framealpha=0.95)
    ax_energy.grid(True, alpha=0.3, linestyle='--')
    ax_energy.set_ylim([85, 102])
    
    # (B) RMSE
    ax_rmse = fig.add_subplot(gs[1, 2:4])
    rmse_means = [stats[K]['rmse_mean'] for K in K_values]
    rmse_stds = [stats[K]['rmse_std'] for K in K_values]
    
    ax_rmse.plot(K_values, rmse_means, 's-', linewidth=2.5, markersize=8,
                color='#e74c3c', label='Mean RMSE')
    ax_rmse.fill_between(K_values,
                        [m - s for m, s in zip(rmse_means, rmse_stds)],
                        [m + s for m, s in zip(rmse_means, rmse_stds)],
                        alpha=0.2, color='#e74c3c', label='±1 Std Dev')
    
    if idx_32 is not None:
        ax_rmse.plot(K_values[idx_32], rmse_means[idx_32], 'D',
                    markersize=12, color='#27ae60', label='K=32 (Optimal)',
                    markeredgecolor='black', markeredgewidth=1.5, zorder=5)
    
    ax_rmse.set_xlabel('Truncation Size K', fontweight='bold')
    ax_rmse.set_ylabel('Reconstruction RMSE', fontweight='bold')
    ax_rmse.set_title('(B) Reconstruction Error', fontweight='bold', loc='left')
    ax_rmse.legend(loc='upper right', framealpha=0.95)
    ax_rmse.grid(True, alpha=0.3, linestyle='--')
    ax_rmse.set_yscale('log')
    
    # (C) Pareto Frontier
    ax_pareto = fig.add_subplot(gs[1, 4:6])
    compression_ratios = [(K**2 / cfg.INPUT_SIZE**2) * 100 for K in K_values]
    
    scatter = ax_pareto.scatter(compression_ratios, energy_means,
                               c=K_values, s=200, cmap='viridis',
                               edgecolors='black', linewidths=1.5,
                               alpha=0.8, zorder=3)
    
    ax_pareto.plot(compression_ratios, energy_means, '--',
                  linewidth=1.5, color='gray', alpha=0.5, zorder=2)
    
    if idx_32 is not None:
        ax_pareto.scatter(compression_ratios[idx_32], energy_means[idx_32],
                         s=400, marker='*', color='#f1c40f',
                         edgecolors='black', linewidths=2,
                         label='K=32 (Optimal)', zorder=5)
    
    for i, K in enumerate(K_values):
        ax_pareto.annotate(f'K={K}',
                          xy=(compression_ratios[i], energy_means[i]),
                          xytext=(5, 5), textcoords='offset points',
                          fontsize=9, fontweight='bold')
    
    ax_pareto.set_xlabel('Spectral Coefficients Retained (%)', fontweight='bold')
    ax_pareto.set_ylabel('Energy Retention (%)', fontweight='bold')
    ax_pareto.set_title('(C) Pareto Frontier', fontweight='bold', loc='left')
    ax_pareto.legend(loc='lower right', framealpha=0.95)
    ax_pareto.grid(True, alpha=0.3, linestyle='--')
    
    cbar = plt.colorbar(scatter, ax=ax_pareto)
    cbar.set_label('K', rotation=270, labelpad=20, fontweight='bold')
    
    # ==================== ROW 3: DISTRIBUTION ANALYSIS ====================
    
    # (D) Energy Box Plot
    ax_box_energy = fig.add_subplot(gs[2, 0:2])
    box_data_energy = [results[K]['energy'] for K in K_values]
    bp1 = ax_box_energy.boxplot(box_data_energy, labels=K_values,
                                patch_artist=True, widths=0.6,
                                showmeans=True, meanline=True)
    
    colors_gradient = plt.cm.viridis(np.linspace(0, 1, len(K_values)))
    for patch, color in zip(bp1['boxes'], colors_gradient):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    
    if idx_32 is not None:
        bp1['boxes'][idx_32].set_edgecolor('#27ae60')
        bp1['boxes'][idx_32].set_linewidth(3)
    
    ax_box_energy.set_xlabel('Truncation Size K', fontweight='bold')
    ax_box_energy.set_ylabel('Energy Retention (%)', fontweight='bold')
    ax_box_energy.set_title('(D) Energy Distribution', fontweight='bold', loc='left')
    ax_box_energy.grid(True, alpha=0.3, linestyle='--', axis='y')
    ax_box_energy.axhline(y=95, color='#e74c3c', linestyle='--', linewidth=1.5, alpha=0.5)
    
    # (E) RMSE Box Plot
    ax_box_rmse = fig.add_subplot(gs[2, 2:4])
    box_data_rmse = [results[K]['rmse'] for K in K_values]
    bp2 = ax_box_rmse.boxplot(box_data_rmse, labels=K_values,
                              patch_artist=True, widths=0.6,
                              showmeans=True, meanline=True)
    
    for patch, color in zip(bp2['boxes'], colors_gradient):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    
    if idx_32 is not None:
        bp2['boxes'][idx_32].set_edgecolor('#27ae60')
        bp2['boxes'][idx_32].set_linewidth(3)
    
    ax_box_rmse.set_xlabel('Truncation Size K', fontweight='bold')
    ax_box_rmse.set_ylabel('Reconstruction RMSE', fontweight='bold')
    ax_box_rmse.set_title('(E) RMSE Distribution', fontweight='bold', loc='left')
    ax_box_rmse.grid(True, alpha=0.3, linestyle='--', axis='y')
    ax_box_rmse.set_yscale('log')
    
    # (F) Computational Complexity
    ax_complexity = fig.add_subplot(gs[2, 4:6])
    
    flops = [(K**2) * np.log2(K**2) for K in K_values]
    params = [K**2 * cfg.F3 for K in K_values]
    
    flops_norm = np.array(flops) / max(flops) * 100
    params_norm = np.array(params) / max(params) * 100
    
    x = np.arange(len(K_values))
    width = 0.35
    
    bars1 = ax_complexity.bar(x - width/2, flops_norm, width,
                              label='Relative FLOPs', color='#3498db', alpha=0.8)
    bars2 = ax_complexity.bar(x + width/2, params_norm, width,
                              label='Relative Parameters', color='#e74c3c', alpha=0.8)
    
    if idx_32 is not None:
        bars1[idx_32].set_edgecolor('#27ae60')
        bars1[idx_32].set_linewidth(3)
        bars2[idx_32].set_edgecolor('#27ae60')
        bars2[idx_32].set_linewidth(3)
    
    ax_complexity.set_xlabel('Truncation Size K', fontweight='bold')
    ax_complexity.set_ylabel('Normalized Cost (%)', fontweight='bold')
    ax_complexity.set_title('(F) Computational Complexity', fontweight='bold', loc='left')
    ax_complexity.set_xticks(x)
    ax_complexity.set_xticklabels(K_values)
    ax_complexity.legend(framealpha=0.95)
    ax_complexity.grid(True, alpha=0.3, linestyle='--', axis='y')
    
    fig.suptitle('Figure 3: Comprehensive Spectral Truncation Analysis',
                fontsize=16, fontweight='bold', y=0.98)
    
    save_path = os.path.join(cfg.SAVE_DIR, "figures", "figure3_comprehensive_spectral_analysis.pdf")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', format='pdf')
    plt.savefig(save_path.replace('.pdf', '.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"✅ Saved to: {save_path}")
    return save_path

def create_reconstruction_comparison_figure(sample_images, cfg):
    """Create Figure 4: Reconstruction Quality Comparison."""
    print("\n📊 Generating Figure 4: Reconstruction Quality...")
    
    K_values_viz = [16, 32, 48, 128]
    
    fig = plt.figure(figsize=(20, 10))
    gs = gridspec.GridSpec(len(sample_images), len(K_values_viz) + 2,
                          figure=fig, hspace=0.25, wspace=0.15)
    
    for row, img in enumerate(sample_images):
        ax_orig = fig.add_subplot(gs[row, 0])
        ax_orig.imshow(img)
        ax_orig.set_title('Original' if row == 0 else '', fontweight='bold')
        ax_orig.set_ylabel(f'Sample {row+1}', fontweight='bold', fontsize=12)
        ax_orig.axis('off')
        
        for col, K in enumerate(K_values_viz):
            rmse, reconstructed = compute_reconstruction_error(img, K)
            
            ax = fig.add_subplot(gs[row, col+1])
            
            reconstructed_rgb = np.stack([reconstructed]*3, axis=-1)
            ax.imshow(reconstructed_rgb, cmap='gray')
            
            title = f'K={K}' if row == 0 else ''
            ax.set_title(title, fontweight='bold', color='#27ae60' if K == 32 else 'black')
            
            ax.text(0.5, 0.05, f'RMSE: {rmse:.4f}',
                   transform=ax.transAxes, ha='center',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                   fontsize=9, fontweight='bold')
            
            if K == 32:
                for spine in ax.spines.values():
                    spine.set_edgecolor('#27ae60')
                    spine.set_linewidth(3)
            
            ax.axis('off')
        
        ax_error = fig.add_subplot(gs[row, len(K_values_viz) + 1])
        _, reconstructed_32 = compute_reconstruction_error(img, 32)
        
        gray_orig = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
        error = np.abs(gray_orig - reconstructed_32)
        
        im = ax_error.imshow(error, cmap='hot', vmin=0, vmax=0.1)
        ax_error.set_title('Error (K=32)' if row == 0 else '', fontweight='bold')
        ax_error.axis('off')
        
        if row == 0:
            cbar = plt.colorbar(im, ax=ax_error, fraction=0.046)
            cbar.set_label('Absolute Error', rotation=270, labelpad=15, fontsize=9)
    
    fig.suptitle('Figure 4: Visual Reconstruction Quality',
                fontsize=16, fontweight='bold')
    
    save_path = os.path.join(cfg.SAVE_DIR, "figures", "figure4_reconstruction_comparison.pdf")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', format='pdf')
    plt.savefig(save_path.replace('.pdf', '.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"✅ Saved to: {save_path}")
    return save_path

def create_frequency_band_analysis_figure(sample_images, cfg):
    """Create Figure 5: Frequency Band Analysis."""
    print("\n📊 Generating Figure 5: Frequency Band Analysis...")
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    bands = [
        (0, 8, 'Very Low'),
        (8, 16, 'Low'),
        (16, 32, 'Medium'),
        (32, 64, 'High'),
        (64, 128, 'Very High'),
        (128, 256, 'Ultra High')
    ]
    
    all_band_energies = {band[2]: [] for band in bands}
    
    for img in sample_images:
        gray = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        fft = np.fft.fft2(gray)
        fft_shifted = np.fft.fftshift(fft)
        
        h, w = fft_shifted.shape
        cy, cx = h // 2, w // 2
        
        total_energy = np.sum(np.abs(fft_shifted)**2)
        
        for start, end, name in bands:
            start_half = start // 2
            end_half = end // 2
            
            band_region = fft_shifted[cy-end_half:cy+end_half, cx-end_half:cx+end_half]
            if start > 0:
                inner_region = fft_shifted[cy-start_half:cy+start_half, cx-start_half:cx+start_half]
                band_energy = np.sum(np.abs(band_region)**2) - np.sum(np.abs(inner_region)**2)
            else:
                band_energy = np.sum(np.abs(band_region)**2)
            
            all_band_energies[name].append((band_energy / total_energy) * 100)
    
    for i, img in enumerate(sample_images[:5]):
        ax = axes[i]
        band_names = [b[2] for b in bands]
        sample_energies = [all_band_energies[name][i] for name in band_names]
        
        bars = ax.bar(band_names, sample_energies, color='#3498db', alpha=0.7, edgecolor='black')
        bars[2].set_color('#27ae60')
        bars[2].set_alpha(0.9)
        
        ax.set_ylabel('Energy (%)', fontweight='bold')
        ax.set_title(f'Sample {i+1}', fontweight='bold')
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_ylim([0, max(sample_energies) * 1.2])
        ax.set_xticklabels(band_names, rotation=45, ha='right')
    
    ax = axes[5]
    band_names = [b[2] for b in bands]
    avg_energies = [np.mean(all_band_energies[name]) for name in band_names]
    std_energies = [np.std(all_band_energies[name]) for name in band_names]
    
    bars = ax.bar(band_names, avg_energies, yerr=std_energies,
                  color='#3498db', alpha=0.7, edgecolor='black',
                  capsize=5, error_kw={'linewidth': 2})
    
    bars[2].set_color('#27ae60')
    bars[2].set_alpha(0.9)
    
    ax.set_ylabel('Energy (%)', fontweight='bold')
    ax.set_title('Average Across All Samples', fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_xticklabels(band_names, rotation=45, ha='right')
    
    ax.annotate('K=32 Band\n(Optimal)',
               xy=(2, avg_energies[2]), xytext=(2, avg_energies[2] + 5),
               fontsize=10, fontweight='bold', color='#27ae60',
               ha='center',
               arrowprops=dict(arrowstyle='->', color='#27ae60', lw=2))
    
    fig.suptitle('Figure 5: Frequency Band Energy Distribution',
                fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    save_path = os.path.join(cfg.SAVE_DIR, "figures", "figure5_frequency_band_analysis.pdf")
    plt.savefig(save_path, dpi=300, bbox_inches='tight', format='pdf')
    plt.savefig(save_path.replace('.pdf', '.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"✅ Saved to: {save_path}")
    return save_path

def create_publication_tables(stats, results, cfg):
    """Create publication-ready tables."""
    print("\n📋 Generating Publication Tables...")
    
    # Table 1: Spectral Energy Retention
    table1_data = []
    K_values = sorted(stats.keys())
    
    for K in K_values:
        s = stats[K]
        table1_data.append({
            'K': K,
            'Energy (%)': f"{s['energy_mean']:.2f} ± {s['energy_std']:.2f}",
            'Min (%)': f"{s['energy_min']:.2f}",
            'Max (%)': f"{s['energy_max']:.2f}",
            'RMSE': f"{s['rmse_mean']:.4f} ± {s['rmse_std']:.4f}",
            'Coefficients': f"{K**2:,}",
            'Compression (%)': f"{(K**2 / cfg.INPUT_SIZE**2) * 100:.2f}",
            'Relative Cost': f"{K**2 / (32**2):.2f}×"
        })
    
    df_table1 = pd.DataFrame(table1_data)
    table1_path = os.path.join(cfg.SAVE_DIR, "tables", "table1_spectral_energy_retention.csv")
    df_table1.to_csv(table1_path, index=False)
    
    # LaTeX format
    latex_path = table1_path.replace('.csv', '.tex')
    with open(latex_path, 'w') as f:
        f.write("% Table 1: Spectral Energy Retention\n")
        f.write("\\begin{table}[ht]\n")
        f.write("\\centering\n")
        f.write("\\caption{Spectral Energy Retention Analysis}\n")
        f.write("\\label{tab:spectral_energy}\n")
        f.write("\\begin{tabular}{cccccccc}\n")
        f.write("\\toprule\n")
        f.write("$K$ & Energy & Min & Max & RMSE & Coeff. & Compress. & Cost \\\\\n")
        f.write("\\midrule\n")
        for _, row in df_table1.iterrows():
            if row['K'] == 32:
                f.write("\\rowcolor{green!20}\n")
            f.write(f"{row['K']} & {row['Energy (%)']} & {row['Min (%)']} & "
                   f"{row['Max (%)']} & {row['RMSE']} & {row['Coefficients']} & "
                   f"{row['Compression (%)']} & {row['Relative Cost']} \\\\\n")
        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n")
    
    print(f"✅ Table 1 saved to: {table1_path}")
    print("\n" + df_table1.to_string(index=False))
    
    # Table 2: Statistical Significance
    table2_data = []
    baseline_K = 32
    
    for K in K_values:
        if K == baseline_K:
            continue
        
        t_stat_energy, p_value_energy = ttest_ind(
            results[baseline_K]['energy'],
            results[K]['energy']
        )
        
        mean_diff = np.mean(results[baseline_K]['energy']) - np.mean(results[K]['energy'])
        pooled_std = np.sqrt(
            (np.std(results[baseline_K]['energy'])**2 + np.std(results[K]['energy'])**2) / 2
        )
        cohens_d = mean_diff / pooled_std if pooled_std > 0 else 0
        
        significance = ''
        if p_value_energy < 0.001:
            significance = '***'
        elif p_value_energy < 0.01:
            significance = '**'
        elif p_value_energy < 0.05:
            significance = '*'
        else:
            significance = 'ns'
        
        table2_data.append({
            'K': K,
            'Δ Energy (%)': f"{mean_diff:.2f}",
            'p-value': f"{p_value_energy:.4f}",
            'Cohen\'s d': f"{cohens_d:.3f}",
            'Significance': significance,
            'Better?': '✓' if mean_diff > 0 else '✗'
        })
    
    df_table2 = pd.DataFrame(table2_data)
    table2_path = os.path.join(cfg.SAVE_DIR, "tables", "table2_statistical_significance.csv")
    df_table2.to_csv(table2_path, index=False)
    
    print(f"\n✅ Table 2 saved to: {table2_path}")
    print("\n" + df_table2.to_string(index=False))
    print("\nSignificance: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")
    
    return df_table1, df_table2

def generate_spectral_analysis_figures(train_pairs, cfg):
    """Generate all spectral analysis figures."""
    print(f"\n{'='*80}")
    print("📊 GENERATING SPECTRAL ANALYSIS FIGURES")
    print(f"{'='*80}")
    
    stats, results, sample_images = analyze_spectral_concentration(
        train_pairs, cfg, "Training Dataset"
    )
    
    fig3_path = create_comprehensive_spectral_figure(stats, results, sample_images, cfg)
    fig4_path = create_reconstruction_comparison_figure(sample_images, cfg)
    fig5_path = create_frequency_band_analysis_figure(sample_images, cfg)
    
    table1, table2 = create_publication_tables(stats, results, cfg)
    
    print(f"\n{'='*80}")
    print("✅ SPECTRAL ANALYSIS COMPLETE")
    print(f"{'='*80}\n")
    
    return stats

# ==============================================================================
# 🔥 MED-SEGNET-SSF ARCHITECTURE
# ==============================================================================

class SpectralSelectiveTokenMixer(Layer):
    """Spectral-Selective Token Mixer."""
    def __init__(self, channels, num_frequencies=32, ssm_state_dim=16, 
                 use_spectral=True, use_ssm=True, dropout=0.0, **kwargs):
        super().__init__(**kwargs)
        self.channels = channels
        self.num_frequencies = num_frequencies
        self.ssm_state_dim = ssm_state_dim
        self.use_spectral = use_spectral
        self.use_ssm = use_ssm
        self.dropout_rate = dropout
        
    def build(self, input_shape):
        input_h, input_w = input_shape[1], input_shape[2]
        self.actual_frequencies = min(self.num_frequencies, input_h, input_w) if input_h else self.num_frequencies
        
        if self.use_spectral:
            self.freq_weights_real = self.add_weight(
                name='freq_weights_real',
                shape=(self.actual_frequencies, self.actual_frequencies, self.channels),
                initializer=self._get_initializer(),
                trainable=True
            )
            self.spectral_norm = LayerNormalization(epsilon=1e-6, name='spectral_norm')
        
        if self.use_ssm:
            self.ssm_C = Dense(self.channels, name='ssm_C')
            self.selection_gate = Dense(self.channels, activation='sigmoid', name='selection')
            self.ssm_norm = LayerNormalization(epsilon=1e-6, name='ssm_norm')
        
        if self.use_spectral and self.use_ssm:
            self.fusion = Dense(self.channels, name='fusion')
            self.fusion_norm = LayerNormalization(epsilon=1e-6, name='fusion_norm')
        
        self.norm = LayerNormalization(epsilon=1e-6, name='norm')
        super().build(input_shape)
    
    def _get_initializer(self):
        def init_fn(shape, dtype=None):
            H, W, C = shape
            freq_h = np.fft.fftfreq(H)[:, np.newaxis]
            freq_w = np.fft.fftfreq(W)[np.newaxis, :]
            freq_magnitude = np.sqrt(freq_h**2 + freq_w**2)
            gaussian = np.exp(-((freq_magnitude - 0.25)**2) / (2 * 0.15**2))
            gaussian = np.repeat(gaussian[:, :, np.newaxis], C, axis=2)
            return gaussian.astype(np.float32) * 0.5
        return init_fn
    
    def spectral_path(self, x):
        H, W = tf.shape(x)[1], tf.shape(x)[2]
        freq_size = tf.minimum(tf.minimum(H, W), self.actual_frequencies)
        x_complex = tf.cast(x, tf.complex64)
        x_freq = tf.signal.fft2d(x_complex)
        x_freq_real = tf.math.real(x_freq)
        x_freq_imag = tf.math.imag(x_freq)
        x_freq_real_resized = tf.image.resize(x_freq_real, [freq_size, freq_size], method='bilinear')
        x_freq_imag_resized = tf.image.resize(x_freq_imag, [freq_size, freq_size], method='bilinear')
        x_freq_resized = tf.complex(x_freq_real_resized, x_freq_imag_resized)
        freq_filter = tf.cast(self.freq_weights_real[:freq_size, :freq_size, :], tf.complex64)
        x_freq_filtered = x_freq_resized * freq_filter
        x_freq_filt_real = tf.math.real(x_freq_filtered)
        x_freq_filt_imag = tf.math.imag(x_freq_filtered)
        x_freq_back_real = tf.image.resize(x_freq_filt_real, [H, W], method='bilinear')
        x_freq_back_imag = tf.image.resize(x_freq_filt_imag, [H, W], method='bilinear')
        x_freq_back = tf.complex(x_freq_back_real, x_freq_back_imag)
        x_spatial = tf.signal.ifft2d(x_freq_back)
        return self.spectral_norm(tf.math.real(x_spatial))
    
    def ssm_path(self, x):
        B, H, W = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x_flat = tf.reshape(x, [B, H * W, self.channels])
        selection = self.selection_gate(x_flat)
        x_selected = x_flat * selection
        x_ssm = self.ssm_C(x_selected)
        return self.ssm_norm(tf.reshape(x_ssm, [B, H, W, self.channels]))
    
    def call(self, x, training=None):
        outputs = []
        if self.use_spectral:
            outputs.append(self.spectral_path(x))
        if self.use_ssm:
            outputs.append(self.ssm_path(x))
        
        if len(outputs) == 2:
            fused = self.fusion_norm(self.fusion(tf.concat(outputs, axis=-1)))
        elif len(outputs) == 1:
            fused = outputs[0]
        else:
            fused = x
        
        fused = self.norm(fused)
        if training and self.dropout_rate > 0:
            fused = tf.nn.dropout(fused, rate=self.dropout_rate)
        return x + fused

def MRF_SE_BLOCK(x, filters, activation='elu', dropout=0.0, expand_ratio=6, 
                 regularizer=0.0, kernels=[3, 5, 7], se_reduction=16, name='mrf_se'):
    """Multi-Receptive Field with Squeeze-Excitation."""
    F_expanded = filters * expand_ratio
    
    conv = Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                  kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                  name=name+'_expand')(x) if expand_ratio > 1 else x
    conv = Activation(activation, name=name+'_expand_act')(BatchNormalization(name=name+'_expand_bn')(conv))
    
    features = []
    for k in kernels:
        dw = DepthwiseConv2D((k, k), padding='same', depthwise_initializer='he_uniform',
                            depthwise_regularizer=l2(regularizer) if regularizer > 0 else None,
                            name=f"{name}_dw{k}x{k}")(conv)
        features.append(Activation(activation, name=f"{name}_dw{k}x{k}_act")(
            BatchNormalization(name=f"{name}_dw{k}x{k}_bn")(dw)))
    
    combined = Concatenate(name=name+'_concat')(features) if len(features) > 1 else features[0]
    if len(features) > 1:
        combined = Activation(activation, name=name+'_fuse_act')(
            BatchNormalization(name=name+'_fuse_bn')(
                Conv2D(F_expanded, (1, 1), padding='same', kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_fuse')(combined)))
    
    gap = Reshape((1, 1, F_expanded), name=name+'_reshape')(GlobalAveragePooling2D(name=name+'_gap')(combined))
    se = Conv2D(F_expanded, (1, 1), activation='sigmoid', kernel_initializer='he_uniform', name=name+'_se_expand')(
        Conv2D(max(F_expanded//se_reduction, 8), (1, 1), activation=activation, kernel_initializer='he_uniform',
              name=name+'_se_reduce')(gap))
    
    projected = Conv2D(filters, (1, 1), padding='same', kernel_initializer='he_uniform',
                      kernel_regularizer=l2(regularizer) if regularizer > 0 else None,
                      name=name+'_project')(Multiply(name=name+'_se_mult')([combined, se]))
    projected = BatchNormalization(name=name+'_project_bn')(projected)
    
    if dropout > 0:
        projected = Dropout(dropout, name=name+'_dropout')(projected)
    
    out = Add(name=name+'_add')([projected, x])
    return out

def boundary_detection_module(features, filters, name='boundary'):
    """Boundary detection module."""
    boundary_conv = Conv2D(filters // 2, (3, 3), padding='same', activation='relu', name=name + '_conv')(features)
    boundary_map = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name=name + '_map')(boundary_conv)
    return Multiply(name=name + '_mult')([features, boundary_map]), boundary_map

def BFP_decoder_stage(decoder_input, skip_features, filters, stage_name='bfp'):
    """Boundary-Focused Progressive decoder stage."""
    region = Concatenate(name=stage_name+'_concat')([
        UpSampling2D((2, 2), name=stage_name+'_up')(decoder_input),
        skip_features
    ])
    
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv1')(region)))
    region = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_region_conv2')(region)))
    
    boundary_features, boundary_map = boundary_detection_module(region, filters, stage_name+'_boundary')
    
    boundary_refined = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (3, 3), padding='same', name=stage_name+'_boundary_refine')(boundary_features)))
    
    output = Activation('relu')(BatchNormalization()(
        Conv2D(filters, (1, 1), padding='same', name=stage_name+'_fusion')(
            region * (1 - boundary_map) + boundary_refined * boundary_map)))
    
    return output, boundary_map

def build_medsegnet_ssf(cfg):
    """Build MedSegNet-SSF model."""
    print("\n" + "="*80)
    print("🔥 BUILDING MED-SEGNET-SSF")
    print("="*80)
    
    inp = Input((cfg.INPUT_SIZE, cfg.INPUT_SIZE, 3), name="input")
    
    x = Conv2D(16, (3, 3), padding='same', kernel_initializer='he_uniform', name='stem_conv')(inp)
    x = BatchNormalization(name='stem_bn')(x)
    x = Activation('elu', name='stem_act')(x)
    
    encoder_outputs = []
    filters = [cfg.F1, cfg.F2, cfg.F3, cfg.F4, cfg.F5]
    
    for i, f in enumerate(filters):
        x = Conv2D(f, (3, 3), strides=2, padding='same', kernel_initializer='he_uniform')(x)
        x = BatchNormalization()(x)
        x = Activation('elu')(x)
        
        if cfg.USE_MRF_SE:
            x = MRF_SE_BLOCK(x, f, activation='elu', dropout=cfg.DROPOUT,
                           expand_ratio=cfg.EXPAND_RATIO, regularizer=cfg.L2_REG,
                           kernels=cfg.MRF_KERNELS, se_reduction=cfg.SE_REDUCTION,
                           name=f'mrfse_stage{i+1}')
        
        if cfg.USE_SSTM:
            x = SpectralSelectiveTokenMixer(
                channels=f, num_frequencies=cfg.SSTM_NUM_FREQUENCIES,
                ssm_state_dim=cfg.SSTM_SSM_STATE_DIM,
                use_spectral=cfg.SSTM_USE_SPECTRAL[i],
                use_ssm=cfg.SSTM_USE_SSM[i],
                dropout=cfg.SSTM_DROPOUT,
                name=f'sstm_stage{i+1}'
            )(x)
        
        encoder_outputs.append(x)
    
    skip_connections = encoder_outputs[::-1]
    decoder = skip_connections[0]
    decoder_filters = filters[::-1][1:] + [16]
    
    for i, (skip, f) in enumerate(zip(skip_connections[1:], decoder_filters)):
        if cfg.USE_BFP:
            decoder, _ = BFP_decoder_stage(decoder, skip, f, stage_name=f'bfp_stage{i+1}')
    
    decoder = UpSampling2D((2, 2))(decoder)
    decoder = Conv2D(32, (3, 3), padding='same', activation='relu')(decoder)
    decoder = Conv2D(16, (3, 3), padding='same', activation='relu')(decoder)
    out = Conv2D(1, (1, 1), padding='same', activation='sigmoid', name='output')(decoder)
    
    model = Model(inputs=inp, outputs=out, name="MedSegNet_SSF")
    
    print(f"\nTotal parameters: {model.count_params():,}")
    print("="*80 + "\n")
    
    return model

# ==============================================================================
# MASL LOSS FUNCTION
# ==============================================================================

class ClipConstraint(tf.keras.constraints.Constraint):
    """Custom constraint to clip weight values."""
    def __init__(self, min_value=0.1, max_value=10.0):
        self.min_value = min_value
        self.max_value = max_value
    
    def __call__(self, w):
        return tf.clip_by_value(w, self.min_value, self.max_value)
    
    def get_config(self):
        return {'min_value': self.min_value, 'max_value': self.max_value}

class MorphologyAwareAdaptiveLoss(Layer):
    """MASL: Morphology-Aware Adaptive Segmentation Loss."""
    
    def __init__(self, name='masl', **kwargs):
        super().__init__(name=name, **kwargs)
        self.epsilon = 1e-6
        
    def build(self, input_shape):
        clip_constraint = ClipConstraint(min_value=0.1, max_value=10.0)
        
        self.w_region = self.add_weight(
            name='w_region', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint
        )
        self.w_boundary = self.add_weight(
            name='w_boundary', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint
        )
        self.w_structure = self.add_weight(
            name='w_structure', shape=(), initializer=tf.constant_initializer(1.0),
            trainable=True, constraint=clip_constraint
        )
        self.w_scale = self.add_weight(
            name='w_scale', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint
        )
        self.w_texture = self.add_weight(
            name='w_texture', shape=(), initializer=tf.constant_initializer(0.5),
            trainable=True, constraint=clip_constraint
        )
        super().build(input_shape)
    
    def morphological_dilation(self, x, kernel_size=5):
        return tf.nn.max_pool2d(x, kernel_size, strides=1, padding='SAME')
    
    def morphological_erosion(self, x, kernel_size=5):
        return -tf.nn.max_pool2d(-x, kernel_size, strides=1, padding='SAME')
    
    def detect_boundary(self, mask, kernel_size=5):
        dilated = self.morphological_dilation(mask, kernel_size)
        eroded = self.morphological_erosion(mask, kernel_size)
        boundary = dilated - eroded
        return tf.clip_by_value(boundary, 0.0, 1.0)
    
    def analyze_structure_characteristics(self, y_true):
        area = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        total_pixels = tf.cast(tf.shape(y_true)[1] * tf.shape(y_true)[2], tf.float32)
        
        dy = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_padded = tf.pad(dy, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_padded = tf.pad(dx, [[0, 0], [0, 0], [0, 1], [0, 0]])
        gradient_mag = tf.sqrt(dy_padded**2 + dx_padded**2 + self.epsilon)
        perimeter = tf.reduce_sum(gradient_mag, axis=[1, 2, 3]) + self.epsilon
        
        skeleton_approx = self.morphological_erosion(y_true, kernel_size=3)
        skeleton_area = tf.reduce_sum(skeleton_approx, axis=[1, 2, 3]) + self.epsilon
        
        tubularity = tf.reduce_mean(skeleton_area / (area + self.epsilon))
        compactness = tf.reduce_mean((4 * 3.14159 * area) / (perimeter**2 + self.epsilon))
        compactness = tf.clip_by_value(compactness, 0.0, 1.0)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        ddy = boundary[:, 2:, :, :] - 2*boundary[:, 1:-1, :, :] + boundary[:, :-2, :, :]
        ddx = boundary[:, :, 2:, :] - 2*boundary[:, :, 1:-1, :] + boundary[:, :, :-2, :]
        irregularity = tf.reduce_mean(tf.abs(ddy)) + tf.reduce_mean(tf.abs(ddx))
        
        object_size = tf.reduce_mean(area / total_pixels)
        
        return {
            'tubularity': tf.clip_by_value(tubularity, 0.0, 1.0),
            'compactness': compactness,
            'irregularity': tf.clip_by_value(irregularity, 0.0, 1.0),
            'object_size': tf.clip_by_value(object_size, 0.0, 1.0)
        }
    
    def core_loss(self, y_true, y_pred):
        intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
        dice = (2. * intersection + self.epsilon) / (
            tf.reduce_sum(y_true, axis=[1, 2, 3]) + 
            tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon
        )
        dice_loss = 1.0 - tf.reduce_mean(dice)
        
        union = (tf.reduce_sum(y_true, axis=[1, 2, 3]) + 
                tf.reduce_sum(y_pred, axis=[1, 2, 3]) - intersection)
        iou = (intersection + self.epsilon) / (union + self.epsilon)
        iou_loss = 1.0 - tf.reduce_mean(iou)
        
        boundary = self.detect_boundary(y_true, kernel_size=5)
        weights = 1.0 + 5.0 * boundary
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + 
               (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        weighted_bce = tf.reduce_mean(weights * bce)
        
        return 0.4 * dice_loss + 0.3 * iou_loss + 0.3 * weighted_bce
    
    def boundary_loss(self, y_true, y_pred):
        total_loss = 0.0
        weights = [0.5, 0.3, 0.2]
        
        for scale, w in zip([1, 2, 4], weights):
            dy_true = y_true[:, scale:, :, :] - y_true[:, :-scale, :, :]
            dy_pred = y_pred[:, scale:, :, :] - y_pred[:, :-scale, :, :]
            dx_true = y_true[:, :, scale:, :] - y_true[:, :, :-scale, :]
            dx_pred = y_pred[:, :, scale:, :] - y_pred[:, :, :-scale, :]
            total_loss += w * (tf.reduce_mean(tf.abs(dy_true - dy_pred)) + 
                             tf.reduce_mean(tf.abs(dx_true - dx_pred)))
        
        return total_loss
    
    def structure_aware_loss(self, y_true, y_pred, characteristics):
        area_true = tf.reduce_sum(y_true, axis=[1, 2, 3]) + self.epsilon
        dy_true = y_true[:, 1:, :, :] - y_true[:, :-1, :, :]
        dx_true = y_true[:, :, 1:, :] - y_true[:, :, :-1, :]
        dy_true_padded = tf.pad(dy_true, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_true_padded = tf.pad(dx_true, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_true = tf.reduce_sum(tf.sqrt(dy_true_padded**2 + dx_true_padded**2 + self.epsilon), 
                                      axis=[1, 2, 3]) + self.epsilon
        
        area_pred = tf.reduce_sum(y_pred, axis=[1, 2, 3]) + self.epsilon
        dy_pred = y_pred[:, 1:, :, :] - y_pred[:, :-1, :, :]
        dx_pred = y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :]
        dy_pred_padded = tf.pad(dy_pred, [[0, 0], [0, 1], [0, 0], [0, 0]])
        dx_pred_padded = tf.pad(dx_pred, [[0, 0], [0, 0], [0, 1], [0, 0]])
        perimeter_pred = tf.reduce_sum(tf.sqrt(dy_pred_padded**2 + dx_pred_padded**2 + self.epsilon), 
                                      axis=[1, 2, 3]) + self.epsilon
        
        compact_true = area_true / (perimeter_true**2 + self.epsilon)
        compact_pred = area_pred / (perimeter_pred**2 + self.epsilon)
        
        return characteristics['compactness'] * tf.reduce_mean(tf.abs(compact_true - compact_pred))
    
    def scale_aware_focal_loss(self, y_true, y_pred, characteristics):
        size = characteristics['object_size']
        gamma = tf.cond(
            size < 0.05,
            lambda: 3.0,
            lambda: tf.cond(size < 0.2, lambda: 2.0, lambda: 1.5)
        )
        
        p = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        focal_weight = tf.pow(1 - p, gamma)
        bce = -(y_true * tf.math.log(y_pred + self.epsilon) + 
               (1 - y_true) * tf.math.log(1 - y_pred + self.epsilon))
        
        return tf.reduce_mean(focal_weight * bce)
    
    def texture_aware_loss(self, y_true, y_pred):
        ddy_true = y_true[:, 2:, :, :] - 2*y_true[:, 1:-1, :, :] + y_true[:, :-2, :, :]
        ddy_pred = y_pred[:, 2:, :, :] - 2*y_pred[:, 1:-1, :, :] + y_pred[:, :-2, :, :]
        ddx_true = y_true[:, :, 2:, :] - 2*y_true[:, :, 1:-1, :] + y_true[:, :, :-2, :]
        ddx_pred = y_pred[:, :, 2:, :] - 2*y_pred[:, :, 1:-1, :] + y_pred[:, :, :-2, :]
        
        return tf.reduce_mean(tf.abs(ddy_true - ddy_pred)) + tf.reduce_mean(tf.abs(ddx_true - ddx_pred))
    
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        characteristics = self.analyze_structure_characteristics(y_true)
        
        alpha_region = 1.0 + 0.5 * characteristics['compactness']
        alpha_boundary = 1.0 + 1.5 * characteristics['tubularity'] + characteristics['compactness']
        alpha_structure = 1.0 + characteristics['tubularity']
        alpha_scale = 1.0 + 1.5 * characteristics['irregularity']
        alpha_texture = 1.0 + characteristics['irregularity']
        
        l_core = self.core_loss(y_true, y_pred)
        l_boundary = self.boundary_loss(y_true, y_pred)
        l_structure = self.structure_aware_loss(y_true, y_pred, characteristics)
        l_scale = self.scale_aware_focal_loss(y_true, y_pred, characteristics)
        l_texture = self.texture_aware_loss(y_true, y_pred)
        
        weighted_core = self.w_region * alpha_region * l_core
        weighted_boundary = self.w_boundary * alpha_boundary * l_boundary
        weighted_structure = self.w_structure * alpha_structure * l_structure
        weighted_scale = self.w_scale * alpha_scale * l_scale
        weighted_texture = self.w_texture * alpha_texture * l_texture
        
        total_weight = (self.w_region * alpha_region + 
                       self.w_boundary * alpha_boundary + 
                       self.w_structure * alpha_structure + 
                       self.w_scale * alpha_scale + 
                       self.w_texture * alpha_texture)
        
        masl_loss = (weighted_core + weighted_boundary + weighted_structure + 
                    weighted_scale + weighted_texture) / (total_weight + self.epsilon)
        
        return masl_loss
    
    def get_config(self):
        return super().get_config()

_masl_instance = MorphologyAwareAdaptiveLoss()

def masl_loss_fn(y_true, y_pred):
    return _masl_instance(y_true, y_pred)

# ==============================================================================
# METRICS
# ==============================================================================

def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum(y_true_f) + K.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def precision_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    predicted_positives = K.sum(y_pred_bin)
    return true_positives / (predicted_positives + K.epsilon())

def recall_metric(y_true, y_pred):
    y_pred_bin = K.cast(y_pred > 0.5, tf.float32)
    true_positives = K.sum(y_true * y_pred_bin)
    actual_positives = K.sum(y_true)
    return true_positives / (actual_positives + K.epsilon())

# ==============================================================================
# PREDICTION MASK SAVING
# ==============================================================================

def save_prediction_masks(model, test_gen, save_dir):
    """Save prediction masks."""
    print("\n" + "="*80)
    print("💾 SAVING PREDICTION MASKS")
    print("="*80)
    
    prob_dir = os.path.join(save_dir, "probability_masks")
    binary_dir = os.path.join(save_dir, "binary_masks")
    overlay_dir = os.path.join(save_dir, "overlays")
    gt_dir = os.path.join(save_dir, "ground_truth")
    img_dir = os.path.join(save_dir, "original_images")
    
    for d in [prob_dir, binary_dir, overlay_dir, gt_dir, img_dir]:
        os.makedirs(d, exist_ok=True)
    
    sample_count = 0
    all_metrics = []
    
    for batch_idx in range(len(test_gen)):
        images, masks = test_gen[batch_idx]
        predictions = model.predict(images, verbose=0)
        
        for i in range(len(images)):
            sample_count += 1
            
            img = images[i]
            gt_mask = masks[i, :, :, 0]
            pred = predictions[i, :, :, 0]
            pred_binary = (pred > 0.5).astype(np.float32)
            
            intersection = np.sum(gt_mask * pred_binary)
            union = np.sum(gt_mask) + np.sum(pred_binary) - intersection
            dice = (2. * intersection) / (np.sum(gt_mask) + np.sum(pred_binary) + 1e-6)
            iou = intersection / (union + 1e-6)
            
            tp = np.sum(gt_mask * pred_binary)
            fp = np.sum((1 - gt_mask) * pred_binary)
            fn = np.sum(gt_mask * (1 - pred_binary))
            precision = tp / (tp + fp + 1e-6)
            recall = tp / (tp + fn + 1e-6)
            
            all_metrics.append({
                'sample_id': f'sample_{sample_count:04d}',
                'dice': dice,
                'iou': iou,
                'precision': precision,
                'recall': recall
            })
            
            filename = f'sample_{sample_count:04d}'
            
            img_uint8 = (img * 255).astype(np.uint8)
            cv2.imwrite(os.path.join(img_dir, f'{filename}.png'),
                       cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR))
            
            gt_uint8 = (gt_mask * 255).astype(np.uint8)
            cv2.imwrite(os.path.join(gt_dir, f'{filename}_gt.png'), gt_uint8)
            
            prob_uint8 = (pred * 255).astype(np.uint8)
            cv2.imwrite(os.path.join(prob_dir, f'{filename}_prob.png'), prob_uint8)
            prob_colorized = cv2.applyColorMap(prob_uint8, cv2.COLORMAP_JET)
            cv2.imwrite(os.path.join(prob_dir, f'{filename}_prob_color.png'), prob_colorized)
            
            binary_uint8 = (pred_binary * 255).astype(np.uint8)
            cv2.imwrite(os.path.join(binary_dir, f'{filename}_binary.png'), binary_uint8)
            
            overlay = img_uint8.copy()
            pred_uint8 = (pred_binary * 255).astype(np.uint8)
            contours, _ = cv2.findContours(pred_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
            
            gt_contours, _ = cv2.findContours(gt_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(overlay, gt_contours, -1, (255, 0, 0), 2)
            
            cv2.putText(overlay, f'Dice: {dice:.4f}', (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(overlay, 'Green=Pred, Red=GT', (10, 60),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            
            cv2.imwrite(os.path.join(overlay_dir, f'{filename}_overlay.png'),
                       cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
            
            if sample_count % 10 == 0:
                print(f"   Processed {sample_count} samples... (Dice: {dice:.4f})")
    
    df_metrics = pd.DataFrame(all_metrics)
    metrics_file = os.path.join(save_dir, "prediction_metrics.csv")
    df_metrics.to_csv(metrics_file, index=False)
    
    print(f"\n✅ SAVED {sample_count} PREDICTION MASKS")
    print(f"\n📊 Overall Metrics:")
    print(f"   Mean Dice:      {df_metrics['dice'].mean():.4f} ± {df_metrics['dice'].std():.4f}")
    print(f"   Mean IoU:       {df_metrics['iou'].mean():.4f} ± {df_metrics['iou'].std():.4f}")
    print(f"   Mean Precision: {df_metrics['precision'].mean():.4f} ± {df_metrics['precision'].std():.4f}")
    print(f"   Mean Recall:    {df_metrics['recall'].mean():.4f} ± {df_metrics['recall'].std():.4f}")
    
    return df_metrics

# ==============================================================================
# LEARNING RATE SCHEDULER
# ==============================================================================

def cosine_annealing_with_warmup(epoch, lr, total_epochs=30, warmup_epochs=5, min_lr=1e-6):
    """Cosine annealing with warmup."""
    if epoch < warmup_epochs:
        return lr * (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return min_lr + (lr - min_lr) * (1 + np.cos(np.pi * progress)) / 2

# ==============================================================================
# MAIN TRAINING LOOP
# ==============================================================================

def train_model(cfg, strategy, num_gpus):
    set_seed(cfg.SEED, cfg.DETERMINISTIC)
    
    # 1. LOAD DATA
    train_pairs = load_dataset_split(cfg.TRAIN_DIR)
    val_pairs = load_dataset_split(cfg.VAL_DIR)
    test_pairs = load_dataset_split(cfg.TEST_DIR)
    
    if not train_pairs:
        print("❌ No training data found!")
        return
    
    print(f"\n📊 Dataset Statistics:")
    print(f"   Training samples:   {len(train_pairs)}")
    print(f"   Validation samples: {len(val_pairs)}")
    print(f"   Test samples:       {len(test_pairs)}")
    
    # ==================== SPECTRAL ANALYSIS ====================
    if cfg.ANALYZE_SPECTRAL_TRUNCATION:
        generate_spectral_analysis_figures(train_pairs, cfg)
    
    # 2. GENERATORS
    train_aug = get_augmentation(cfg)
    val_aug = get_validation_augmentation(cfg)
    
    train_gen = ExpandedGenerator(
        train_pairs, cfg, augmentation=train_aug, shuffle=True,
        expansion_factor=cfg.EPOCH_EXPANSION_FACTOR
    )
    
    val_gen = ExpandedGenerator(
        val_pairs, cfg, augmentation=val_aug, shuffle=False, expansion_factor=1
    )
    
    test_gen = ExpandedGenerator(
        test_pairs, cfg, augmentation=val_aug, shuffle=False, expansion_factor=1
    )
    
    print(f"\n📊 Training Configuration:")
    print(f"   Steps per Epoch: {len(train_gen)}")
    print(f"   Virtual images/epoch: {len(train_pairs)*cfg.EPOCH_EXPANSION_FACTOR}")
    if num_gpus > 1:
        print(f"   Effective batch size: {cfg.BATCH_SIZE * num_gpus}")
    
    # 3. BUILD MODEL
    with strategy.scope():
        model = build_medsegnet_ssf(cfg)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.LEARNING_RATE, clipnorm=1.0)
        
        model.compile(
            optimizer=optimizer,
            loss=masl_loss_fn,
            metrics=[dice_coefficient, iou_score, precision_metric, recall_metric]
        )
    
    # 4. CALLBACKS
    lr_scheduler = LearningRateScheduler(
        lambda epoch: cosine_annealing_with_warmup(epoch, cfg.LEARNING_RATE, cfg.EPOCHS, warmup_epochs=5),
        verbose=0
    )
    
    callbacks = [
        ModelCheckpoint(
            os.path.join(cfg.SAVE_DIR, "best_MedSegNet_SSF.h5"),
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            save_best_only=True, verbose=1
        ),
        EarlyStopping(
            monitor=cfg.CHECKPOINT_MONITOR, mode=cfg.CHECKPOINT_MODE,
            patience=cfg.EARLY_STOPPING_PATIENCE, verbose=1, restore_best_weights=True
        ),
        CSVLogger(os.path.join(cfg.SAVE_DIR, "training_log.csv")),
        lr_scheduler
    ]
    
    # 5. TRAIN
    gpu_info = f" on {num_gpus} GPU(s)" if num_gpus > 0 else " on CPU"
    print(f"\n🚀 STARTING TRAINING ({cfg.EPOCHS} EPOCHS){gpu_info}")
    print("="*80)
    start_time = time.time()
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=cfg.EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    
    training_time = time.time() - start_time
    print(f"\n✅ Training finished in {training_time/60:.1f} minutes")
    
    # 6. EVALUATE
    print("\n" + "="*80)
    print("📊 EVALUATING ON TEST SET")
    print("="*80)
    test_results = model.evaluate(test_gen, verbose=1)
    
    results = {
        "model": "MedSegNet-SSF",
        "loss_function": "MASL",
        "gpu_config": {
            "num_gpus": num_gpus,
            "gpu_numbers": cfg.GPU_NUMBERS,
            "effective_batch_size": cfg.BATCH_SIZE * max(num_gpus, 1)
        },
        "training_time_minutes": training_time / 60,
        "test_results": {name: float(value) for name, value in zip(model.metrics_names, test_results)}
    }
    
    with open(os.path.join(cfg.SAVE_DIR, "results.json"), "w") as f:
        json.dump(results, f, indent=2)
    
    # 7. SAVE PREDICTION MASKS
    masks_dir = os.path.join(cfg.SAVE_DIR, "predictions")
    df_metrics = save_prediction_masks(model, test_gen, masks_dir)
    
    # 8. FINAL SUMMARY

    
    return model, history

# ==============================================================================
# MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("🔥 MED-SEGNET-SSF: SPECTRAL-SELECTIVE FEATURES")
    print("📊 WITH PUBLICATION-QUALITY SPECTRAL ANALYSIS")
    print("="*80)
    
    model, history = train_model(config, strategy, num_gpus)
    
    print("\n✅ ALL DONE!")
    print("🚀 Check the output directory for results and figures! 🚀\n")