from tensorflow import keras
import numpy as np
import math
import pandas as pd



class BigDatasetLoader(keras.utils.Sequence):

    def __init__(self, folder_name, batch_size, shuffle = True, standardize_input = True, standardize_output = True):
        self.folder_name = folder_name
        self.batch_size = batch_size
        self.batch_count = 0
        self.chunk_count = 0
        # self.shuffle = shuffle
        self.standardize_input = standardize_input
        self.standardize_output = standardize_output
        
        info = np.array(pd.read_csv("../../data/big/info.csv", header = None))
        self.n_total = info[0][0]
        self.chunks = info[0][1]

        if self.standardize_input:
            self.X_mean = np.array(pd.read_csv(folder_name + "/X_mean.csv", header = None))
            self.X_std = np.array(pd.read_csv(folder_name + "/X_std.csv", header = None))

        if self.standardize_output:
            self.y_mean = np.array(pd.read_csv(folder_name + "/y_mean.csv", header = None)).squeeze()
            self.y_std = np.array(pd.read_csv(folder_name + "/y_std.csv", header = None)).squeeze()
        else:
            self.y_mean = 0
            self.y_std = 1

        # self.f_data = open(folder_name + "/train_data.csv", 'r')
        # self.f_labels = open(folder_name + "/train_labels.csv", 'r')

        self.x = np.array(pd.read_csv(self.folder_name + "/train_data.csv", header=None, skiprows=1, nrows=self.batch_size))
        self.y = np.array(pd.read_csv(self.folder_name + "/train_labels.csv", header=None, skiprows=1, nrows=self.batch_size))

        # self.indices = np.arange(self.n_total)

    @property
    def n_attributes(self):
        return self.x.shape[1]

    def __len__(self):
        return math.ceil(self.n_total/ self.batch_size)

    def __getitem__(self, idx):
        # if (idx % )
        # print(idx)
        # inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]

        batch_x = pd.read_csv(self.folder_name + "/train_data.csv", header=None, skiprows=1 + idx * self.batch_size, nrows=self.batch_size)
        batch_y = pd.read_csv(self.folder_name + "/train_labels.csv", header=None, skiprows=1 + idx * self.batch_size, nrows=self.batch_size)
        # batch_x = self.x[inds]
        # batch_y = self.y[inds]
        if self.standardize_input:
            batch_x = (batch_x - self.X_mean) / self.X_std

        if self.standardize_output:
            batch_y = (batch_y - self.y_mean) / self.y_std

        return np.array(batch_x), np.array(batch_y)

    def on_epoch_end(self):
        pass
        # if self.shuffle:
        #     self.indices = np.random.shuffle(self.indices)

    def standardize_x(self, x):
        x_out = x
        if self.standardize_input:
            x_out = (x - self.X_mean) / self.X_std
        return x_out

    def standardize_y(self, y):
        y_out = y
        if self.standardize_output:
            y_out = (y - self.y_mean) / self.y_std
        return y_out

    def estimate_lengthscale(self):
        X_sample = self.sample(n_samples=1000)

        dist2 = np.sum(X_sample**2, 1, keepdims = True) - 2.0 * np.dot(X_sample, X_sample.T) + np.sum(X_sample**2, 1, keepdims = True).T
        log_l = 0.5 * np.log(np.median(dist2[ np.triu_indices(1000, 1) ]))

        return log_l

    def sample(self, n_samples=100):
        X_sample = self.x[ np.random.choice(np.arange(self.x.shape[ 0 ]), size = 1000), :  ]

        if self.standardize_input:
            X_sample = (X_sample - self.X_mean) / self.X_std

        return X_sample