from gcip.preparators.graph.TU import TUPreparator
from gcip.utils.constants import Cte
from gcip.utils.exceptions import WrongLoss
from gcip.utils.io import dict_to_cn


class COX2reparator(TUPreparator):

    def __init__(self,
                 **kwargs):

        self.dataset = None

        super().__init__(name='COX2',
                         **kwargs)

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
        }

        my_dict.update(TUPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = COX2reparator.params(dataset)

        return cls(**my_dict)

    def _x_dim(self):
        return 38

    def label_dim(self):
        return 1  # [0, 1]

    def _loss(self, loss):
        if loss == Cte.DEFAULT:
            return Cte.BCELOGITS
        elif loss in [Cte.BCELOGITS]:
            return loss
        else:
            raise WrongLoss(loss)



    def get_scaler_info(self):
        if self.scale in ['default']:
            return {
                'x': ('std', self._x_dim())
            }
        elif self.scale in ['min0_max1']:
            return {
                'x': ('min0_max1', self._x_dim())
            }
        else:
            raise NotImplementedError
