import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import utils



class EEGDataset(Dataset):
    def __init__(self, eeg_path, image_path):
        """
        Multimodal dataset loader
        Parameters:
            eeg_path: EEG data file path (.npy)
            image_path: Image data file path (.npy)
        """
        # Load large data files using memory mapping (optional)
        self.eeg_data = np.load(eeg_path, mmap_mode='r', allow_pickle=True)  # Use memory mapping to save memory
        self.eeg_data = self.eeg_data['preprocessed_eeg_data']
        self.image_data = np.load(image_path, mmap_mode='r')

        # Validate data consistency
        assert len(self.eeg_data) == len(self.image_data), \
            "Inconsistent number of EEG and image data samples!"

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        # Convert numpy array to torch tensor
        items = []
        eeg = self.eeg_data[idx].copy()
        items.append(eeg)
        image = self.image_data[idx].copy()
        items.append(image)

        return items



# eeg_path, image_path
def get_eeg_dataloader(
        eeg_path,
        image_path,
        batch_size,
        num_workers=1,
        seed=42,
        is_shuffle=True,
):
    utils.seed_everything(seed)
    print("eeg_path:{}".format(eeg_path))
    print("image_path:{}".format(image_path))
    dataset = EEGDataset(eeg_path, image_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                            shuffle=is_shuffle)

    return dataloader



def get_eeg_dls(subject, data_path, batch_size, val_batch_size, num_workers, seed):
    # Check subject format and build path correctly
    if isinstance(subject, int):
        subject_str = str(subject).zfill(2)
    else:
        subject_str = subject
        
    training_eeg_path = "/path/to/eeg_training_folder".format(data_path, subject_str)
    training_image_path = "/path/to/image_training_folder".format(data_path)
    test_eeg_path = "/path/to/eeg_test_folder".format(data_path, subject_str)
    test_image_path = "/path/to/image_test_folder".format(data_path)

    train_dl = get_eeg_dataloader(
        training_eeg_path,
        training_image_path,
        batch_size=batch_size,
        num_workers=num_workers,
        seed=seed,
        is_shuffle=True,
    )

    val_dl = get_eeg_dataloader(
        test_eeg_path,
        test_image_path,
        batch_size=val_batch_size,
        num_workers=num_workers,
        seed=seed,
        is_shuffle=True,
    )

    num_train = len(train_dl.dataset)
    num_val = len(val_dl.dataset)
    print(training_eeg_path, "\n", training_image_path)
    print(test_eeg_path, "\n", test_image_path)
    print("number of train data:", num_train, "\n")
    print("batch_size", batch_size, "\n")
    print("number of val data:", num_val, "\n")
    print("val_batch_size", batch_size, "\n")

    return train_dl, val_dl
