import os
import torch
import numpy as np
import scipy.io as sio
from training.utils import UnitGaussianNormalizer, ScaledGaussianNormalizer, ScaledGaussianNormalizer2
from torch.utils.data import Dataset as TorchDataset

class Dataset(TorchDataset):
    def __init__(self, gain=1.0, use_labels=True, pde_direction='forward', 
                    normalizer_path=None,normalizer="UnitGaussian", training_mode="conditional", masking_strategy=None):
        # Initialize common attributes
        self.gain = gain
        self.training_mode = training_mode  # "conditional" or "unified"

        # Initialize common attributes
        self.gain = gain
        if normalizer_path:
            self.normalizer_path = os.path.join(normalizer_path, "normalizers.npy")
        else:
            self.normalizer_path = None
        self.pde_direction = pde_direction  
        self.pde_type = None
        self.input_normalizer = None
        self.output_normalizer = None
        self.normalizer = normalizer
        self.training = True  # Default to training mode
        # self.process_seed = None
        # Initialize masking strategy parameters
        self.masking_strategy = masking_strategy or {}
        self.random_sample_masking = self.masking_strategy.get('random_sample_masking', False)
        self.final_mask_sample_rate = self.masking_strategy.get('final_mask_sample_rate', 0.0)
        self.final_mask_observation_rate = self.masking_strategy.get('final_mask_observation_rate', 1.0)
        self.enable_sparsity_curriculum = self.masking_strategy.get('enable_sparsity_curriculum', True)
        self.enable_sample_curriculum = self.masking_strategy.get('enable_sample_curriculum', False)
        # Curriculum learning parameters
        # Sparsity curriculum
        sparsity_curriculum = self.masking_strategy.get('sparsity_curriculum', {})
        self.initial_obs_rate = sparsity_curriculum.get('initial_obs_rate', self.final_mask_observation_rate)
        self.sparsity_curriculum_kimg = sparsity_curriculum.get('sparsity_curriculum_kimg', 100)
        self.sparsity_schedule = sparsity_curriculum.get('sparsity_schedule', 'cosine')
        # Sample masking curriculum
        sample_curriculum = self.masking_strategy.get('sample_curriculum', {})
        self.initial_sample_rate = sample_curriculum.get('initial_sample_rate', self.final_mask_sample_rate)
        self.sample_curriculum_kimg = sample_curriculum.get('sample_curriculum_kimg', 50)
        self.sample_schedule = sample_curriculum.get('sample_schedule', 'linear') 
        # Fill strategy
        self.fill_strategy = self.masking_strategy.get('fill_strategy', 'mean') 
        self.noise_scale = self.masking_strategy.get('noise_scale', 0.01)
        
        # Current curriculum state
        self.current_kimg = 0
        self.current_obs_rate = self.initial_obs_rate
        self.current_sample_rate = self.initial_sample_rate  
        

    
    @property
    def name(self):
        raise NotImplementedError("Subclasses should implement this property.")
    
    @property
    def num_channels(self):
        raise NotImplementedError("Subclasses should implement this property.")

    @property
    def resolution(self):
        raise NotImplementedError("Subclasses should implement this property.")

    @property
    def label_dim(self):
        return self.y_dim
    
    @property
    def pde_loss_function(self):
        raise NotImplementedError("Subclasses should implement this property.")

    #### old code for conditional version only
    def old_normalize(self, input_data: torch.Tensor, output_data: torch.Tensor, gain: float, 
                check_valid: bool = False):
        """Helper function to return normalized version of input and output data."""
        self.gain = gain
        if self.normalizer_path and os.path.exists(self.normalizer_path):
            # Load normalization parameters if path exists
            self.load_normalizers(self.normalizer_path)
        else:
            if self.normalizer == "UnitGaussian":
                print("Using UnitGaussianNormalizer")
                # Initialize normalizers and save them if path is provided
                self.input_normalizer = UnitGaussianNormalizer(input_data)
                self.output_normalizer = UnitGaussianNormalizer(output_data)
            elif self.normalizer == "ScaledGaussian":
                # Initialize normalizers and save them if path is provided
                self.input_normalizer = ScaledGaussianNormalizer(input_data)
                self.output_normalizer = ScaledGaussianNormalizer(output_data)
            else:
                self.input_normalizer = ScaledGaussianNormalizer2(input_data)
                self.output_normalizer = ScaledGaussianNormalizer2(output_data)

            if self.normalizer_path:
                self.save_normalizers(self.normalizer_path)

        input_normed = self.norm_input(input_data)
        output_normed = self.norm_output(output_data)
        
        if check_valid:
            # Debug information
            input_reconstructed = self.denorm_input(input_normed)
            output_reconstructed = self.denorm_output(output_normed)
                
            # print(f"=== Original Data Stats ===")
            # print(f"Input data: min={input_data.min():.6f}, max={input_data.max():.6f}, mean={input_data.mean():.6f}, std={input_data.std():.6f}")
            # print(f"Output data: min={output_data.min():.6f}, max={output_data.max():.6f}, mean={output_data.mean():.6f}, std={output_data.std():.6f}")
            
            # print(f"=== Normalized Data Stats ===")
            # print(f"Input normed: min={input_normed.min():.6f}, max={input_normed.max():.6f}, mean={input_normed.mean():.6f}, std={input_normed.std():.6f}")
            # print(f"Output normed: min={output_normed.min():.6f}, max={output_normed.max():.6f}, mean={output_normed.mean():.6f}, std={output_normed.std():.6f}")
            
            # print(f"=== Reconstructed Data Stats ===")
            # print(f"Input reconstructed: min={input_reconstructed.min():.6f}, max={input_reconstructed.max():.6f}, mean={input_reconstructed.mean():.6f}, std={input_reconstructed.std():.6f}")
            # print(f"Output reconstructed: min={output_reconstructed.min():.6f}, max={output_reconstructed.max():.6f}, mean={output_reconstructed.mean():.6f}, std={output_reconstructed.std():.6f}")
            
            # print(f"=== Reconstruction Errors ===")
            # input_diff = torch.abs(input_reconstructed - input_data)
            # output_diff = torch.abs(output_reconstructed - output_data)
            
            # print(f"Input reconstruction error: max={input_diff.max():.6f}, mean={input_diff.mean():.6f}")
            # print(f"Output reconstruction error: max={output_diff.max():.6f}, mean={output_diff.mean():.6f}")
            
            # print(f"=== Expected vs Actual for {self.normalizer} ===")
            # if self.normalizer == "UnitGaussian":
            #     print(f"Expected normalized std ≈ 1.0")
            # else:  # GaussianNormalizer2
            #     print(f"Expected normalized std ≈ 0.5")
            # Check specific failing elements
            input_close = torch.isclose(input_reconstructed, input_data, atol=1e-4)
            output_close = torch.isclose(output_reconstructed, output_data, atol=1e-4)
            assert input_close.all(), f"Input normalization round-trip failed"
            assert output_close.all(), f"Output normalization round-trip failed"
            assert torch.isclose(self.denorm_input(input_normed), input_data, atol=1e-4).all()
            assert torch.isclose(self.denorm_output(output_normed), output_data, atol=1e-4).all()
        
        return input_normed, output_normed

    def normalize(self, a_data: torch.Tensor, u_data: torch.Tensor, gain: float, 
                check_valid: bool = False):
        """
        Initializes canonical normalizers for 'a' and 'u', and returns normalized data
        for the conditional mode.
        """
        self.gain = gain
        # print("self.gain: ", self.gain)
        print("normalizer path: ", self.normalizer_path)
        # print("path exists: ", os.path.exists(self.normalizer_path) if self.normalizer_path else "No path provided")
        if self.normalizer_path and os.path.exists(self.normalizer_path):
            self.load_normalizers(self.normalizer_path)
            print("loading normalisers")
        else:
            print("Creating new normalisers")
            # Create canonical normalizers from the full dataset statistics
            if self.normalizer == "UnitGaussian":
                self.a_normalizer = UnitGaussianNormalizer(a_data)
                self.u_normalizer = UnitGaussianNormalizer(u_data)
            elif self.normalizer == "ScaledGaussian":
                self.a_normalizer = ScaledGaussianNormalizer(a_data)
                self.u_normalizer = ScaledGaussianNormalizer(u_data)
            else:
                self.a_normalizer = ScaledGaussianNormalizer2(a_data)
                self.u_normalizer = ScaledGaussianNormalizer2(u_data)

            if self.normalizer_path:
                self.save_normalizers(self.normalizer_path)

        # Set aliases for backward compatibility with conditional mode
        self._set_conditional_normalizer_aliases()

        # For conditional mode, return the correctly ordered normalized data
        # if self.pde_direction == 'forward':
        #     input_normed = self.a_normalizer.encode(a_data) * self.gain
        #     output_normed = self.u_normalizer.encode(u_data) * self.gain
        # else: # inverse
        #     input_normed = self.u_normalizer.encode(u_data) * self.gain
        #     output_normed = self.a_normalizer.encode(a_data) * self.gain
        if self.training_mode == 'conditional':
            input_normed = self.norm_input(a_data if self.pde_direction == 'forward' else u_data)
            output_normed = self.norm_output(u_data if self.pde_direction == 'forward' else a_data)
            
            print(f"  [normalize conditional] Input normed shape: {input_normed.shape}, Output normed shape: {output_normed.shape}")
            if check_valid:
                assert torch.isclose(self.denorm_input(input_normed), a_data if self.pde_direction == 'forward' else u_data, atol=1e-4).all()
                assert torch.isclose(self.denorm_output(output_normed), u_data if self.pde_direction == 'forward' else a_data, atol=1e-4).all()
            return input_normed, output_normed
        
        if self.training_mode == 'unified':
            print(f"  [normalize unified] Stacking a_data ({a_data.shape}) and u_data ({u_data.shape})")
            au_normed = self.norm_tensor(torch.stack([a_data, u_data], dim=1))
            print(f"  [normalize unified] Joint tensor normed shape: {au_normed.shape}")
            if check_valid:
                denormed_tensor = self.denorm_tensor(au_normed)
                print(f"  [normalize unified check] Denormed tensor shape: {denormed_tensor.shape}")
                # Note: stacking for assertion should match the normalization stacking
                assert torch.isclose(denormed_tensor, torch.stack([a_data, u_data], dim=1), atol=1e-4).all()

            return au_normed

    def _set_conditional_normalizer_aliases(self):
        """Sets input/output normalizer aliases for conditional mode."""
        if self.training_mode == 'conditional':
            if self.pde_direction == 'forward':
                self.input_normalizer = self.a_normalizer
                self.output_normalizer = self.u_normalizer
            else: # inverse
                self.input_normalizer = self.u_normalizer
                self.output_normalizer = self.a_normalizer
        return

    def norm_tensor(self, tensor):
        """Normalize a tensor with two channels: channel 0 with a_normalizer, channel 1 with u_normalizer."""
        # print(f"  [norm_tensor] Input tensor shape: {tensor.shape}")
        assert tensor.shape[1] == 2, "Input tensor must have 2 channels (a, u)"
        normed_a = self.a_normalizer.encode(tensor[:, 0]) * self.gain
        normed_u = self.u_normalizer.encode(tensor[:, 1]) * self.gain
        result = torch.stack([normed_a, normed_u], dim=1)
        # print(f"  [norm_tensor] Channel 0 (a) normed: min={normed_a.min():.4f}, max={normed_a.max():.4f}, mean={normed_a.mean():.4f}, std={normed_a.std():.4f}")
        # print(f"  [norm_tensor] Channel 1 (u) normed: min={normed_u.min():.4f}, max={normed_u.max():.4f}, mean={normed_u.mean():.4f}, std={normed_u.std():.4f}")
        
        # print(f"  [norm_tensor] Output tensor shape: {result.shape}")
        return result
    
    def denorm_tensor(self, normed_tensor,device="cpu"):
        """Denormalize a tensor with two channels: channel 0 with a_normalizer, channel 1 with u_normalizer."""
        # print(f"  [denorm_tensor] Input tensor shape: {normed_tensor.shape}")
        assert normed_tensor.shape[1] == 2, "Input tensor must have 2 channels (a, u)"
        if device=="cpu":
            normed_tensor = normed_tensor.to(device)
        denorm_a = self.a_normalizer.decode(normed_tensor[:, 0] / self.gain)
        denorm_u = self.u_normalizer.decode(normed_tensor[:, 1] / self.gain)
        result = torch.stack([denorm_a, denorm_u], dim=1)
        # print(f"  [denorm_tensor] Output tensor shape: {result.shape}")
        return result

    def norm_input(self, input_data):
        """Normalize input data for training"""
        return self.input_normalizer.encode(input_data) * self.gain

    def denorm_input(self, input_normed):
        """Denormalize input data"""
        input_normed = input_normed.to("cpu")
        return self.input_normalizer.decode(input_normed / self.gain)

    def norm_output(self, output_data):
        """Normalize output data for training"""
        return self.output_normalizer.encode(output_data) * self.gain

    def denorm_output(self, output_normed):
        """Denormalize output data"""
        output_normed = output_normed.to("cpu")
        return self.output_normalizer.decode(output_normed / self.gain)

    def update_normalizer_path(self,path):
        self.normalizer_path = os.path.join(path, "normalizers.npy")
        self.save_normalizers(self.normalizer_path)

    def save_normalizers(self, path):
        """Save canonical normalization parameters to the specified path."""
        if self.a_normalizer is None or self.u_normalizer is None:
            print("Warning: Normalizers not initialized. Cannot save.")
            return
        
        save_data = {
            'a_mean': self.a_normalizer.mean,
            'a_std': self.a_normalizer.std,
            'u_mean': self.u_normalizer.mean,
            'u_std': self.u_normalizer.std,
        }
        torch.save(save_data, path)

        #### old code for conditional version only
        # if self.pde_direction == 'forward':
        #     save_data = {
        #         'a_mean': self.input_normalizer.mean,
        #         'a_std': self.input_normalizer.std,
        #         'u_mean': self.output_normalizer.mean,
        #         'u_std': self.output_normalizer.std,
        #     }
        # else:  # For inverse 
        #     save_data = {
        #         'u_mean': self.input_normalizer.mean,
        #         'u_std': self.input_normalizer.std,
        #         'a_mean': self.output_normalizer.mean,
        #         'a_std': self.output_normalizer.std,
        #     }
        # torch.save(save_data, path)
        print(f"Normalization parameters saved to {path} with PDE direction '{self.pde_direction}'")
    
    def old_get_normalizers(self):
        """Save normalization parameters to the specified path based on pde_direction."""
        save_data = None
        if self.pde_direction == 'forward':
            save_data = {
                'a_mean': self.input_normalizer.mean,
                'a_std': self.input_normalizer.std,
                'u_mean': self.output_normalizer.mean,
                'u_std': self.output_normalizer.std,
            }
        else:  # For inverse  
            save_data = {
                'u_mean': self.input_normalizer.mean,
                'u_std': self.input_normalizer.std,
                'a_mean': self.output_normalizer.mean,
                'a_std': self.output_normalizer.std,
            }
        return save_data
    
    def get_normalizers(self):
        """Return canonical normalization parameters."""
        if self.a_normalizer is None or self.u_normalizer is None: return None
        return {
            'a_mean': self.a_normalizer.mean,
            'a_std': self.a_normalizer.std,
            'u_mean': self.u_normalizer.mean,
            'u_std': self.u_normalizer.std,
        }
    
    def load_normalizers(self, path):
        """Load canonical normalization parameters and set up aliases."""
        checkpoint = torch.load(path)

        # Create and load canonical a_normalizer
        if self.normalizer == "UnitGaussian": self.a_normalizer = UnitGaussianNormalizer(torch.zeros(1))
        elif self.normalizer == "ScaledGaussian": self.a_normalizer = ScaledGaussianNormalizer(torch.zeros(1))
        else: self.a_normalizer = ScaledGaussianNormalizer2(torch.zeros(1))
        self.a_normalizer.mean = checkpoint['a_mean']
        self.a_normalizer.std = checkpoint['a_std']

        # Create and load canonical u_normalizer
        if self.normalizer == "UnitGaussian": self.u_normalizer = UnitGaussianNormalizer(torch.zeros(1))
        elif self.normalizer == "ScaledGaussian": self.u_normalizer = ScaledGaussianNormalizer(torch.zeros(1))
        else: self.u_normalizer = ScaledGaussianNormalizer2(torch.zeros(1))
        self.u_normalizer.mean = checkpoint['u_mean']
        self.u_normalizer.std = checkpoint['u_std']
        
        # Set aliases for backward compatibility
        self._set_conditional_normalizer_aliases()
        
        print(f"Canonical normalization parameters (a, u) loaded from {path}")

    def old_load_normalizers(self, path):
        """Load normalization parameters from the specified path based on pde_direction."""
        checkpoint = torch.load(path)

        if self.pde_direction == 'forward':
            if self.normalizer == "UnitGaussian":
                self.input_normalizer = UnitGaussianNormalizer(torch.zeros(1))
            elif self.normalizer == "ScaledGaussian":
                self.input_normalizer = ScaledGaussianNormalizer(torch.zeros(1))
            else:
                self.input_normalizer = ScaledGaussianNormalizer2(torch.zeros(1))
            # self.input_normalizer = UnitGaussianNormalizer(torch.zeros(1))  # Placeholder tensor
            self.input_normalizer.mean = checkpoint['a_mean']
            self.input_normalizer.std = checkpoint['a_std']

            if self.normalizer == "UnitGaussian":
                self.output_normalizer = UnitGaussianNormalizer(torch.zeros(1))
            elif self.normalizer == "ScaledGaussian":
                self.output_normalizer = ScaledGaussianNormalizer(torch.zeros(1))
            else:
                self.output_normalizer = ScaledGaussianNormalizer2(torch.zeros(1))
            # self.output_normalizer = UnitGaussianNormalizer(torch.zeros(1))  # Placeholder tensor
            self.output_normalizer.mean = checkpoint['u_mean']
            self.output_normalizer.std = checkpoint['u_std']
        else:  # For inverse 
            if self.normalizer == "UnitGaussian":
                self.input_normalizer = UnitGaussianNormalizer(torch.zeros(1))
            elif self.normalizer == "ScaledGaussian":
                self.input_normalizer = ScaledGaussianNormalizer(torch.zeros(1))
            else:
                self.input_normalizer = ScaledGaussianNormalizer2(torch.zeros(1))
            # self.input_normalizer = UnitGaussianNormalizer(torch.zeros(1))  # Placeholder tensor
            self.input_normalizer.mean = checkpoint['u_mean']
            self.input_normalizer.std = checkpoint['u_std']

            if self.normalizer == "UnitGaussian":
                self.output_normalizer = UnitGaussianNormalizer(torch.zeros(1))
            elif self.normalizer == "ScaledGaussian":
                self.output_normalizer = ScaledGaussianNormalizer(torch.zeros(1))
            else:
                self.output_normalizer = ScaledGaussianNormalizer2(torch.zeros(1))
            # self.output_normalizer = UnitGaussianNormalizer(torch.zeros(1))  # Placeholder tensor
            self.output_normalizer.mean = checkpoint['a_mean']
            self.output_normalizer.std = checkpoint['a_std']
        
        print(f"Normalization parameters loaded from {path} with PDE direction '{self.pde_direction}'")
    
    def __getitem__(self, idx):
        raise NotImplementedError("Subclasses should implement this method.")

    def generate_mask(self, spatial_shape, observation_rate=None):
        """Generate random observation mask with given observation rate"""
        H, W = spatial_shape

        # If in evaluation mode, always use the final target rate
        if not self.training:
            observation_rate = self.final_mask_observation_rate
        else:
            observation_rate = self.current_obs_rate if self.enable_sparsity_curriculum else self.final_mask_observation_rate

        if observation_rate >= 1.0:
            # Full observation
            return torch.ones(H, W)

        # Simple random masking
        # base_seed = self.process_seed if hasattr(self, 'process_seed') else 0
        # print("Process seed:", base_seed)
        seed = int(self.current_kimg * 10000)
        generator = torch.Generator(device='cpu').manual_seed(seed)
        mask = torch.rand(H, W, generator=generator) < observation_rate
        # if torch.rand(1).item() < 0.01:
        #     print(f"Generated mask with observation rate {observation_rate:.4f}, actual rate {mask.float().mean().item():.4f}")
        return mask.float()  # Add channel dimension

    def apply_mask_to_input(self, input_data, mask, fill_strategy="mean", noise_scale=0.01):
        if input_data.dim() != mask.dim():
            print(f"Dimension mismatch: input_data.shape={input_data.shape}, mask.shape={mask.shape}")
            # Add missing dimensions to match input_data
            if input_data.dim() == 3 and mask.dim() == 2:  # Add channel dimension to mask
                mask = mask.unsqueeze(0)
            elif input_data.dim() == 2 and mask.dim() == 3:  # Remove channel dimension from mask
                mask = mask.squeeze(0)
    
        """Apply mask to input with configurable fill strategy"""
        if fill_strategy == "zero":
            # Fill masked regions with zeros
            return input_data * mask
        elif fill_strategy == "mean":
            observed_values = input_data * mask
            # Calculate mean only from observed values (where mask == 1)
            observed_sum = observed_values.sum(dim=(-2, -1), keepdim=True)
            observed_count = mask.sum(dim=(-2, -1), keepdim=True)
            # Avoid division by zero
            mean_value = torch.where(observed_count > 0, 
                                observed_sum / observed_count, 
                                input_data.mean())
            return input_data * mask + mean_value * (1 - mask)
        elif fill_strategy == "noise":
            # Fill masked regions with small noise
            noise = torch.randn_like(input_data) * noise_scale
            return input_data * mask + noise * (1 - mask)
        elif fill_strategy == "zero_noise":
            # Fill masked regions with zeros plus very small noise for stability
            # This helps prevent potential issues during training while keeping values close to zero
            small_noise = torch.randn_like(input_data) * (noise_scale * 0.1)  # 10x smaller noise
            return input_data * mask + small_noise * (1 - mask)
        else:
            # Default: zero filling
            return input_data * mask

    def set_training(self, is_training):
        """Set dataset to training or evaluation mode"""
        self.training = is_training
    
    def set_process_seed(self, seed, rank=None):
        """Store the process seed (from training loop) for use in masking functions."""
        self.process_seed = seed
        self.process_rank = rank

    def update_curriculum(self, kimg):
        """Update curriculum parameters based on current training progress"""
        if not self.enable_sample_curriculum and not self.enable_sparsity_curriculum:
            return {
                'sample_rate': 0.0,
                'obs_rate': 1.1,
                'kimg': self.current_kimg
            }

        self.current_kimg = kimg
        # print(f"Updating curriculum at {kimg} kimg")
        
        # Update observation rate (sparsity curriculum)
        if self.enable_sparsity_curriculum and self.sparsity_curriculum_kimg > 0:
            progress = min(1.0, kimg / self.sparsity_curriculum_kimg)
            # print(f"Sparsity curriculum progress: {progress:.4f}")
            
            if self.sparsity_schedule == 'linear':
                self.current_obs_rate = self.initial_obs_rate + (self.final_mask_observation_rate - self.initial_obs_rate) * progress
            elif self.sparsity_schedule == 'cosine':
                self.current_obs_rate = self.final_mask_observation_rate + (self.initial_obs_rate - self.final_mask_observation_rate) * 0.5 * (1 + np.cos(np.pi * progress))
            elif self.sparsity_schedule == 'exponential':
                # print(f"Exponential sparsity schedule progress: {progress:.4f}")
                self.current_obs_rate = self.final_mask_observation_rate + (self.initial_obs_rate - self.final_mask_observation_rate) * ((1 - progress) ** 2.5)
        # print(f"Current observation rate: {self.current_obs_rate:.4f}")
        # Update sample masking rate (sample curriculum)
        if self.enable_sample_curriculum and self.sample_curriculum_kimg > 0:
            progress = min(1.0, kimg / self.sample_curriculum_kimg)
            # print(f"Sample curriculum progress: {progress:.4f}")

            if self.sample_schedule == 'linear':
                self.current_sample_rate = self.initial_sample_rate + (self.final_mask_sample_rate - self.initial_sample_rate) * progress
            elif self.sample_schedule == 'cosine':
                self.current_sample_rate = self.final_mask_sample_rate + (self.initial_sample_rate - self.final_mask_sample_rate) * 0.5 * (1 + np.cos(np.pi * progress))
            elif self.sample_schedule == 'exponential':
                self.current_sample_rate = self.final_mask_sample_rate + (self.initial_sample_rate - self.final_mask_sample_rate) * (1 - progress) ** 2
            else:
                # Keep constant sample rate when curriculum is disabled
                self.current_sample_rate = self.final_mask_sample_rate
        # print("Current sample rate:", self.current_sample_rate)

        curriculum_values = {
            'sample_rate': self.current_sample_rate,
            'obs_rate': self.current_obs_rate,
            'kimg': self.current_kimg
        }

        # print(f"Updated curriculum values: {curriculum_values}")
        return curriculum_values

    def apply_masking(self, input_data, kimg=None, return_masks=True):
        """Apply masking to input data based on current curriculum state
        
        Args:
            input_data (torch.Tensor): Input tensor to be masked (batch or single image)
            kimg (int, optional): Current training progress in kimg. Updates curriculum if provided.
            return_masks (bool): Whether to return masks separately
            
        Returns:
            tuple: (masked_inputs, masks) if return_masks=True, otherwise just masked_inputs
        """
        # Update curriculum if kimg is provided
        if self.random_sample_masking:
            if kimg is not None and kimg != self.current_kimg:
                curriculum = self.update_curriculum(kimg)
                # print("Curriculum updated based on kimg")
                # print(curriculum)
            else:
                curriculum = {
                    'sample_rate': self.current_sample_rate,
                    'obs_rate': self.current_obs_rate,
                    'kimg': self.current_kimg
                }
                # print("Using existing curriculum values")
            # print(f"Applying masking with curriculum: {curriculum} at kimg={self.current_kimg}")

            # Extract values from curriculum dictionary
            sample_rate = curriculum['sample_rate']
            obs_rate = curriculum['obs_rate']

        # Handle batch or single input
        if len(input_data.shape) == 2:  # Single image (H, W)
            input_data = input_data.unsqueeze(0)  # Add batch dimension
            is_batch = False
        else:
            is_batch = True
        
        batch_size = input_data.shape[0]
        masks = []
        masked_inputs = []
        
        for i in range(batch_size):
            # Determine whether to apply masking to this sample
            apply_mask = True
            if self.training and self.random_sample_masking:
                # Random decision based on current sample rate
                sample_rate = sample_rate if self.enable_sample_curriculum else self.final_mask_sample_rate
                # process_seed = self.process_seed if hasattr(self, 'process_seed') else 0
                sample_seed = int(self.current_kimg * 10000 + i)
                sample_gen = torch.Generator(device='cpu').manual_seed(sample_seed)
                apply_mask = torch.rand(1, generator=sample_gen).item() < sample_rate
            
                if apply_mask and self.final_mask_observation_rate < 1.0:
                    # Generate mask with current observation rate
                    obs_rate = obs_rate if self.enable_sparsity_curriculum else self.final_mask_observation_rate
                    mask = self.generate_mask(input_data[i].shape, observation_rate=obs_rate)
                    
                    # Apply mask to input
                    masked_input = self.apply_mask_to_input(input_data[i], mask, self.fill_strategy, self.noise_scale)
                else:
                    # Full observation (no masking)
                    mask = torch.ones_like(input_data[i])
                    masked_input = input_data[i]
            else:
                # Full observation (no masking)
                mask = torch.ones_like(input_data[i])
                masked_input = input_data[i]
            
            masks.append(mask)
            masked_inputs.append(masked_input)
        # Stack batch results
        masks = torch.stack(masks)
        masked_inputs = torch.stack(masked_inputs)
        # # Sample a few inputs and show masking effect
        # if is_batch and batch_size > 0:
        #     i = 0  # First sample
        #     mask = masks[i]
        #     input_orig = input_data[i]
        #     input_masked = masked_inputs[i]
            
        #     # Find a masked point (where mask == 0)
        #     masked_pts = (mask < 1.0).nonzero()
        #     if len(masked_pts) > 0:
        #         pt = tuple(masked_pts[0].tolist())
        #         print(f"DEBUG - Masked point at {pt}:")
        #         print(f"  Original: {input_orig[pt[1:] if len(pt) > 2 else pt].item():.4f}")
        #         print(f"  Masked: {input_masked[pt[1:] if len(pt) > 2 else pt].item():.4f}")
        #         print(f"  Mask: {mask[pt].item():.4f}")
        #         print(f"  Using fill strategy: {self.fill_strategy}")
                
        #         if self.fill_strategy == "mean":
        #             # Calculate the mean that should be used
        #             observed_values = input_orig * mask
        #             mean_value = observed_values.sum() / (mask.sum() + 1e-8)
        #             print(f"  Expected masked value (mean): {mean_value.item():.4f}")
        # Return single image if input was single
        if not is_batch:
            masks = masks[0]
            masked_inputs = masked_inputs[0]
        
        if return_masks:
            return masked_inputs, masks
        else:
            return masked_inputs