# --------------------------------------------------------
# MSSIM from https://docs.monai.io/en/stable/metrics.html
# FID adapted from https://github.com/pfriedri/wdm-3d/blob/main/eval/fid.py
# Pretrained Resnet adapted from https://github.com/pfriedri/wdm-3d/tree/main/eval
# Thanks to the authors of these libraries for sharing their work.
# --------------------------------------------------------

import argparse
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from scipy import linalg
from monai.metrics import MultiScaleSSIMMetric
from tqdm import tqdm
import logging
import pandas as pd
import random
from functools import partial
import re
from scipy import stats
from collections.abc import Callable


# --- Logging Setup ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# --- Constants ---
DEFAULT_BATCH_SIZE = 4
DEFAULT_MAX_SAMPLES_PAIRWISE = 200
DEFAULT_PERCENTILE_SAMPLES = 100
PADDED_SHAPE = (128, 128, 128)
ORIGINAL_SHAPE = (91, 109, 91)
MIN_SAMPLES_FOR_STAT_TEST = 10

# --- Define Age Bands ---
AGE_BANDS = {
    "15-30": (15, 30),
    "40-55": (40, 55),
    "65-80": (65, 80)
}
AGE_PATTERN = re.compile(r'[_-]AGE[_-]([0-9.]+)', re.IGNORECASE)

# --- MedicalNet ResNet Code ---
def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, dilation=dilation, stride=stride, padding=dilation, bias=False)

def downsample_basic_block(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4), device=out.device)
    out = torch.cat([out.data, zero_pads], dim=1)
    return out

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None: 
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
    
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None: 
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, layers, sample_input_D, sample_input_H, sample_input_W, num_seg_classes, shortcut_type='B', no_cuda=False):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=1, dilation=4)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv3d): 
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d): 
                m.weight.data.fill_(1); m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A': 
                downsample = partial(downsample_basic_block, planes=planes * block.expansion, stride=stride)
            else: 
                downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm3d(planes * block.expansion))
        layers = [block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks): 
            layers.append(block(self.inplanes, planes, dilation=dilation))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x
    
def resnet50(**kwargs): 
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

def generate_medical_resnet(**kwargs): 
    return resnet50(sample_input_D=PADDED_SHAPE[0], sample_input_H=PADDED_SHAPE[1], sample_input_W=PADDED_SHAPE[2], num_seg_classes=2, **kwargs)

def extract_age_from_filename(filename):
    match = AGE_PATTERN.search(os.path.basename(filename))
    return float(match.group(1)) if match else None

# --- Global Intensity Statistics Calculation ---
def calculate_global_percentiles(file_paths, num_samples=100, percentiles=(0.5, 99.5)):
    logger.info(f"Calculating global intensity percentiles from {min(num_samples, len(file_paths))} sample files...")
    intensities = []
    files_to_sample = random.sample(file_paths, min(num_samples, len(file_paths)))
    for file_path in tqdm(files_to_sample, desc="Sampling intensities for normalization"):
        try:
            img = nib.load(file_path)
            data = img.get_fdata(dtype=np.float32)
            if data.shape == ORIGINAL_SHAPE: 
                intensities.append(data.flatten())
        except Exception as e: 
            logger.warning(f"Could not load {file_path} for percentiles: {e}")

    if not intensities: 
        logger.error("Could not load any valid data for global percentiles.")
        return None, None
    
    all_intensities = np.concatenate(intensities)
    p_low, p_high = np.nanpercentile(all_intensities, percentiles)

    logger.info(f"Global Percentiles ({percentiles[0]}%, {percentiles[1]}%): {p_low:.4f}, {p_high:.4f}")
    return p_low, p_high

# --- Dataset Class ---
class NiftiDataset(Dataset):
    def __init__(self, file_paths, global_p_low, global_p_high):
        self.file_paths, self.global_p_low, self.global_p_high = file_paths, global_p_low, global_p_high
        if self.global_p_low is None: raise ValueError("Global percentiles must be provided.")

    def __len__(self): 
        return len(self.file_paths)
    
    def _calculate_padding(self, current_shape):
        padding = []
        for i in range(3):
            total_pad = PADDED_SHAPE[2-i] - current_shape[2-i]
            padding.extend([total_pad // 2, total_pad - (total_pad // 2)])
        return tuple(padding)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        try:
            data = nib.load(file_path).get_fdata(dtype=np.float32)
            if data.shape != ORIGINAL_SHAPE: 
                logger.warning(f"File {os.path.basename(file_path)} shape mismatch. Skipping.")
                return None
            data = np.clip(data, self.global_p_low, self.global_p_high)
            denom = self.global_p_high - self.global_p_low
            data = (data - self.global_p_low) / denom if denom > 1e-8 else np.zeros_like(data)
            padded_tensor = F.pad(torch.from_numpy(np.clip(data, 0.0, 1.0)).unsqueeze(0), self._calculate_padding(data.shape), mode='constant', value=0)
            return padded_tensor
        except Exception as e: 
            logger.error(f"Error loading/processing {file_path}: {e}")
            return None
        
def safe_collate(batch):
    filtered_batch = [data for data in batch if data is not None]
    if not filtered_batch: 
        return None
    return torch.stack(filtered_batch)

# --- Activation Calculation ---
@torch.no_grad()
def get_activations(dataloader, model, device, max_samples=None):
    model.eval()
    activations = []
    count = 0
    for batch_data in tqdm(dataloader, desc="Calculating Activations"):
        if batch_data is None: 
            continue
        features = model(batch_data.to(device))
        activations.append(features.cpu().numpy())
        count += batch_data.size(0)
        if max_samples and count >= max_samples: 
            break
    if not activations: 
        return np.array([])
    concatenated_activations = np.concatenate(activations, axis=0)
    return concatenated_activations[:max_samples] if max_samples else concatenated_activations

# --- FID & MMD Calculation ---
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1, mu2, sigma1, sigma2 = np.atleast_1d(mu1), np.atleast_1d(mu2), np.atleast_2d(sigma1), np.atleast_2d(sigma2)
    try:
        covmean, _ = linalg.sqrtm((sigma1 + np.eye(sigma1.shape[0]) * eps) @ (sigma2 + np.eye(sigma2.shape[0]) * eps), disp=False)
        if np.iscomplexobj(covmean): covmean = covmean.real
        return np.sum((mu1 - mu2)**2) + np.trace(sigma1 + sigma2 - 2 * covmean)
    except (linalg.LinAlgError, ValueError) as e:
        logger.warning(f"sqrtm failed: {e}. Returning NaN for FID."); return float('nan')

def calculate_fid(acts1, acts2):
    if acts1.shape[0] < 2 or acts2.shape[0] < 2: 
        return float('nan')
    mu1, sigma1 = np.mean(acts1, axis=0), np.cov(acts1, rowvar=False)
    mu2, sigma2 = np.mean(acts2, axis=0), np.cov(acts2, rowvar=False)
    return calculate_frechet_distance(mu1, sigma1, mu2, sigma2)

def gaussian_kernel(x, y, sigma=1.0):
    x_sq, y_sq = (x*x).sum(1).unsqueeze(1), (y*y).sum(1).unsqueeze(0)
    return torch.exp(-F.relu(x_sq - 2 * x@y.T + y_sq) / (2 * sigma**2))

def calculate_mmd(acts1, acts2, device, sigma=1.0):
    if acts1.shape[0] < 2 or acts2.shape[0] < 2: 
        return float('nan')
    y, y_pred, m, n = torch.from_numpy(acts1).to(device), torch.from_numpy(acts2).to(device), acts1.shape[0], acts2.shape[0]
    K_yy = gaussian_kernel(y, y, sigma)
    K_y_pred_y_pred = gaussian_kernel(y_pred, y_pred, sigma)
    K_y_pred_y = gaussian_kernel(y_pred, y, sigma)
    term1 = (K_yy.sum() - K_yy.trace()) / (m * (m - 1)) if m > 1 else 0
    term2 = (K_y_pred_y_pred.sum() - K_y_pred_y_pred.trace()) / (n * (n - 1)) if n > 1 else 0
    mmd_sq = term1 + term2 - 2 * K_y_pred_y.mean()
    return F.relu(mmd_sq).item()

# --- Bootstrapped FID/MMD Calculation ---
def calculate_bootstrapped_fid_mmd(real_activations, gen_activations, num_bootstraps, device, mmd_sigma=5.0):
    fid_scores, mmd_scores = [], []
    num_samples_to_bootstrap = min(len(real_activations), len(gen_activations))
    if num_samples_to_bootstrap < 2: 
        return float('nan'), float('nan'), float('nan'), float('nan'), [], []
    pbar = tqdm(range(num_bootstraps), desc="Bootstrapping FID/MMD", leave=False)
    for _ in pbar:
        real_indices = np.random.choice(len(real_activations), size=num_samples_to_bootstrap, replace=True)
        gen_indices = np.random.choice(len(gen_activations), size=num_samples_to_bootstrap, replace=True)
        acts1_sample, acts2_sample = real_activations[real_indices], gen_activations[gen_indices]
        fid_scores.append(calculate_fid(acts1_sample, acts2_sample))
        mmd_scores.append(calculate_mmd(acts1_sample, acts2_sample, device, sigma=mmd_sigma))
    return np.nanmean(fid_scores), np.nanstd(fid_scores), np.nanmean(mmd_scores), np.nanstd(mmd_scores), fid_scores, mmd_scores

# --- Pairwise MS-SSIM Calculation Logic ---
def calculate_pairwise_metrics(dataloader, device, metric_name="MS-SSIM", return_scores=False):
    if metric_name != "MS-SSIM": 
        raise ValueError(f"Currently this function only supports MS-SSIM, not {metric_name}.")
    
    logger.info(f"Calculating pairwise {metric_name} (mean and variance)...")
    all_data = []
    for batch_data in dataloader:
        if batch_data is not None: 
            all_data.append(batch_data.cpu())

    if not all_data: 
        logger.warning(f"No valid data batches found for pairwise {metric_name}.")
        return float('nan'), float('nan'), []

    all_data_tensor = torch.cat(all_data, dim=0)
    num_samples = all_data_tensor.size(0)
    logger.info(f"Calculating pairwise {metric_name} on {num_samples} samples.")

    if num_samples < 2: 
        logger.warning(f"Need at least 2 samples for pairwise {metric_name}, found {num_samples}.")
        return float('nan'), float('nan'), []

    total_pairs = (num_samples * (num_samples - 1)) // 2
    pbar_calc = tqdm(total=total_pairs, desc=f"Calculating {metric_name} pairs", leave=False)
    metric_values = []
    metric_func = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_size=7)

    for i in range(num_samples):
        for j in range(i + 1, num_samples):
            try:
                img1, img2 = all_data_tensor[i:i+1].to(device), all_data_tensor[j:j+1].to(device)
                value = metric_func(img1, img2).item()
                if np.isfinite(value): 
                    metric_values.append(value)
            except Exception: 
                pass
            finally: 
                pbar_calc.update(1)
    pbar_calc.close()

    if not metric_values: 
        logger.warning(f"No valid {metric_name} values were calculated.")
        return float('nan'), float('nan'), []

    avg_metric = np.mean(metric_values)
    var_metric = np.var(metric_values, ddof=1) if len(metric_values) > 1 else 0.0
    
    return (avg_metric, var_metric, metric_values) if return_scores else (avg_metric, var_metric, [])

# --- Main Execution Logic ---
def main():
    parser = argparse.ArgumentParser(description="Comprehensive Benchmarking for 3D NIfTI Volumes")
    parser.add_argument("--real_dir", type=str, required=True, help="Directory containing real NIfTI files.")
    parser.add_argument("--gen_dirs", nargs='+', required=True, help="List of directories with generated NIfTI files.")
    parser.add_argument("--output_csv", type=str, required=True, help="Path to save the results CSV file.")
    parser.add_argument("--medical_resnet_path", type=str, required=True, help="Path to the pretrained MedicalNet ResNet-50 weights.")
    parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Batch size for DataLoader.")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for DataLoader.")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use.")
    parser.add_argument("--max_fid_samples", type=int, default=500, help="Maximum number of generated samples to use for FID/MMD.")
    parser.add_argument("--max_ssim_samples", type=int, default=DEFAULT_MAX_SAMPLES_PAIRWISE, help=f"Maximum number of generated samples to use for MS-SSIM. Default: {DEFAULT_MAX_SAMPLES_PAIRWISE}")
    parser.add_argument("--percentile_samples", type=int, default=DEFAULT_PERCENTILE_SAMPLES, help="Number of real samples for global percentile calculation.")
    parser.add_argument("--num_bootstraps", type=int, default=100, help="Number of random bootstrap iterations for FID/MMD.")
    parser.add_argument("--mmd_sigma", type=float, default=5.0, help="Sigma for the Gaussian kernel in MMD.")
    parser.add_argument("--stat_alpha", type=float, default=0.05, help="Significance level (alpha) for statistical tests.")
    args = parser.parse_args()

    device = torch.device(args.device)
    logger.info(f"Using device: {device}")
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if device.type == 'cuda': 
        torch.cuda.manual_seed_all(42)

    logger.info(f"Loading MedicalNet ResNet-50 feature extractor...")
    feature_extractor = generate_medical_resnet(no_cuda=(device.type == 'cpu'))
    ckpt = torch.load(args.medical_resnet_path, map_location='cpu')
    state_dict = {k.replace('module.', ''): 
                  v for k, v in ckpt['state_dict'].items() if 'fc' not in k and 'conv_seg' not in k}
    feature_extractor.load_state_dict(state_dict, strict=False)
    feature_extractor.to(device).eval()

    logger.info(f"Processing ALL real data from: {args.real_dir}")
    all_real_filepaths = sorted(glob.glob(os.path.join(args.real_dir, "*.nii.gz")))
    real_files_valid_shape = [fp for fp in all_real_filepaths if nib.load(fp).shape == ORIGINAL_SHAPE]
    if not real_files_valid_shape: 
        logger.error(f"No real files with correct shape found.")
        return

    global_p_low, global_p_high = calculate_global_percentiles(real_files_valid_shape, num_samples=args.percentile_samples)
    if global_p_low is None: 
        logger.error("Failed to calculate global percentiles.")
        return

    real_dataset_all = NiftiDataset(real_files_valid_shape, global_p_low, global_p_high)
    real_dataloader_all = DataLoader(real_dataset_all, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=safe_collate)
    real_activations_all = get_activations(real_dataloader_all, feature_extractor, device)
    logger.info(f"Calculated activations for {real_activations_all.shape[0]} real samples.")

    real_path_to_activation = {path: act for path, act in zip(real_files_valid_shape, real_activations_all)}

    real_files_by_band = {band_name: [] for band_name in AGE_BANDS}
    for fp in real_files_valid_shape:
        age = extract_age_from_filename(fp)
        if age is not None:
            for band_name, (min_age, max_age) in AGE_BANDS.items():
                if min_age <= age <= max_age: 
                    real_files_by_band[band_name].append(fp)

    results = []
    metric_distributions = {}

    for gen_dir in args.gen_dirs:
        dir_basename = os.path.basename(gen_dir)
        metric_distributions[dir_basename] = {}
        logger.info(f"\n--- Processing generated data from: {dir_basename} ---")
        
        row = {"Directory": dir_basename}
        gen_files_full = sorted(glob.glob(os.path.join(gen_dir, "*.nii.gz")))
        if not gen_files_full: 
            logger.warning(f"No NIfTI files in {gen_dir}. Skipping.")
            continue

        gen_files_valid_shape = [fp for fp in gen_files_full if nib.load(fp).shape == ORIGINAL_SHAPE]
        
        # --- Sample files for FID/MMD ---
        if len(gen_files_valid_shape) > args.max_fid_samples:
            stride = len(gen_files_valid_shape) // args.max_fid_samples
            gen_files_for_fid = gen_files_valid_shape[::stride][:args.max_fid_samples]
            logger.info(f"Systematically sampled {len(gen_files_for_fid)} files for FID/MMD (stride: {stride}).")
        else:
            gen_files_for_fid = gen_files_valid_shape

        if len(gen_files_for_fid) < 2: 
            logger.warning(f"Not enough valid samples for FID/MMD in {dir_basename}. Skipping.")
            continue

        if len(gen_files_valid_shape) > args.max_ssim_samples:
            stride_ssim = len(gen_files_valid_shape) // args.max_ssim_samples
            gen_files_for_ssim = gen_files_valid_shape[::stride_ssim][:args.max_ssim_samples]
            logger.info(f"Sampled {len(gen_files_for_ssim)} files for MS-SSIM (stride: {stride_ssim}).")
        else:
            gen_files_for_ssim = gen_files_valid_shape

        gen_dataset_fid = NiftiDataset(gen_files_for_fid, global_p_low, global_p_high)
        gen_dataloader_fid = DataLoader(gen_dataset_fid, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=safe_collate)
        gen_activations = get_activations(gen_dataloader_fid, feature_extractor, device)
        logger.info(f"Calculated {len(gen_activations)} activations for FID/MMD.")
        row["Num_Samples_FID_MMD"] = len(gen_activations)
        
        is_conditioned = any(extract_age_from_filename(fp) is not None for fp in gen_files_for_fid)
        row["Is_Conditioned"] = "Yes" if is_conditioned else "No"
        
        logger.info("Calculating Overall metrics (All Real vs. Gen)...")
        fid_mean, fid_std, mmd_mean, mmd_std, fid_scores, mmd_scores = calculate_bootstrapped_fid_mmd(real_activations_all, gen_activations, args.num_bootstraps, device, args.mmd_sigma)
        row.update({"FID_Overall_Mean": fid_mean, "FID_Overall_Std": fid_std, "MMD_Overall_Mean": mmd_mean, "MMD_Overall_Std": mmd_std})
        metric_distributions[dir_basename]["FID_Overall"] = fid_scores
        metric_distributions[dir_basename]["MMD_Overall"] = mmd_scores

        ssim_dataset = NiftiDataset(gen_files_for_ssim, global_p_low, global_p_high)
        ssim_dataloader = DataLoader(ssim_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=safe_collate)
        ssim_mean, ssim_var, ssim_scores = calculate_pairwise_metrics(ssim_dataloader, device, return_scores=True)
        row.update({"MS_SSIM_Overall_Mean": ssim_mean, "MS_SSIM_Overall_Std": np.sqrt(ssim_var) if ssim_var >= 0 else float('nan')})
        row["Num_Samples_MS_SSIM"] = len(gen_files_for_ssim)
        metric_distributions[dir_basename]["MS_SSIM_Overall"] = ssim_scores

        for band_name, (min_age, max_age) in AGE_BANDS.items():
            logger.info(f"Calculating metrics for Age Band: {band_name}...")
            real_files_in_band = real_files_by_band[band_name]
            if len(real_files_in_band) < 2: logger.warning(f"Not enough real samples in band {band_name}."); continue
            real_activations_in_band = np.array([real_path_to_activation[p] for p in real_files_in_band])

            if is_conditioned:
                gen_files_in_band_fid = [p for p in gen_files_for_fid if min_age <= (extract_age_from_filename(p) or -1) <= max_age]
                if len(gen_files_in_band_fid) < 2: 
                    logger.warning(f"Not enough generated samples for FID in band {band_name}.")
                    continue
                gen_path_to_activation = {path: act for path, act in zip(gen_files_for_fid, gen_activations)}
                gen_activations_in_band = np.array([gen_path_to_activation[p] for p in gen_files_in_band_fid])
            else:
                gen_activations_in_band = gen_activations

            fid_mean_b, fid_std_b, mmd_mean_b, mmd_std_b, fid_scores_b, mmd_scores_b = calculate_bootstrapped_fid_mmd(real_activations_in_band, gen_activations_in_band, args.num_bootstraps, device, args.mmd_sigma)
            row.update({f"FID_{band_name}_Mean": fid_mean_b, f"FID_{band_name}_Std": fid_std_b, f"MMD_{band_name}_Mean": mmd_mean_b, f"MMD_{band_name}_Std": mmd_std_b})
            metric_distributions[dir_basename][f"FID_{band_name}"] = fid_scores_b
            metric_distributions[dir_basename][f"MMD_{band_name}"] = mmd_scores_b

            if is_conditioned:
                gen_files_in_band_ssim = [p for p in gen_files_for_ssim if min_age <= (extract_age_from_filename(p) or -1) <= max_age]
                if len(gen_files_in_band_ssim) < 2: 
                    logger.warning(f"Not enough generated samples for MS-SSIM in band {band_name}.")
                    continue
                
                band_ssim_dataset = NiftiDataset(gen_files_in_band_ssim, global_p_low, global_p_high)
                band_ssim_dataloader = DataLoader(band_ssim_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=safe_collate)
                ssim_mean_b, ssim_var_b, ssim_scores_b = calculate_pairwise_metrics(band_ssim_dataloader, device, return_scores=True)
                row.update({f"MS_SSIM_{band_name}_Mean": ssim_mean_b, f"MS_SSIM_{band_name}_Std": np.sqrt(ssim_var_b) if ssim_var_b >= 0 else float('nan')})
                metric_distributions[dir_basename][f"MS_SSIM_{band_name}"] = ssim_scores_b
        
        results.append(row)

    df = pd.DataFrame(results)
    if len(args.gen_dirs) > 1:
        logger.info(f"\n--- Performing Statistical Tests vs Baseline: {os.path.basename(args.gen_dirs[0])} ---")
        baseline_dir_name = os.path.basename(args.gen_dirs[0])
        baseline_metrics = metric_distributions[baseline_dir_name]
        
        metrics_to_test = [k for k in baseline_metrics.keys() if baseline_metrics[k]]
        num_tests = (len(args.gen_dirs) - 1) * len(metrics_to_test)
        alpha_corrected = args.stat_alpha / max(1, num_tests)
        logger.info(f"Total tests: {num_tests}. Bonferroni corrected alpha: {alpha_corrected:.6f}")

        for metric_key in metrics_to_test:
            p_col = f"p_{metric_key}_vs_Baseline"
            s_col = f"sig_{metric_key}_vs_Baseline"
            df[p_col] = pd.NA
            df[s_col] = pd.NA
            base_dist = baseline_metrics[metric_key]
            for i, current_gen_dir in enumerate(args.gen_dirs[1:]):
                current_dir_name = os.path.basename(current_gen_dir)
                current_dist = metric_distributions.get(current_dir_name, {}).get(metric_key, [])
                if len(base_dist) >= MIN_SAMPLES_FOR_STAT_TEST and len(current_dist) >= MIN_SAMPLES_FOR_STAT_TEST:
                    try:
                        _, p_value = stats.ranksums(base_dist, current_dist)
                        df.loc[df['Directory'] == current_dir_name, p_col] = p_value
                        df.loc[df['Directory'] == current_dir_name, s_col] = p_value < alpha_corrected
                    except Exception as e: 
                        logger.warning(f"Wilcoxon test failed for {metric_key}: {e}")

    output_dir = os.path.dirname(args.output_csv)
    if output_dir: 
        os.makedirs(output_dir, exist_ok=True)
    df.to_csv(args.output_csv, index=False, float_format='%.6f')
    logger.info(f"\nBenchmark results saved to: {args.output_csv}")
    
    with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 300, 'display.float_format', '{:.4f}'.format):
        print("\n--- Benchmark Summary ---")
        print(df.to_string(index=False))
    logger.info("Evaluation script finished.")

if __name__ == "__main__":
    main()