from gcip.preparators.graph.TU import TUPreparator
from gcip.utils.constants import Cte
from gcip.utils.exceptions import WrongLoss
from gcip.utils.io import dict_to_cn


class PTCPreparator(TUPreparator):

    def __init__(self,
                 **kwargs):

        self.dataset = None

        super().__init__(name='PTC_FM',
                         **kwargs)

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
        }

        my_dict.update(TUPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = PTCPreparator.params(dataset)

        return cls(**my_dict)

    def _x_dim(self):
        return 18

    def edge_attr_dim(self):
        return 4

    def label_dim(self):
        return 1  # [0, 1]

    def _loss(self, loss):
        if loss == Cte.DEFAULT:
            return Cte.BCELOGITS
        elif loss in [Cte.BCELOGITS]:
            return loss
        else:
            raise WrongLoss(loss)
