import torch
import os
import numpy as np
import scipy.io as sio
from training.datasets.dataset import Dataset
from training.evaluation_utils import get_helmholtz_loss_diffusion_pde

class DiffusionPDEHelmholtzDataset(Dataset):
    def __init__(self, path, resolution=None, pde_direction='forward', use_labels=True, 
                label_shape=None, normalizer_path=None, offset=0, num=None, 
                train_downsample=False, normalizer="UnitGaussian", training_mode="conditional",
                use_sparse_conditioning=False, masking_strategy=None, unified_regime=None, **kwargs):
        """
        Initialize and load data from a single .mat file or a directory containing multiple .mat files.
        
        Args:
            path (str): Path to a single .mat file or a directory with multiple .mat files.
            pde_direction (str): Specifies 'forward' or 'inverse' direction for setting input and output.
        """
        # Load data depending on whether path is a directory or a single file
        super().__init__(gain=1.0, use_labels=use_labels, pde_direction=pde_direction, 
                normalizer_path=normalizer_path, normalizer=normalizer,  training_mode=training_mode, 
                masking_strategy=masking_strategy) 

        self.use_sparse_conditioning = use_sparse_conditioning
        self.training_mode = training_mode
        if self.training_mode == 'unified':
            self.task_probs = unified_regime.get('task_probs', {
                'full_fwd': 0.25, 'full_inv': 0.25, 'sparse_fwd': 0.2, 'sparse_inv': 0.2, 'uncond': 0.1
            })
            # Normalize probabilities to sum to 1
            total_prob = sum(self.task_probs.values())
            if total_prob > 0:
                self.task_probs = {k: v / total_prob for k, v in self.task_probs.items()}
            
            self.task_types = list(self.task_probs.keys())
            self.task_prob_values = list(self.task_probs.values())

            # Define the observation rate ranges for sparse tasks from the new dictionary
            self.sparse_obs_rates = unified_regime.get('sparse_obs_rates', {
                'sparse_fwd': [0.03, 0.5],
                'sparse_inv': [0.03, 0.5]
            })


        print(f"--- Initializing DiffusionPDEHelmholtzDataset ---")
        print(f"  Training mode: {self.training_mode}, PDE direction: {self.pde_direction}")

        if os.path.isdir(path):
            self.a, self.u = self.load_and_merge_data(path)
        else:
            training_raw_data = sio.loadmat(path)
            self.a = torch.tensor(training_raw_data["f_data"], dtype=torch.float32)
            self.u = torch.tensor(training_raw_data["psi_data"], dtype=torch.float32) #training script will have a and u

        original_size = 128
        if train_downsample:
            target_size =  resolution if resolution else original_size
        else:
            target_size = original_size
        sub = original_size // target_size 
        # Apply optional slicing for evaluation subset
        if num is not None:
            self.a = self.a[offset:offset + num]
            self.u = self.u[offset:offset + num]

        self.a = torch.tensor(self.a[:, ::sub, ::sub][:, :target_size, :target_size], dtype=torch.float32)
        self.u = torch.tensor(self.u[:, ::sub, ::sub][:, :target_size, :target_size], dtype=torch.float32)
        print(f"  Downsampled data shapes: a={self.a.shape}, u={self.u.shape}")

        if self.training_mode == 'unified':
            normed_tensor = self.normalize(self.a, self.u, gain=1.0, check_valid=True)
            print(f"  Unified mode: joint_data shape={normed_tensor.shape}")
        elif self.training_mode == 'conditional':
            norm_input, norm_output = self.normalize(self.a, self.u, gain=1.0, check_valid=True)
            print(f"  Conditional mode: input_data shape={norm_input.shape}, output_data shape={norm_output.shape}")

        self.input_data = norm_input if self.training_mode == 'conditional' else None
        self.output_data = norm_output if self.training_mode == 'conditional' else None
        self.joint_data = normed_tensor if self.training_mode == 'unified' else None

        print(f"  After normalization: input_data shape={self.input_data.shape if self.input_data is not None else 'N/A'}, output_data shape={self.output_data.shape if self.output_data is not None else 'N/A'}, joint_data shape={self.joint_data.shape if self.joint_data is not None else 'N/A'}")

        # Set label shape
        if self.training_mode == 'conditional':
            self.label_shape = label_shape if label_shape else self.output_data.shape[1:] if self.output_data is not None else self.input_data.shape[1:]
        else: # unified
            self.label_shape = label_shape if label_shape else self.joint_data.shape[1:] 
        # Set label shape
        # self.label_shape = label_shape if label_shape else self.output_data.shape[1:] if self.output_data is not None else self.input_data.shape[1:]
        self.res = resolution #??
        # self.pde_type = "Helmholtz"

    @property
    def resolution(self):
        return self.res if self.res else self.input[1]
    
    @property
    def name(self):
        return "DiffusionPDEHelmholtzDataset"
    
    @property
    def print_name(self):
        return "diffusion_pde_helmholtz" 
    
    @property
    def pde_loss_function(self):
        return get_helmholtz_loss_diffusion_pde
    
    @property
    def num_channels(self):
        return 2 if self.training_mode == "unified" else 1
    
    @property
    def y_dim(self):
        """How many channels are in y?"""
        return 2 if self.training_mode == "unified" else 1

    @property
    def has_masks(self):
        """Indicates if dataset provides masks"""
        return self.use_sparse_conditioning or self.training_mode == 'unified'
    
    def __len__(self):
        return len(self.joint_data) if self.training_mode == 'unified' else len(self.input_data)
    
    def __getitem__(self, idx):
        if self.training_mode == 'unified':
            return self._getitem_unified(idx)
        else: # 'conditional'
            return self._getitem_conditional(idx)

    def _getitem_conditional(self, idx):
        """Return input data and output data if labels are available. Otherwise, return zeroes of label shape."""
        input_data = self.input_data[idx]
        # print("input batch shape: ", input_data.shape)
        output_data = self.output_data[idx] if self.output_data is not None else torch.zeros(self.label_shape)
        # print("output_data batch shape: ", output_data.shape)
        # diffusion models learns output data distribution conditioned on input data
        mask = torch.ones_like(input_data)
        # print("mask shape: ", mask.shape)
        return output_data, input_data, mask

    def _getitem_unified(self, idx):
        """
        Returns a dictionary for the unified inpainting framework.
        """
        au_data = self.joint_data[idx]
        if self.training:
            mask = self._generate_unified_mask(spatial_shape=au_data.shape[1:])
        else:
            # During evaluation, use a dummy mask and handle it outside
            mask = torch.ones((2, au_data.shape[1], au_data.shape[2]), dtype=torch.float32)
        # For unified, the 'label' is not used in the same way. We return a placeholder.
        return au_data, torch.zeros_like(au_data), mask  # Return mask as both input and output mask

    def _generate_unified_mask(self, spatial_shape):
        """
        Generates a unified 2-channel conditioning mask 'm' based on a multi-task sampling strategy.
        """
        H, W = spatial_shape
        m_a = torch.zeros((H, W), dtype=torch.float32)
        m_u = torch.zeros((H, W), dtype=torch.float32)

        # Sample a task based on its probability.
        task = np.random.choice(self.task_types, p=self.task_prob_values)

        if task == 'full_fwd': # Full Forward (a -> u)
            m_a = torch.ones((H, W), dtype=torch.float32)
        elif task == 'full_inv':  # Full Inverse (u -> a)
            m_u = torch.ones((H, W), dtype=torch.float32)
        elif task in self.sparse_obs_rates: # Handle all sparse tasks defined in the dictionary
            min_rate, max_rate = self.sparse_obs_rates[task]
            
            # Power-law sampling to heavily favor values near the minimum (high sparsity)
            # Using inverse transform sampling with a power distribution
            alpha = 3.0  # Higher alpha means stronger bias toward min_rate
            r = torch.rand(1).item()
            
            # Apply power-law transformation to favor low values
            # This generates values heavily skewed toward min_rate
            obs_rate = min_rate + (1 - r**alpha) * (max_rate - min_rate)
            
            # Additional bias: 50% chance to use values very close to min_rate
            if torch.rand(1).item() < 0.5:
                # Use a value in the lowest 10% of the range
                obs_rate = min_rate + torch.rand(1).item() * (max_rate - min_rate) * 0.1
            
            # Create a random binary mask with that rate
            sparse_mask = (torch.rand(spatial_shape) < obs_rate).float()
            
            if task == 'sparse_fwd':
                m_a = sparse_mask
            elif task == 'sparse_inv':
                m_u = sparse_mask
        # else: task == 'uncond' -> m_a and m_u remain zeros

        mask = torch.stack([m_a, m_u], dim=0)
        # print("Unified mask shape:", mask.shape)
        # print("Task selected:", task)
        # print("m_a sum:", m_a.sum().item(), "m_u sum:", m_u.sum().item())
        return mask

    def load_and_merge_data(self, dir_path):
        """
        Load and merge data from all .mat files in a directory.
        
        Args:
            dir_path (str): Path to the directory containing multiple .mat files.
        
        Returns:
            tuple: Concatenated tensors for 'a' and 'u' data.
        """
        # Initialize lists to collect data
        a_data = []
        u_data = []

        # Iterate over all .mat files in the directory
        for file_name in os.listdir(dir_path):
            if file_name.endswith('.mat'):
                file_path = os.path.join(dir_path, file_name)
                training_raw_data = sio.loadmat(file_path)
                
                # Append data to the lists
                a_data.append(torch.tensor(training_raw_data["f_data"], dtype=torch.float32))
                u_data.append(torch.tensor(training_raw_data["psi_data"], dtype=torch.float32))
        
        # Concatenate all data along the first dimension
        a_merged = torch.cat(a_data, dim=0)
        u_merged = torch.cat(u_data, dim=0)
        return a_merged, u_merged
