from gcip.preparators.graph.TU import TUPreparator
from gcip.utils.constants import Cte
from gcip.utils.exceptions import WrongLoss
from gcip.utils.io import dict_to_cn


class ENZYMESPreparator(TUPreparator):

    def __init__(self,
                 **kwargs):

        self.dataset = None

        super().__init__(name='ENZYMES',
                         **kwargs)

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
        }

        my_dict.update(TUPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = ENZYMESPreparator.params(dataset)

        return cls(**my_dict)

    def _x_dim(self):
        return 21

    def get_scaler_info(self):
        if self.scale in ['default', 'std']:
            return {
                'x': ('std', self._x_dim())
            }
        elif self.scale in ['min0_max1']:
            return {
                'x': ('min0_max1', self._x_dim())
            }
        else:
            raise NotImplementedError

    def label_dim(self):
        return 6  # [0, 1]

    def _loss(self, loss):
        if loss == Cte.DEFAULT:
            return Cte.CE
        elif loss in [Cte.CE]:
            return loss
        else:
            raise WrongLoss(loss)
