import pathlib
import os.path as osp
import torch as th

from .ps import Base2DSet
from .cox_utils import Cox

class CoxDist(Base2DSet):
    def __init__(self, len_data, dim, is_linear=True):
        fcsv = osp.join(pathlib.Path(__file__).parent.resolve(), "df_pines.csv")
        self.cox = Cox(fcsv, 40, use_whitened=False)

        super().__init__(len_data, is_linear)
        self.data = th.ones(dim, dtype=float).cuda()  # pylint: disable= not-callable
        self.data_ndim = dim

    def get_gt_disc(self, x):
        return -self.cox.evaluate_log_density(x)