#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" read datasets from existing files"""

import numpy as np
from sklearn.preprocessing import StandardScaler


class DataGenerator(object):


    def __init__(self, data_index_path, data_path, dag_index_path, dag_path, normalize_flag=False, transpose_flag=False):

        self.inputdata_index = np.load(data_index_path)
        self.inputdata = np.load(data_path)
        self.datasize_index, self.d_index = self.inputdata_index.shape
        self.datasize, self.d = self.inputdata.shape
        assert self.datasize == self.datasize_index
        assert self.d == (self.d_index - 1)

        if normalize_flag:
            self.inputdata_index = StandardScaler().fit_transform(self.inputdata_index)
            self.inputdata = StandardScaler().fit_transform(self.inputdata)


        dag_index_true = np.load(dag_index_path)
        dag_true = np.load(dag_path)
        if transpose_flag:
            dag_index_true = np.transpose(dag_index_true)
            dag_true = np.transpose(dag_true)

        # (i,j)=1 => node i -> node j
        self.dag_true = np.int32(np.abs(dag_true) > 1e-3)
        self.dag_index_true = np.int32(np.abs(dag_index_true) > 1e-3)

    def gen_instance_graph(self, dimension, with_index=True):
        seq = np.random.randint(self.datasize, size=(dimension))
        if with_index:
            input_ = self.inputdata_index[seq]
        else:
            input_ = self.inputdata[seq]
        return input_.T

    # Generate random batch for training procedure
    def train_batch(self, batch_size, dimension, with_index=True):
        input_batch = []

        for _ in range(batch_size):
            input_ = self.gen_instance_graph(dimension, with_index)
            input_batch.append(input_)

        return input_batch
