import torch
from scipy import sparse

def load_data(data, modality, paired, n_batches, stage, norm=False):
    # Store feature counts for each modality
    feature_counts = {}
    modality_indices = {}
    
    if data == 'mm_sim':
        data_dir = './01_data/mm_sim/'
        
        # Pre-compute feature dimensions for each modality (only load once)
        if modality in ['rna-atac', 'rna-protein', 'atac-protein', 'all'] or not paired:
            # Get feature counts for individual modalities
            rna_features = torch.tensor(sparse.load_npz(data_dir+"observed_transcription_batch_0.npz").toarray()).shape[1]
            atac_features = torch.tensor(sparse.load_npz(data_dir+"peaks_batch_0.npz").toarray()).shape[1]
            protein_features = torch.tensor(sparse.load_npz(data_dir+"prot_counts_batch_0.npz").toarray()).shape[1]
            
            feature_counts['rna'] = rna_features
            feature_counts['atac'] = atac_features
            feature_counts['protein'] = protein_features
        
        # Load the data once
        data = []
        for i in range(n_batches):
            if modality == 'rna':
                if stage == 'noisy':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                elif stage == 'raw':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray()))
                elif stage == 'processed':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray()))
            elif modality == 'atac':
                if stage == 'noisy':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                elif stage == 'raw':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray()))
                elif stage == 'processed':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray()))
            elif modality == 'protein':
                if stage == 'noisy':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
                elif stage == 'raw':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray()))
                elif stage == 'processed':
                    data.append(torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray()))
            elif modality == 'rna-atac':
                rna_data = None
                atac_data = None
                if stage == 'noisy':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray())
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray())
                elif stage == 'raw':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray())
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray())
                elif stage == 'processed':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray())
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray())
                
                if norm:
                    rna_data = rna_data / torch.norm(rna_data.float(), dim=1, keepdim=True)
                    atac_data = atac_data / torch.norm(atac_data.float(), dim=1, keepdim=True)
                
                # Store where each modality begins and ends
                if i == 0:
                    modality_indices['rna'] = (0, rna_data.shape[1])
                    modality_indices['atac'] = (0, atac_data.shape[1])
                
                if paired:
                    # For paired data, we concatenate and append once
                    data.append(torch.cat([rna_data, atac_data], dim=1))
                else:
                    # For unpaired data, we pad the data
                    rna_pad = torch.zeros_like(atac_data)
                    atac_pad = torch.zeros_like(rna_data)
                    rna_data = torch.cat([rna_data, rna_pad], dim=1)
                    atac_data = torch.cat([atac_pad, atac_data], dim=1)
                    data.append(torch.cat([rna_data, atac_data], dim=0))
                    #data.append([rna_data, atac_data])
            elif modality == 'rna-protein':
                rna_data = None
                protein_data = None
                if stage == 'noisy':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray())
                elif stage == 'raw':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray())
                elif stage == 'processed':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray())
                if norm:
                    rna_data = rna_data / torch.norm(rna_data.float(), dim=1, keepdim=True)
                    protein_data = protein_data / torch.norm(protein_data.float(), dim=1, keepdim=True)
                # Store where each modality begins and ends
                if i == 0:
                    modality_indices['rna'] = (0, rna_data.shape[1])
                    modality_indices['protein'] = (0, protein_data.shape[1])
                if paired:
                    # For paired data, we concatenate and append once
                    data.append(torch.cat([rna_data, protein_data], dim=1))
                else:
                    # For unpaired data, we append each modality separately
                    #data.append([rna_data, protein_data])
                    rna_pad = torch.zeros_like(protein_data)
                    protein_pad = torch.zeros_like(rna_data)
                    rna_data = torch.cat([rna_data, rna_pad], dim=1)
                    protein_data = torch.cat([protein_pad, protein_data], dim=1)
                    data.append(torch.cat([rna_data, protein_data], dim=0))
            elif modality == 'atac-protein':
                atac_data = None
                protein_data = None
                if stage == 'noisy':
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray())
                elif stage == 'raw':
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray())
                elif stage == 'processed':
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray())
                if norm:
                    atac_data = atac_data / torch.norm(atac_data.float(), dim=1, keepdim=True)
                    protein_data = protein_data / torch.norm(protein_data.float(), dim=1, keepdim=True)
                # Store where each modality begins and ends
                if i == 0:
                    modality_indices['atac'] = (0, atac_data.shape[1])
                    modality_indices['protein'] = (0, protein_data.shape[1])
                if paired:
                    # For paired data, we concatenate and append once
                    data.append(torch.cat([atac_data, protein_data], dim=1))
                else:
                    # For unpaired data, we append each modality separately
                    #data.append([atac_data, protein_data])
                    atac_pad = torch.zeros_like(protein_data)
                    protein_pad = torch.zeros_like(atac_data)
                    atac_data = torch.cat([atac_data, atac_pad], dim=1)
                    protein_data = torch.cat([protein_pad, protein_data], dim=1)
                    data.append(torch.cat([atac_data, protein_data], dim=0))
            elif modality == 'all':
                rna_data = None
                atac_data = None
                protein_data = None
                if stage == 'noisy':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray())
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray())
                elif stage == 'raw':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"potential_transcription_batch_{i}.npz").toarray())
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"open_chromatin_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prots_translated_batch_{i}.npz").toarray())
                elif stage == 'processed':
                    rna_data = torch.tensor(sparse.load_npz(data_dir+f"real_transcription_batch_{i}.npz").toarray())
                    atac_data = torch.tensor(sparse.load_npz(data_dir+f"peaks_nonoise_batch_{i}.npz").toarray())
                    protein_data = torch.tensor(sparse.load_npz(data_dir+f"prots_real_batch_{i}.npz").toarray())
                if norm:
                    rna_data = rna_data / torch.norm(rna_data.float(), dim=1, keepdim=True)
                    atac_data = atac_data / torch.norm(atac_data.float(), dim=1, keepdim=True)
                    protein_data = protein_data / torch.norm(protein_data.float(), dim=1, keepdim=True)
                # Store where each modality begins and ends
                if i == 0:
                    modality_indices['rna'] = (0, rna_data.shape[1])
                    modality_indices['atac'] = (0, atac_data.shape[1])
                    modality_indices['protein'] = (0, protein_data.shape[1])
                if paired:
                    # For paired data, we concatenate and append once
                    data.append(torch.cat([rna_data, atac_data, protein_data], dim=1))
                else:
                    # For unpaired data, we append each modality separately
                    data.append([rna_data, atac_data, protein_data])
        # If paired, we only have one tensor
        '''
        if paired:
            if len(data) == 0:
                raise ValueError("No data found for the specified modality and stage.")
            data = torch.cat(data, dim=0)
        else:
            if modality in ['rna', 'atac', 'protein']:
                data = torch.cat(data, dim=0)
            else:
                mod_data = {}
                for i in range(len(modality_indices)):
                    mod_data[modality.split('-')[i]] = torch.cat([batch[i] for batch in data], dim=0)
                data = mod_data
        '''
            #print(len(data), len(data[0]), data[0][0].shape, data[0][1].shape)
            # For unpaired data, we keep separate lists for each modality
            #if len(data) == 0:
            #    raise ValueError("No data found for the specified modality and stage.")
            #data = [torch.cat([batch[i] for batch in data], dim=0) for i in range(len(modality_indices))]

        # Concatenate the data
        #if isinstance(data[0], torch.Tensor):
        #    data = torch.cat(data, dim=0)
        #else:
        #    # For multimodal unpaired data, keep separate lists
        #    data = {mod: torch.cat([batch[i] for batch in data], dim=0) 
        #           for i, mod in enumerate(modality_indices.keys())}

        data = torch.cat(data, dim=0)

    return data, feature_counts, modality_indices

# define a dataloader

class MMSimData(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
        self.n_modalities = len(data) if isinstance(data, list) else 1
        self.n_samples = data[0].shape[0] if isinstance(data, list) else data.shape[0]
        self.n_features = [mod.shape[1] for mod in data] if isinstance(data, list) else data.shape[1]
        self.total_features = sum(mod.shape[1] for mod in data) if isinstance(data, list) else data.shape[1]

        # create a mask for all rows that are all zero
        # first identify the rows that are all zero
        mod_mask_rows = [torch.all(mod == 0, dim=1) for mod in data]
        # if there are any rows that are all zero in any modality, we mask them
        #if any([torch.any(mod_mask) for mod_mask in mod_mask_rows]):
        self.mask = [torch.ones((self.n_samples, n_features), dtype=torch.bool) for n_features in self.n_features]
        for i, mod in enumerate(mod_mask_rows):
            self.mask[i][mod,:] = 0
        self.mask = torch.cat(self.mask, dim=1)
        #else:
        #    self.mask = None

    def __len__(self):
        return len(self.data[0]) if isinstance(self.data, list) else len(self.data)

    def __getitem__(self, idx):
        #if self.mask is None:
        #    # return a nan as mask
        #    return [self.data[i][idx].float() for i in range(self.n_modalities)], float('nan')
        #else:
        return [self.data[i][idx].float() for i in range(self.n_modalities)], self.mask[idx]


class PairedMultimodalData(torch.utils.data.Dataset):
    """Simple dataset for paired multimodal data without masking"""
    def __init__(self, data, modality_names):
        print(len(data), modality_names)
        self.data = data
        self.n_modalities = len(data)
        self.modality_names = modality_names
        self.n_samples = data[modality_names[0]].shape[0]
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return [self.data[i][idx].float() for i in self.modality_names]


class PaddedMultimodalDataWithMasks(torch.utils.data.Dataset):
    """
    Dataset class that handles padded multimodal data and generates masks on-the-fly.
    Masks indicate valid (non-padded) positions in the sequence dimension.
    """
    
    def __init__(self, padded_tensors, sequence_lengths, modality_names):
        """
        Args:
            padded_tensors: Dict of {modality_name: padded_tensor}
            sequence_lengths: Dict of {modality_name: [seq_len_sample0, seq_len_sample1, ...]}
            modality_names: List of modality names in order
        """
        self.padded_tensors = padded_tensors
        self.sequence_lengths = sequence_lengths
        self.modality_names = modality_names
        self.n_modalities = len(modality_names)
        self.n_samples = len(next(iter(sequence_lengths.values())))
        
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        """
        Returns:
            tuple: (data_list, mask_list) where both are lists of tensors per modality
        """
        data_list = []
        mask_list = []
        
        for modality in self.modality_names:
            # Get padded data for this sample
            data = self.padded_tensors[modality][idx]
            
            # Create mask based on sequence length
            seq_len = self.sequence_lengths[modality][idx]
            mask_shape = data.shape
            mask = torch.zeros(mask_shape, dtype=torch.bool)
            
            if isinstance(seq_len, tuple):
                # 3D data: seq_len is (height, width)
                height, width = seq_len
                if height > 0 and width > 0:
                    mask[:height, :width] = True
            elif seq_len > 0:
                # 2D data: seq_len is scalar
                if len(data.shape) == 2:
                    # 2D data: (seq_length, channels)
                    mask[:seq_len] = True
                elif len(data.shape) == 3:
                    # 3D data: (height, width, channels) but seq_len is scalar (fallback)
                    mask[:seq_len] = True
                else:
                    # Fallback: mark first seq_len positions in first dimension
                    mask[:seq_len] = True
            
            data_list.append(data.float())
            mask_list.append(mask)
            
        return data_list, mask_list