import pandas as pd
import numpy as np
from scipy import stats

from margflow.datasets.dataset_abstracts import DiscreteSamplesFromFileDataset


class GasDataset(DiscreteSamplesFromFileDataset):
    def __init__(self, args):
        super(GasDataset, self).__init__(args)
        self.dataset_suffix += "_gas"

    def load_data(self):
        data = pd.read_pickle(self.dataset_folder / "ethylene_CO.pickle")
        data.drop("Meth", axis=1, inplace=True)
        data.drop("Eth", axis=1, inplace=True)
        data.drop("Time", axis=1, inplace=True)
        #
        corr = (data.corr() > 0.98).values.sum(1)  # pearson correlation
        while np.any(corr > 1):
            col_to_remove = np.where(corr > 1)[0][0]
            col_name = data.columns[col_to_remove]
            data.drop(col_name, axis=1, inplace=True)
            corr = (data.corr() > 0.98).values.sum(1)

        # n_std = 2.5
        # z_scores = np.abs(stats.zscore(data.select_dtypes(include=[np.number])))
        # filtered_df = data[(z_scores < n_std).all(axis=1)]

        return data.values  # filtered_df.values
