import torch.utils.data as data
import torch
import pandas as pd
import numpy as np

from scipy.io import arff
from sklearn import preprocessing


class CustomValidDataset(data.Dataset):

    def __init__(self, data_loader, n_samples_per_class=10):
        super(CustomValidDataset, self).__init__()

        self.data = None
        self.labels = None
        dataset_data, dataset_targets = self.separate_data_from_loader(data_loader)
        self.create_valid_dataset(dataset_data=dataset_data, dataset_targets=dataset_targets, 
                                n_samples_per_class=n_samples_per_class)

    def __getitem__(self, index):
        return self.data[index], int(self.labels[index])

    def __len__(self):
        return len(self.data)

    def separate_data_from_loader(self, data_loader):
        samples = []
        targets = []
        for sample, target in data_loader:
            samples.append(sample.numpy())    
            targets.append(target.numpy())
        return torch.Tensor(np.array(samples)), torch.Tensor(np.array(targets))

    def create_valid_dataset(self, dataset_data, dataset_targets, n_samples_per_class=10):

        for x in torch.unique(dataset_targets):
            idx = dataset_targets == x
            
            data = dataset_data[idx]
            targets = dataset_targets[idx]
            choices = np.random.choice(len(targets), size=n_samples_per_class, replace=False)
            data = data[choices]
            targets = targets[choices]

            if self.data is None:
                self.data = data
                self.labels = targets
            else:
                self.data = torch.cat((self.data, data), 0)
                self.labels = torch.cat((self.labels, targets), 0)

class CustomDataset(data.Dataset):

    def __init__(self, load_path, norm=None):
        super(CustomDataset, self).__init__()

        if load_path.endswith(".arff"):
            loaded_mat, meta = arff.loadarff(load_path)
            loaded_mat = pd.DataFrame(loaded_mat, dtype=float)
        else:
            loaded_mat = pd.read_csv(load_path, sep=",", header=None)
            loaded_mat = pd.DataFrame(loaded_mat, dtype=float)

        self.labels = loaded_mat.iloc[:, -1].values

        if norm is not None:
            if norm == 'minmax':
                min_max_scaler = preprocessing.MinMaxScaler().fit(loaded_mat)
                loaded_mat = min_max_scaler.transform(loaded_mat)

            elif norm == 'scaler':
                scaler = preprocessing.StandardScaler().fit(loaded_mat)
                loaded_mat = scaler.transform(loaded_mat)

        loaded_mat = pd.DataFrame(loaded_mat, dtype=float)

        self.data = loaded_mat.iloc[:, :-1]

        if load_path != "":
            self.labels = self.labels.astype(int)

    def __getitem__(self, index):
        sample, target = self.data.iloc[index], int(self.labels[index])

        return torch.tensor(sample), torch.tensor(target)

    def __len__(self):
        return len(self.data)
