
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import time
#add onsets here
#want to get a subset without actually doing a copy, we just want to keep the data references...
#should handle case when there is none in a dataset, it breaks at the moment
class RegressionNumpyDataset(Dataset):
    def __init__(self, dataset_regression_features, dataset_meg_targets):
        for dataset_index in range(len(dataset_regression_features)):
            assert dataset_regression_features[dataset_index].shape[0] == len(dataset_meg_targets[dataset_index])

        self.dataset_lens = [dataset_feature.shape[0] for dataset_feature in dataset_regression_features]
            
        self.cumulative_indices = np.cumsum([0] + self.dataset_lens)        
        self.dataset_regression_features = dataset_regression_features
        self.dataset_meg_targets = dataset_meg_targets
        self.total_len = sum(self.dataset_lens)

    def __len__(self):
        return self.total_len
    
    def get_batch(self, indices):
        # Vectorized batch retrieval:
        # Group the global indices by dataset.
        indices = np.array(indices)
        batch_features = []
        batch_targets = []
        # For each dataset, figure out which indices belong there.
        for ds in range(len(self.dataset_regression_features)):
            # Determine mask for indices that fall into dataset ds
            mask = (indices >= self.cumulative_indices[ds]) & (indices < self.cumulative_indices[ds+1])
            if not np.any(mask):
                continue
            # Compute local indices within this dataset
            local_indices = indices[mask] - self.cumulative_indices[ds]

            feats = self.dataset_regression_features[ds][local_indices]
            targs = self.dataset_meg_targets[ds][local_indices]
            batch_features.append(feats)
            batch_targets.append(targs)
        # Concatenate results from all datasets.
        batch_features = np.concatenate(batch_features, axis=0)
        batch_targets = np.concatenate(batch_targets, axis=0)
        return torch.tensor(batch_features, dtype=torch.float32), torch.tensor(batch_targets, dtype=torch.float32)
    
    def __getitem__(self, index):
        # If index is a list or np.ndarray, call get_batch.
        if isinstance(index, (list, np.ndarray)):
            return self.get_batch(index)
        else:
            # Otherwise, process a single sample
            ds = np.searchsorted(self.cumulative_indices[1:], index, side="right")
            local_idx = index - self.cumulative_indices[ds]
            feat = self.dataset_regression_features[ds][local_idx]
            targ = self.dataset_meg_targets[ds][local_idx]
            return torch.tensor(feat, dtype=torch.float32), torch.tensor(targ, dtype=torch.float32)
    
    def subset(self, dataset_valid_indices):
        return RegressionNumpyDatasetSubset(self.dataset_regression_features, self.dataset_meg_targets, dataset_valid_indices)

class RegressionNumpyDatasetSubset(Dataset):
    def __init__(self, dataset_regression_features, dataset_meg_targets, dataset_valid_indices):
        for dataset_index in range(len(dataset_regression_features)):
            assert len(dataset_regression_features[dataset_index]) == len(dataset_meg_targets[dataset_index])

        self.dataset_lens = [len(valid_indices) for valid_indices in dataset_valid_indices]
        self.cumulative_indices = np.cumsum([0] + self.dataset_lens)        
        self.dataset_valid_indices = dataset_valid_indices
        self.dataset_regression_features = dataset_regression_features
        self.dataset_meg_targets = dataset_meg_targets
        self.total_len = sum(self.dataset_lens)

    def __len__(self):
        return self.total_len
    
    def get_batch(self, indices):
        # Vectorized batch retrieval:
        # Group the global indices by dataset.
        indices = np.array(indices)
        batch_features = []
        batch_targets = []
        # For each dataset, figure out which indices belong there.
        for ds in range(len(self.dataset_regression_features)):
            # Determine mask for indices that fall into dataset ds
            mask = (indices >= self.cumulative_indices[ds]) & (indices < self.cumulative_indices[ds+1])
            if not np.any(mask):
                continue
            # Compute local indices within this dataset
            local_indices = indices[mask] - self.cumulative_indices[ds]

            feats = self.dataset_regression_features[ds][self.dataset_valid_indices[ds][local_indices]]
            targs = self.dataset_meg_targets[ds][self.dataset_valid_indices[ds][local_indices]]
            batch_features.append(feats)
            batch_targets.append(targs)
        # Concatenate results from all datasets.
        batch_features = np.concatenate(batch_features, axis=0)
        batch_targets = np.concatenate(batch_targets, axis=0)
        return torch.tensor(batch_features, dtype=torch.float32), torch.tensor(batch_targets, dtype=torch.float32)
    
    def __getitem__(self, index):
        # If index is a list or np.ndarray, call get_batch.
        if isinstance(index, (list, np.ndarray)):
            return self.get_batch(index)
        else:
            # Otherwise, process a single sample
            ds = np.searchsorted(self.cumulative_indices[1:], index, side="right")
            local_idx = index - self.cumulative_indices[ds]
            feat = self.dataset_regression_features[ds][self.dataset_valid_indices[ds][local_idx]]
            targ = self.dataset_meg_targets[ds][self.dataset_valid_indices[ds][local_idx]]
            return torch.tensor(feat, dtype=torch.float32), torch.tensor(targ, dtype=torch.float32)
    