import torch
import random
import numpy as np
import torch.nn.functional as F
from torchvision import transforms
import math
import warnings
from typing import Dict, List, Tuple, Optional, Any
import time

warnings.filterwarnings('ignore')

def safe_config_get(config, key, default=None):
    """
    Safe configuration accessor supporting both Namespace and dict objects.
    
    :param config: Configuration object (Namespace or dict)
    :param key: Configuration key to retrieve
    :param default: Default value if key not found
    :return: Configuration value or default
    """
    if hasattr(config, 'get'):
        return config.get(key, default)
    else:
        return getattr(config, key, default)

def safe_nested_config_get(config, keys, default=None):
    """
    Safe nested configuration accessor for hierarchical config structures.
    
    :param config: Configuration object (Namespace or dict)
    :param keys: List of keys for nested access
    :param default: Default value if path not found
    :return: Nested configuration value or default
    """
    current = config
    for key in keys:
        if hasattr(current, 'get'):
            current = current.get(key, {})
        else:
            current = getattr(current, key, {})
        if current == {} or current is None:
            return default
    return current if current != {} else default

class NDAHorizontalFlip(object):
    """
    Neuromorphic Data Augmentation horizontal flip transform.
    Preserves temporal spike timing information across all dimensions.
    """
    
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, x):
        if not torch.is_tensor(x):
            raise TypeError("Input must be a torch.Tensor")
        if random.random() > self.prob:
            return x
        return torch.flip(x, dims=[-1])

class NDARoll(object):
    """
    Neuromorphic Data Augmentation spatial rolling transform.
    Implements circular shift operations on spatial dimensions.
    """
    
    def __init__(self, shift_range=(-5, 5), prob=1.0):
        self.shift_range = shift_range
        self.prob = prob

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        min_shift, max_shift = self.shift_range
        off1 = random.randint(min_shift, max_shift)
        off2 = random.randint(min_shift, max_shift)
        return torch.roll(x, shifts=(off1, off2), dims=(-2, -1))

class NDARotation(object):
    """
    Neuromorphic Data Augmentation rotation transform.
    Applies geometric rotation while maintaining temporal structure.
    """
    
    def __init__(self, degrees=30, prob=1.0):
        self.degrees = degrees
        self.prob = prob
        self.rotate = transforms.RandomRotation(degrees=degrees)

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        original_shape = x.shape
        if x.dim() == 4:  # [T, C, H, W]
            T, C, H, W = x.shape
            x_reshaped = x.view(T * C, 1, H, W)
            x_rotated = self.rotate(x_reshaped)
            return x_rotated.view(T, C, H, W)
        elif x.dim() == 5:  # [B, T, C, H, W]
            B, T, C, H, W = x.shape
            x_reshaped = x.view(B * T * C, 1, H, W)
            x_rotated = self.rotate(x_reshaped)
            return x_rotated.view(B, T, C, H, W)
        elif x.dim() == 3:  # [C, H, W]
            return self.rotate(x)
        else:
            return x

class NDAShear(object):
    """
    Neuromorphic Data Augmentation shear transformation.
    Implements affine shear operations along specified directions.
    """
    
    def __init__(self, shear_range=(-30, 30), prob=1.0):
        self.shear_range = shear_range
        self.prob = prob
        self.shear = transforms.RandomAffine(degrees=0, shear=shear_range)

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        original_shape = x.shape
        if x.dim() == 4:  # [T, C, H, W]
            T, C, H, W = x.shape
            x_reshaped = x.view(T * C, 1, H, W)
            x_sheared = self.shear(x_reshaped)
            return x_sheared.view(T, C, H, W)
        elif x.dim() == 5:  # [B, T, C, H, W]
            B, T, C, H, W = x.shape
            x_reshaped = x.view(B * T * C, 1, H, W)
            x_sheared = self.shear(x_reshaped)
            return x_sheared.view(B, T, C, H, W)
        elif x.dim() == 3:  # [C, H, W]
            return self.shear(x)
        else:
            return x

class NDACutout(object):
    """
    Neuromorphic Data Augmentation cutout transform.
    Implements random rectangular masking for regularization.
    """
    
    def __init__(self, length, prob=1.0):
        self.length = length
        self.prob = prob

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        *batch_dims, H, W = x.shape
        mask = np.ones((H, W), np.float32)
        
        y = np.random.randint(H)
        x_coord = np.random.randint(W)
        
        y1 = np.clip(y - self.length // 2, 0, H)
        y2 = np.clip(y + self.length // 2, 0, H)
        x1 = np.clip(x_coord - self.length // 2, 0, W)
        x2 = np.clip(x_coord + self.length // 2, 0, W)
        
        mask[y1:y2, x1:x2] = 0.0
        mask_tensor = torch.from_numpy(mask).to(x.device)
        
        for _ in batch_dims:
            mask_tensor = mask_tensor.unsqueeze(0)
        mask_tensor = mask_tensor.expand_as(x)
        
        return x * mask_tensor

class NDAConservativeZoom(object):
    """
    Neuromorphic Data Augmentation conservative scaling transform.
    Applies mild scaling operations while preserving temporal information.
    """
    
    def __init__(self, scale_range=(0.95, 1.05), prob=0.2):
        self.scale_min, self.scale_max = scale_range
        self.prob = prob

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        scale = random.uniform(self.scale_min, self.scale_max)
        
        if x.dim() == 4:  # [T, C, H, W]
            T, C, H, W = x.shape
            new_h, new_w = int(H * scale), int(W * scale)
            if new_h > 0 and new_w > 0 and abs(scale - 1.0) > 0.01:
                x_reshaped = x.view(T * C, 1, H, W)
                x_scaled = F.interpolate(x_reshaped, size=(new_h, new_w),
                                       mode='bilinear', align_corners=False)
                x_scaled = x_scaled.view(T, C, new_h, new_w)
                
                if new_h >= H and new_w >= W:
                    start_h = (new_h - H) // 2
                    start_w = (new_w - W) // 2
                    return x_scaled[:, :, start_h:start_h+H, start_w:start_w+W]
                else:
                    pad_h = (H - new_h) // 2
                    pad_w = (W - new_w) // 2
                    return F.pad(x_scaled, (pad_w, W-new_w-pad_w, pad_h, H-new_h-pad_h))
        return x

class NDAMixup(object):
    """
    Neuromorphic Data Augmentation mixup transform.
    Implements sample-level mixing for regularization.
    """
    
    def __init__(self, alpha=0.2, prob=0.3):
        self.alpha = alpha
        self.prob = prob

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
            noise = torch.randn_like(x) * 0.01
            return lam * x + (1 - lam) * (x + noise)
        return x

def nda_cutmix_data(input_batch, target_batch, alpha=1.0):
    """
    Neuromorphic Data Augmentation CutMix implementation.
    
    :param input_batch: Input tensor batch
    :param target_batch: Target label batch
    :param alpha: Beta distribution parameter
    :return: Mixed batch, target_a, target_b, lambda
    """
    if input_batch.size(0) < 2:
        return input_batch, target_batch, target_batch, 1.0
    
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(input_batch.size(0)).to(input_batch.device)
    target_a = target_batch
    target_b = target_batch[rand_index]
    
    *batch_dims, H, W = input_batch.shape
    cut_rat = np.sqrt(1.0 - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    input_batch[..., bby1:bby2, bbx1:bbx2] = input_batch[rand_index, ..., bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (H * W))
    
    return input_batch, target_a, target_b, lam

def nda_mixup_criterion(criterion, pred, y_a, y_b, lam):
    """
    Neuromorphic Data Augmentation mixup loss function.
    
    :param criterion: Loss function
    :param pred: Model predictions
    :param y_a: First target
    :param y_b: Second target  
    :param lam: Mixing parameter
    :return: Mixed loss value
    """
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class NDAMonitor(object):
    """
    Neuromorphic Data Augmentation monitoring and analytics system.
    Tracks augmentation application rates, timing, and quality metrics.
    """
    
    def __init__(self, enabled=True):
        self.enabled = enabled
        self.augmentation_stats = {}
        self.timing_stats = {}
        self.diversity_stats = {}
        self.start_time = None

    def start_timing(self):
        """Begin timing measurement for operations."""
        if self.enabled:
            self.start_time = time.time()

    def end_timing(self, operation_name):
        """End timing measurement and record duration."""
        if self.enabled and self.start_time is not None:
            elapsed = time.time() - self.start_time
            if operation_name not in self.timing_stats:
                self.timing_stats[operation_name] = []
            self.timing_stats[operation_name].append(elapsed)

    def log_augmentation(self, aug_name, applied=True):
        """Log augmentation application statistics."""
        if self.enabled:
            if aug_name not in self.augmentation_stats:
                self.augmentation_stats[aug_name] = {'applied': 0, 'total': 0}
            self.augmentation_stats[aug_name]['total'] += 1
            if applied:
                self.augmentation_stats[aug_name]['applied'] += 1

    def analyze_timing_preservation(self, original_tensor, augmented_tensor):
        """
        Analyze temporal information preservation after augmentation.
        
        :param original_tensor: Original input tensor
        :param augmented_tensor: Augmented tensor
        """
        if not self.enabled:
            return
        
        with torch.no_grad():
            if original_tensor.dim() >= 4:
                temporal_axis = 0 if original_tensor.dim() == 4 else 1
                orig_temporal = original_tensor.sum(dim=tuple(range(temporal_axis+1, original_tensor.dim())))
                aug_temporal = augmented_tensor.sum(dim=tuple(range(temporal_axis+1, augmented_tensor.dim())))
                
                correlation = F.cosine_similarity(
                    orig_temporal.flatten().float(),
                    aug_temporal.flatten().float(),
                    dim=0
                ).item()
                
                if 'timing_preservation' not in self.diversity_stats:
                    self.diversity_stats['timing_preservation'] = []
                self.diversity_stats['timing_preservation'].append(correlation)

    def get_stats_summary(self) -> Dict[str, Any]:
        """
        Generate comprehensive statistics summary.
        
        :return: Dictionary containing all monitored statistics
        """
        if not self.enabled:
            return {}
        
        summary = {}
        
        if self.augmentation_stats:
            aug_summary = {}
            for aug_name, stats in self.augmentation_stats.items():
                rate = stats['applied'] / max(stats['total'], 1)
                aug_summary[aug_name] = {
                    'application_rate': rate,
                    'total_calls': stats['total']
                }
            summary['augmentation_rates'] = aug_summary
        
        if self.timing_stats:
            timing_summary = {}
            for op_name, times in self.timing_stats.items():
                timing_summary[op_name] = {
                    'mean_time': np.mean(times),
                    'std_time': np.std(times),
                    'total_calls': len(times)
                }
            summary['timing_stats'] = timing_summary
        
        if self.diversity_stats:
            diversity_summary = {}
            for metric_name, values in self.diversity_stats.items():
                diversity_summary[metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'min': np.min(values),
                    'max': np.max(values)
                }
            summary['diversity_stats'] = diversity_summary
        
        return summary

class NDACoreAugmentation(object):
    """
    Neuromorphic Data Augmentation core transform implementing random selection
    from roll, rotation, and shear operations.
    """
    
    def __init__(self, roll_range=(-4, 4), rotation_degrees=20, shear_range=(-20, 20), prob=0.15):
        self.prob = prob
        self.roll = NDARoll(shift_range=roll_range, prob=1.0)
        self.rotate = NDARotation(degrees=rotation_degrees, prob=1.0)
        self.shear = NDAShear(shear_range=shear_range, prob=1.0)
        self.choices = ['roll', 'rotate', 'shear']

    def __call__(self, x):
        if not torch.is_tensor(x) or random.random() > self.prob:
            return x
        
        choice = np.random.choice(self.choices)
        if choice == 'roll':
            return self.roll(x)
        elif choice == 'rotate':
            return self.rotate(x)
        elif choice == 'shear':
            return self.shear(x)
        else:
            return x

class NDATransform(object):
    """
    Neuromorphic Data Augmentation main transform composer.
    Builds augmentation pipeline from configuration with dataset-specific optimizations.
    """
    
    def __init__(self, config, dataset_name="cifar10_dvs"):
        self.config = config
        self.dataset_name = dataset_name.lower()
        self.transforms = []
        self.monitor = None
        
        enabled = safe_config_get(config, 'enabled', False)
        if not enabled:
            return
        
        self._setup_monitoring()
        self._setup_dataset_params()
        self._apply_dataset_overrides()
        self._build_transforms()

    def _setup_monitoring(self):
        """Initialize monitoring system if enabled."""
        monitoring_config = safe_config_get(self.config, 'nda_monitoring', {})
        if isinstance(monitoring_config, dict):
            enabled = monitoring_config.get('enabled', False)
        else:
            enabled = getattr(monitoring_config, 'enabled', False)
        
        self.monitor = NDAMonitor(enabled=enabled)

    def _setup_dataset_params(self):
        """Set dataset-specific optimization parameters."""
        if 'cifar10' in self.dataset_name:
            self.roll_range = (-5, 5)
            self.rotation_degrees = 30
            self.shear_range = (-30, 30)
            self.cutout_length = 16
        elif 'caltech' in self.dataset_name or 'ncaltech' in self.dataset_name:
            self.roll_range = (-3, 3)
            self.rotation_degrees = 15
            self.shear_range = (-15, 15)
            self.cutout_length = 12
        elif 'gesture' in self.dataset_name or 'dvs' in self.dataset_name:
            self.roll_range = (-4, 4)
            self.rotation_degrees = 20
            self.shear_range = (-20, 20)
            self.cutout_length = 14
        elif 'imagenet' in self.dataset_name:
            self.roll_range = (-6, 6)
            self.rotation_degrees = 25
            self.shear_range = (-25, 25)
            self.cutout_length = 20
        else:
            self.roll_range = (-4, 4)
            self.rotation_degrees = 20
            self.shear_range = (-20, 20)
            self.cutout_length = 14

    def _apply_dataset_overrides(self):
        """Apply dataset-specific parameter overrides from configuration."""
        dataset_overrides = safe_config_get(self.config, 'dataset_overrides', {})
        if isinstance(dataset_overrides, dict) and self.dataset_name in dataset_overrides:
            overrides = dataset_overrides[self.dataset_name]
            if isinstance(overrides, dict):
                if 'roll_range' in overrides:
                    self.roll_range = tuple(overrides['roll_range'])
                if 'rotation_degrees' in overrides:
                    self.rotation_degrees = overrides['rotation_degrees']
                if 'shear_range' in overrides:
                    self.shear_range = tuple(overrides['shear_range'])
                if 'cutout_length' in overrides:
                    self.cutout_length = overrides['cutout_length']
            else:
                if hasattr(overrides, 'roll_range'):
                    self.roll_range = tuple(overrides.roll_range)
                if hasattr(overrides, 'rotation_degrees'):
                    self.rotation_degrees = overrides.rotation_degrees
                if hasattr(overrides, 'shear_range'):
                    self.shear_range = tuple(overrides.shear_range)
                if hasattr(overrides, 'cutout_length'):
                    self.cutout_length = overrides.cutout_length
        
        nda_geometric = safe_config_get(self.config, 'nda_geometric', {})
        if isinstance(nda_geometric, dict):
            enabled = nda_geometric.get('enabled', False)
            dataset_adaptive = nda_geometric.get('dataset_adaptive', True)
        else:
            enabled = getattr(nda_geometric, 'enabled', False)
            dataset_adaptive = getattr(nda_geometric, 'dataset_adaptive', True)
        
        if enabled and not dataset_adaptive:
            if isinstance(nda_geometric, dict):
                if 'roll_range' in nda_geometric:
                    self.roll_range = tuple(nda_geometric['roll_range'])
                if 'rotation_degrees' in nda_geometric:
                    self.rotation_degrees = nda_geometric['rotation_degrees']
                if 'shear_range' in nda_geometric:
                    self.shear_range = tuple(nda_geometric['shear_range'])
            else:
                if hasattr(nda_geometric, 'roll_range'):
                    self.roll_range = tuple(nda_geometric.roll_range)
                if hasattr(nda_geometric, 'rotation_degrees'):
                    self.rotation_degrees = nda_geometric.rotation_degrees
                if hasattr(nda_geometric, 'shear_range'):
                    self.shear_range = tuple(nda_geometric.shear_range)

    def _build_transforms(self):
        """Build transform pipeline from configuration."""
        # Safe spatial transforms
        safe_spatial = safe_config_get(self.config, 'safe_spatial', {})
        if isinstance(safe_spatial, dict):
            horizontal_flip = safe_spatial.get('horizontal_flip', 0)
        else:
            horizontal_flip = getattr(safe_spatial, 'horizontal_flip', 0)
        
        if horizontal_flip > 0:
            flip_prob = min(horizontal_flip, 0.5)
            self.transforms.append(NDAHorizontalFlip(prob=flip_prob))
        
        # Conservative zoom
        conservative_zoom = safe_config_get(self.config, 'conservative_zoom', {})
        if isinstance(conservative_zoom, dict):
            zoom_enabled = conservative_zoom.get('enabled', False)
            scale_range = conservative_zoom.get('scale_range', [0.95, 1.05])
            zoom_prob = conservative_zoom.get('prob', 0.2)
        else:
            zoom_enabled = getattr(conservative_zoom, 'enabled', False)
            scale_range = getattr(conservative_zoom, 'scale_range', [0.95, 1.05])
            zoom_prob = getattr(conservative_zoom, 'prob', 0.2)
        
        if zoom_enabled:
            scale_range = tuple(scale_range)
            self.transforms.append(NDAConservativeZoom(scale_range=scale_range, prob=zoom_prob))
        
        # Core geometric augmentations
        nda_geometric = safe_config_get(self.config, 'nda_geometric', {})
        if isinstance(nda_geometric, dict):
            geometric_enabled = nda_geometric.get('enabled', False)
        else:
            geometric_enabled = getattr(nda_geometric, 'enabled', False)
        
        if isinstance(safe_spatial, dict):
            core_prob = safe_spatial.get('prob_threshold', 0.15)
        else:
            core_prob = getattr(safe_spatial, 'prob_threshold', 0.15)
        
        if geometric_enabled or core_prob > 0:
            self.transforms.append(NDACoreAugmentation(
                roll_range=self.roll_range,
                rotation_degrees=self.rotation_degrees,
                shear_range=self.shear_range,
                prob=core_prob
            ))
        
        # Advanced techniques
        nda_advanced = safe_config_get(self.config, 'nda_advanced', {})
        if isinstance(nda_advanced, dict):
            advanced_enabled = nda_advanced.get('enabled', False)
        else:
            advanced_enabled = getattr(nda_advanced, 'enabled', False)
        
        if advanced_enabled:
            # Cutout
            cutout_config = safe_config_get(nda_advanced, 'cutout', {}) if isinstance(nda_advanced, dict) else getattr(nda_advanced, 'cutout', {})
            if isinstance(cutout_config, dict):
                cutout_enabled = cutout_config.get('enabled', False)
                cutout_length = cutout_config.get('length', self.cutout_length)
                cutout_prob = cutout_config.get('prob', 0.3)
            else:
                cutout_enabled = getattr(cutout_config, 'enabled', False)
                cutout_length = getattr(cutout_config, 'length', self.cutout_length)
                cutout_prob = getattr(cutout_config, 'prob', 0.3)
            
            if cutout_enabled:
                self.transforms.append(NDACutout(cutout_length, prob=cutout_prob))
            
            # Mixup
            mixup_config = safe_config_get(nda_advanced, 'mixup', {}) if isinstance(nda_advanced, dict) else getattr(nda_advanced, 'mixup', {})
            if isinstance(mixup_config, dict):
                mixup_enabled = mixup_config.get('enabled', False)
                mixup_alpha = mixup_config.get('alpha', 0.2)
                mixup_prob = mixup_config.get('prob', 0.3)
            else:
                mixup_enabled = getattr(mixup_config, 'enabled', False)
                mixup_alpha = getattr(mixup_config, 'alpha', 0.2)
                mixup_prob = getattr(mixup_config, 'prob', 0.3)
            
            if mixup_enabled:
                self.transforms.append(NDAMixup(alpha=mixup_alpha, prob=mixup_prob))

    def __call__(self, x):
        """
        Apply transform pipeline with optional monitoring.
        
        :param x: Input tensor
        :return: Augmented tensor
        """
        if self.monitor:
            self.monitor.start_timing()
        
        original_x = x.clone() if self.monitor and self.monitor.enabled else None
        
        for transform in self.transforms:
            transform_name = transform.__class__.__name__
            x_before = x.clone() if self.monitor and self.monitor.enabled else None
            x = transform(x)
            
            if self.monitor and self.monitor.enabled and x_before is not None:
                applied = not torch.equal(x_before, x)
                self.monitor.log_augmentation(transform_name, applied)
        
        if self.monitor and self.monitor.enabled and original_x is not None:
            self.monitor.analyze_timing_preservation(original_x, x)
        
        if self.monitor:
            self.monitor.end_timing('total_augmentation')
        
        return x

    def get_monitor_stats(self):
        """Return monitoring statistics if available."""
        if self.monitor and self.monitor.enabled:
            return self.monitor.get_stats_summary()
        return {}

class FramewiseResize(object):
    """
    Frame-wise resize transform preserving temporal spike information.
    Handles multi-dimensional tensor resizing with temporal awareness.
    """
    
    def __init__(self, size):
        if isinstance(size, int):
            size = (size, size)
        self.size = size

    def __call__(self, x):
        if not torch.is_tensor(x):
            raise TypeError("Input must be a torch.Tensor")
        
        x = x.float()
        
        if x.dim() == 5:  # [B, T, C, H, W]
            B, T, C, H, W = x.shape
            x_ = x.view(B * T, C, H, W)
            x_ = F.interpolate(x_, size=self.size, mode='bilinear', align_corners=False)
            return x_.view(B, T, C, *self.size)
        elif x.dim() == 4:  # [T, C, H, W]
            T, C, H, W = x.shape
            x_ = x.view(T * C, 1, H, W)
            x_ = F.interpolate(x_, size=self.size, mode='bilinear', align_corners=False)
            return x_.view(T, C, *self.size)
        elif x.dim() == 3:  # [C, H, W]
            x_ = F.interpolate(x.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False)
            return x_.squeeze(0)
        else:
            raise ValueError(f"Unsupported input shape for resize: {x.shape}")

class DVSMinMaxNormalize(object):
    """
    Min-max normalization for DVS data with numerical stability.
    """
    
    def __call__(self, x):
        x = x.float()
        min_val = x.min()
        max_val = x.max()
        if (max_val - min_val) < 1e-8:
            return torch.zeros_like(x)
        return (x - min_val) / (max_val - min_val + 1e-6)

class DVSClampNormalize(object):
    """
    Clamping normalization for extreme value control in DVS processing.
    """
    
    def __init__(self, min_val=-3.0, max_val=3.0):
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, x):
        x = x.float()
        return torch.clamp(x, self.min_val, self.max_val)

def get_nda_transform(config, split='train', input_size=128, dataset_name="cifar10_dvs"):
    """
    Neuromorphic Data Augmentation transform factory with configuration support.
    
    :param config: Configuration object (Namespace or dict)
    :param split: Data split ('train', 'validation', 'test')
    :param input_size: Target input size
    :param dataset_name: Dataset identifier for optimization
    :return: Composed transform pipeline
    """
    transform_list = [FramewiseResize((input_size, input_size))]
    
    if split == 'train':
        data_aug_config = safe_config_get(config, 'data_aug', {})
        if isinstance(data_aug_config, dict):
            enabled = data_aug_config.get('enabled', False)
        else:
            enabled = getattr(data_aug_config, 'enabled', False)
        
        if enabled:
            nda_transform = NDATransform(data_aug_config, dataset_name)
            transform_list.append(nda_transform)
    
    transform_list.extend([
        DVSClampNormalize(min_val=-2.0, max_val=2.0),
        DVSMinMaxNormalize()
    ])
    
    return transforms.Compose(transform_list)

def get_nda_transform_no_resize(config, split='train', dataset_name="cifar10_dvs"):
    """
    Neuromorphic Data Augmentation transform without resize for dataset integration.
    
    :param config: Configuration object (Namespace or dict)
    :param split: Data split ('train', 'validation', 'test')
    :param dataset_name: Dataset identifier for optimization
    :return: Composed transform pipeline without resize
    """
    if split != 'train':
        return transforms.Compose([])
    
    data_aug_config = safe_config_get(config, 'data_aug', {})
    if isinstance(data_aug_config, dict):
        enabled = data_aug_config.get('enabled', False)
    else:
        enabled = getattr(data_aug_config, 'enabled', False)
    
    if not enabled:
        return transforms.Compose([])
    
    return NDATransform(data_aug_config, dataset_name)

def get_nda_cutmix_config(config):
    """
    Extract CutMix configuration from nested config structure.
    
    :param config: Configuration object (Namespace or dict)
    :return: CutMix configuration dictionary
    """
    data_aug = safe_config_get(config, 'data_aug', {})
    if isinstance(data_aug, dict):
        nda_advanced = data_aug.get('nda_advanced', {})
    else:
        nda_advanced = getattr(data_aug, 'nda_advanced', {})
    
    if isinstance(nda_advanced, dict):
        cutmix_config = nda_advanced.get('cutmix', {})
    else:
        cutmix_config = getattr(nda_advanced, 'cutmix', {})
    
    if isinstance(cutmix_config, dict):
        return {
            'enabled': cutmix_config.get('enabled', False),
            'alpha': cutmix_config.get('alpha', 1.0),
            'prob': cutmix_config.get('prob', 0.4)
        }
    else:
        return {
            'enabled': getattr(cutmix_config, 'enabled', False),
            'alpha': getattr(cutmix_config, 'alpha', 1.0),
            'prob': getattr(cutmix_config, 'prob', 0.4)
        }

def get_transform(split='train', input_size=128, cfg=None, dataset_name="cifar10_dvs"):
    """
    Legacy transform interface for backward compatibility.
    
    :param split: Data split
    :param input_size: Target input size
    :param cfg: Configuration object
    :param dataset_name: Dataset identifier
    :return: Transform pipeline
    """
    if cfg is None:
        cfg = {'data_aug': {'enabled': False}}
    return get_nda_transform(cfg, split, input_size, dataset_name)

def apply_nda_batch_augmentation(batch, labels, config, criterion=None):
    """
    Apply batch-level neuromorphic data augmentation including CutMix.
    
    :param batch: Input batch tensor
    :param labels: Target labels
    :param config: Configuration object
    :param criterion: Loss function for mixed criterion computation
    :return: Augmented batch and labels/criterion
    """
    cutmix_config = get_nda_cutmix_config(config)
    
    if cutmix_config['enabled'] and random.random() < cutmix_config['prob']:
        mixed_batch, target_a, target_b, lam = nda_cutmix_data(
            batch, labels, alpha=cutmix_config['alpha']
        )
        
        if criterion is not None:
            def mixed_criterion(pred):
                return nda_mixup_criterion(criterion, pred, target_a, target_b, lam)
            return mixed_batch, mixed_criterion
        else:
            return mixed_batch, (target_a, target_b, lam)
    
    return batch, labels
