import torch
import torch.utils.data as Data
from scipy import io
import os

class MAMEMDataLoader:
    def __init__(self, subject, ratio=8, data_path='./data/MAMEM/', bs=64, dev='cpu', finetune=False):
        self.finetune = finetune
        self.subject = subject
        self.ratio = ratio
        self.data_path = data_path
        self.bs = bs
        self.dev = torch.device(dev)
        self.timeseries_length = 125

        self.trainloader = None
        self.validloader = None
        self.testloader = None

    def get_dataloader(self):
        train = io.loadmat(os.path.join(self.data_path, 'U' + f'{int(self.subject):03d}' + '.mat'))

        tempdata = torch.Tensor(train['x_test'])
        templabel = torch.Tensor(train['y_test']).view(-1)

        x_train = tempdata[:300]
        y_train = templabel[:300]

        x_valid = tempdata[300:400]
        y_valid = templabel[300:400]

        x_test = tempdata[400:500]
        y_test = templabel[400:500]

        self.x_train = x_train.to(self.dev)
        self.y_train = y_train.long().to(self.dev)
        self.x_valid = x_valid.to(self.dev)
        self.y_valid = y_valid.long().to(self.dev)
        self.x_test = x_test.to(self.dev)
        self.y_test = y_test.long().to(self.dev)

        if self.finetune:
            self.train_subject = torch.full((300, 1, 1, 125), int(self.subject), device=self.dev)
            self.valid_subject = torch.full((100, 1, 1, 125), int(self.subject), device=self.dev)
            self.test_subject = torch.full((100, 1, 1, 125), int(self.subject), device=self.dev)
            train_dataset = Data.TensorDataset(self.x_train, self.train_subject, self.y_train)
            valid_dataset = Data.TensorDataset(self.x_valid, self.valid_subject, self.y_valid)
            test_dataset = Data.TensorDataset(self.x_test, self.test_subject, self.y_test)
        else:
            train_dataset = Data.TensorDataset(self.x_train, self.y_train)
            valid_dataset = Data.TensorDataset(self.x_valid, self.y_valid)
            test_dataset = Data.TensorDataset(self.x_test, self.y_test)

        trainloader = Data.DataLoader(
            dataset=train_dataset,
            batch_size=self.bs,
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )
        validloader = Data.DataLoader(
            dataset=valid_dataset,
            batch_size=self.bs,
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )
        testloader = Data.DataLoader(
            dataset=test_dataset,
            batch_size=self.bs,
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )

        return trainloader, validloader, testloader

from sklearn.model_selection import train_test_split

class MAMEMDataLoader_fixed:
    def __init__(self, subject, data_path='./data/MAMEM/', bs=64, dev='cpu'):
        self.subject = subject
        self.ratio = 8
        self.data_path = data_path
        self.bs = bs
        self.dev = torch.device(dev)
        self.timeseries_length = 125

        self.trainloader = None
        self.validloader = None
        self.testloader = None

    def get_dataloader(self):
        train = io.loadmat(os.path.join(self.data_path, 'U' + f'{int(self.subject):03d}' + '.mat'))

        tempdata = torch.Tensor(train['x_test'])
        templabel = torch.Tensor(train['y_test']).view(-1)

        x = tempdata.numpy()
        y = templabel.numpy()

        # Split the data into train, validation, and test sets with random splits while maintaining the same sizes
        x_train, x_test, y_train, y_test = train_test_split(
            x, y, test_size=0.4, )
        x_valid, x_test, y_valid, y_test = train_test_split(
            x_test, y_test, test_size=0.5, )

        # Convert the arrays to PyTorch tensors
        x_train = torch.Tensor(x_train)
        y_train = torch.Tensor(y_train).long()
        x_valid = torch.Tensor(x_valid)
        y_valid = torch.Tensor(y_valid).long()
        x_test = torch.Tensor(x_test)
        y_test = torch.Tensor(y_test).long()

        self.x_train = x_train.to(self.dev)
        self.y_train = y_train.to(self.dev)
        self.x_valid = x_valid.to(self.dev)
        self.y_valid = y_valid.to(self.dev)
        self.x_test = x_test.to(self.dev)
        self.y_test = y_test.to(self.dev)

        if self.finetune:
            self.train_subject = torch.full((self.bs, 1, 128), self.subject, device=self.dev)
            self.valid_subject = torch.full((self.bs, 1, 128), self.subject, device=self.dev)
            self.test_subject  = torch.full((self.bs, 1, 128), self.subject, device=self.dev)
            train_dataset = Data.TensorDataset(self.x_train, self.train_subject, self.y_train)
            valid_dataset = Data.TensorDataset(self.x_valid, self.valid_subject, self.y_valid)
            test_dataset = Data.TensorDataset(self.x_test, self.test_subject, self.y_test)
        else:
            train_dataset = Data.TensorDataset(self.x_train, self.y_train)
            valid_dataset = Data.TensorDataset(self.x_valid, self.y_valid)
            test_dataset = Data.TensorDataset(self.x_test, self.y_test)

        trainloader = Data.DataLoader(
            dataset=train_dataset,
            batch_size=self.bs,
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )
        validloader = Data.DataLoader(
            dataset=valid_dataset,
            batch_size=self.bs,
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )
        testloader = Data.DataLoader(
            dataset=test_dataset,
            batch_size=self.bs,
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )

        return trainloader, validloader, testloader
