import numpy as np
import torch
from scipy.io import loadmat

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


class BiologicalLoader:
    def __init__(self, name, label1, label2, normalization=False, test_size=0.25, seed=1, device="cpu"):
        mat = loadmat(name)
        self.features = mat["X"]
        self.labels = mat["Y"]
        
        split = train_test_split(self.features, self.labels, test_size=test_size, random_state=seed)
        features_train_original, features_test_original, self.labels_train, self.labels_test = split

        if normalization:
            scaler = StandardScaler()
            scaler.fit(features_train_original)
            
            self.features_train = scaler.transform(features_train_original)
            self.features_test = scaler.transform(features_test_original)
        else:
            self.features_train = features_train_original
            self.features_test = features_test_original
                
        X = self.features_train[(self.labels_train[..., 0] == label1)]
        Y = self.features_train[(self.labels_train[..., 0] == label2)]

        self.X = torch.tensor(X, device=device, dtype=torch.float32)
        self.Y = torch.tensor(Y, device=device, dtype=torch.float32)

        self.d = self.X.shape[1]
        self.groups = [[k] for k in range(self.d)]
        
        self.platform = {"dtype": self.X.dtype, "device": device}
        
