import os  # type: ignore
from typing import Tuple

import numpy as np  # type: ignore
import torch

from data.tabular import TabularDataset

T = torch.Tensor


class HiggsDataset(TabularDataset):
    """
    NOTE:
     - this dataset was used in this paper: https://www.nature.com/articles/ncomms5308
     - UCI repository: https://archive.ics.uci.edu/ml/datasets/HIGGS
    """
    def __init__(self, path: str, test: bool = True, tiny: bool = False):
        file_suffix = "test" if test else "train"
        if tiny:
            file_suffix = "tiny"

        path = os.path.join(path, "higgs", f"higgs-{file_suffix}.npy")
        dataset = np.loadtxt(path)
        self.x = torch.from_numpy(dataset[:, 1:]).float()
        self.y = torch.from_numpy(dataset[:, 0]).float()
        self.name = "higgs"

    def __getitem__(self, i: int) -> Tuple[T, T]:
        return self.x[i], self.y[i]

    def __len__(self) -> int:
        return self.x.size(0)

    def __str__(self) -> str:
        return f"positive: {(self.y == 1).sum() / self.y.size(0)} negative: {(self.y == 0).sum() / self.y.size(0)}"
