import numpy as np

from margflow.datasets.dataset_abstracts import DiscreteSamplesFromFileDataset


class PowerDataset(DiscreteSamplesFromFileDataset):
    def __init__(self, args):
        super(PowerDataset, self).__init__(args)
        self.dataset_suffix += "_pwr"

    def load_data(self):
        data = np.load(self.dataset_folder / "data.npy")
        data = np.delete(data, 3, axis=1)
        data = np.delete(data, 1, axis=1)

        # rng = np.random.RandomState(42)
        # voltage_noise = 0.01 * rng.rand(data.shape[0], 1)
        # gap_noise = 0.001 * rng.rand(data.shape[0], 1)
        # sm_noise = rng.rand(data.shape[0], 3)
        # time_noise = np.zeros((data.shape[0], 1))
        # noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise))

        return data  # + noise
