import torch
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
import mne

class PretrainingDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = Path(data_path)
        # Collect all .npy file paths
        self.files = [file_path for file_path in self.data_path.rglob('*.npy')]

        # Calculate the total number of examples
        self.total_examples = len(self.files)

        # Create mne Info object
        self.sfreq = 256  # Sample frequency 
        self.ch_names = [
            'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF', 
            'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF', 
            'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF', 
            'EEG T6-REF', 'EEG A1-REF', 'EEG A2-REF', 'EEG FZ-REF', 'EEG CZ-REF', 
            'EEG PZ-REF', 'EEG T1-REF', 'EEG T2-REF']
        
        self.ch_types = ['eeg'] * len(self.ch_names)  # Channel types
        self.info = mne.create_info(ch_names=self.ch_names, sfreq=self.sfreq, ch_types=self.ch_types)

    def __len__(self):
        return self.total_examples

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        if idx < 0 or idx >= self.total_examples:
            raise IndexError("Index out of range")

        file_path = self.files[idx]
        
        # try:
        # Load the data from the .npy file
        data = np.load(file_path)
        
        # Create RawArray
        raw = mne.io.RawArray(data, self.info, verbose="ERROR")
        
        # Apply filters
        raw.filter(l_freq=0.1, h_freq=75.0, verbose="ERROR")
        raw.notch_filter(60, verbose="ERROR")

        # Extract the data
        data = raw.get_data()

        if not isinstance(data, torch.Tensor):
            # Convert to PyTorch tensor with dtype torch.float32
            data = torch.Tensor(data)

        
        # min-max to [-1, 1]
        max_X = data.max()
        min_X = data.min()
        data = (data - min_X) / (max_X - min_X) # [0, 1]
        data = (data - 0.5) * 2 # [-0.5, 0.5] -> [-1, 1]
        
        return {"input": data}
        # except Exception as e:
        #     print(f"Error loading data from {file_path}: {e}")
