from collections import Counter

import numpy as np
import pandas as pd
from scipy import stats

from margflow.datasets.dataset_abstracts import DiscreteSamplesFromFileDataset


class HepmassDataset(DiscreteSamplesFromFileDataset):
    def __init__(self, args):
        super(HepmassDataset, self).__init__(args)
        self.dataset_suffix += "_hep"

    def load_data(self):
        data_train = pd.read_csv(
            filepath_or_buffer=self.dataset_folder / "1000_train.csv", index_col=False
        )
        data_test = pd.read_csv(
            filepath_or_buffer=self.dataset_folder / "1000_test.csv", index_col=False
        )
        data_train = data_train[data_train[data_train.columns[0]] == 1]
        data_train = data_train.drop(data_train.columns[0], axis=1)
        data_test = data_test[data_test[data_test.columns[0]] == 1]
        data_test = data_test.drop(data_test.columns[0], axis=1)
        data_test = data_test.drop(data_test.columns[-1], axis=1)

        data = np.concatenate([data_train.values, data_test.values], axis=0)

        # n_std = 3
        # data_df = pd.DataFrame(data)
        # z_scores = np.abs(stats.zscore(data_df.select_dtypes(include=[np.number])))
        # filtered_df = data_df[(z_scores < n_std).all(axis=1)]

        return data  # filtered_df.values
