from sklearn import preprocessing
import numpy as np

class Dataset(object):
    def __init__(self):
        pass
    def sample_batch(self,data,labels,Ns=None,Ls=None,Ms=None):
        ohc           = preprocessing.OneHotEncoder(sparse=False)
        labels        = ohc.fit_transform(labels.reshape(-1,1))
        if labels.shape[1] < Ls:
            new_labels = np.zeros((labels.shape[0], Ls))
            new_labels[:, :labels.shape[1]] = labels
            labels = new_labels
        if data.shape[1] < Ms:
            new_data = np.zeros((data.shape[0], Ms))
            new_data[:, :data.shape[1]] = data
            data = new_data

        # Ms            = np.random.choice(np.arange(2,data.shape[1]+1),size=1)[0]  if Ms is None else Ms
        # Ls            = np.random.choice(np.arange(1,labels.shape[1]+1),size=1)[0] if Ls is None else Ls
        #Ns            = np.minimum(2**np.random.choice(np.arange(4,9),size=1)[0],data.shape[0]) if data.shape[0] < 100 else 10
        L_hat         = np.random.choice(np.arange(0,labels.shape[1]),size=Ls,replace=False)
        M_hat         = np.random.choice(np.arange(0,data.shape[1]),size=Ms,replace=False)
        N_hat         = np.random.choice(np.arange(0,data.shape[0]),size=Ns,replace=True)

        data          = data[N_hat]
        data          = data[:,M_hat]
        labels        = labels[N_hat]
        labels        = labels[:,L_hat]
        return data,labels

    def sample_batch_pairs(self, Ns, Ls, Ms, positive=False,first_element=None,test=False):
        if test:
            data    = self.tst_data
            labels  = self.tst_labels
            names  = self.tst_names
        else:
            data = self.trn_data
            labels = self.trn_labels
            names  = self.trn_names

        first_element = np.random.choice(np.arange(0,len(data)), 1,replace=False)[0] if first_element is None else first_element
        name_first_element = names[first_element]
        if not positive:
            done = False
            while not done:
                second_element = np.random.choice(np.arange(0,len(data)), 1,replace=False)[0]
                name_second_element = names[second_element]
                if name_first_element != name_second_element:
                    done = True
            related = 0
        else:
            second_element = first_element
            related = 1
        info = []

        X_1, Y_1 = self.sample_batch(data=data[first_element],labels=labels[first_element], Ns=Ns,Ls=Ls,Ms=Ms)
        X_2, Y_2 = self.sample_batch(data=data[second_element],labels=labels[second_element], Ns=Ns,Ls=Ls,Ms=Ms)

        return X_1, Y_1, X_2, Y_2

    def get_batch(self, batch_size, Ns, Ls, Ms, test, stratification_pos_ratio = 0.5):
        list_X_1, list_Y_1, list_X_2, list_Y_2 = [], [], [], []
        I = []
        num_pos = int(batch_size*stratification_pos_ratio)
        for i in range(num_pos):
            X_1, Y_1, X_2, Y_2 = self.sample_batch_pairs(Ns=Ns,Ls=Ls,Ms=Ms, positive=True, test=test)
            list_X_1.append(X_1)
            list_Y_1.append(Y_1)
            list_X_2.append(X_2)
            list_Y_2.append(Y_1)
            I.append(0)
        for i in range(batch_size-num_pos):
            X_1, Y_1, X_2, Y_2 = self.sample_batch_pairs(Ns=Ns,Ls=Ls,Ms=Ms, positive=False, test=test)
            list_X_1.append(X_1)
            list_Y_1.append(Y_1)
            list_X_2.append(X_2)
            list_Y_2.append(Y_1)
            I.append(1)

        return np.array(list_X_1), np.array(list_Y_1), np.array(list_X_2), np.array(list_Y_2), np.array(I)
