from collections import Counter

import numpy as np
import pandas as pd
from scipy import stats

from margflow.datasets.dataset_abstracts import DiscreteSamplesFromFileDataset


class MinibooneDataset(DiscreteSamplesFromFileDataset):
    def __init__(self, args):
        super(MinibooneDataset, self).__init__(args)
        self.dataset_suffix += "_bon"

    def load_data(self):
        # data = np.loadtxt(self.dataset_folder + "/MiniBooNE_PID.txt", skiprows=1)
        # indices = (data[:, 0] < -100)
        # data = data[~indices]
        #
        # remove = []
        # for i, feature in enumerate(data.T):
        #     c = Counter(feature)
        #     max_count = np.array([v for k, v in sorted(c.items())])[0]
        #     if max_count > 5:
        #         remove.append(i)
        # data = data[:, np.array([i for i in range(data.shape[1]) if i not in remove])]

        data = np.load(self.dataset_folder / "data.npy")

        # n_std = 3
        # data_df = pd.DataFrame(data)
        # z_scores = np.abs(stats.zscore(data_df.select_dtypes(include=[np.number])))
        # filtered_df = data_df[(z_scores < n_std).all(axis=1)]

        return data
