from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelEncoder
from os.path import join as pjoin

def load_aml():
    data = pd.read_csv("./data/AML/preprocessed_data.csv", index_col=0)
    metadata = pd.read_csv("./data/AML/metadata.csv", index_col=0)

    data1 = data.values[(metadata['condition'] == 'Healthy') & (metadata['patient_id'] == 2)].astype(np.float32)
    data2 = data.values[(metadata['condition'] == 'Post transplant') & (metadata['patient_id'] == 0)].astype(np.float32)
    data3 = data.values[(metadata['condition'] == 'Pre transplant') & (metadata['patient_id'] == 0)].astype(np.float32)

    return (data1, np.zeros(data1.shape[0])), (data2, np.ones(data2.shape[0])), (data3, np.ones(data3.shape[0]) * 2)


def load_epithel():
    data1 = np.load("./data/epithel_new/Control.npy")
    labels1 = np.zeros(data1.shape[0])
    data2 = np.load("./data/epithel_new/Salmonella.npy")
    labels2 = np.ones(data2.shape[0])
    data3 = np.load("./data/epithel_new/Hpoly.npy")
    labels3 = np.ones(data3.shape[0]) * 2  # Array of 2's

    return (data1.astype('float32'), labels1), (data2.astype('float32'), labels2), (data3.astype('float32'), labels3)


def load_mice():
    target = np.load("./data/mice/target.npy")
    background = np.load("./data/mice/background.npy")
    return background, target


def load_activity():
    train = pd.read_csv("./data/activity/train.csv")
    test = pd.read_csv("./data/activity/test.csv")

    data_train = train.iloc[:, :-2].values
    labels_train = train.iloc[:, -1].values

    data_test = test.iloc[:, :-2].values
    labels_test = test.iloc[:, -1].values

    data_total = np.concatenate([data_train, data_test])
    data_total = MinMaxScaler().fit_transform(data_total)

    labels_total = np.concatenate([labels_train, labels_test])
    background = data_total[labels_total == 'LAYING']
    target = data_total[labels_total != 'LAYING']

    labels_encoder = LabelEncoder().fit(labels_total)
    background_labels = labels_encoder.transform(labels_total[labels_total == 'LAYING'])
    target_labels = labels_encoder.transform(labels_total[labels_total != 'LAYING'])

    return (background, background_labels), (target, target_labels)


class SimpleDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = X.copy()
        self.y = y.copy()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]

class LabeledDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = X.copy()
        self.y = y.copy()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.y[index]