from gcip.preparators.graph.TU import TUPreparator
from gcip.utils.constants import Cte
from gcip.utils.exceptions import WrongLoss
from gcip.utils.io import dict_to_cn


class PROTEINSPreparator(TUPreparator):

    def __init__(self,
                 **kwargs):

        self.dataset = None

        super().__init__(name='PROTEINS',
                         **kwargs)

    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
        }

        my_dict.update(TUPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = PROTEINSPreparator.params(dataset)

        return cls(**my_dict)

    def _x_dim(self):
        return 3

    def label_dim(self):
        return 1  # [0, 1]

    def _loss(self, loss):
        if loss == Cte.DEFAULT:
            return Cte.BCELOGITS
        elif loss in [Cte.BCELOGITS]:
            return loss
        else:
            raise WrongLoss(loss)

    def get_transform_fn(self):
        def transform(data):
            data.y = data.y.unsqueeze(-1)
            data.x = data.x[:,1:]
            return data
        return transform

    def get_scaler_info(self):
        if self.scale in ['default']:
            return {
                'x': ('identity', 3)
            }
        elif self.scale in ['std']:
            return {
                'x': ('std', self._x_dim())
            }
        elif self.scale in ['min0_max1']:
            return {
                'x': ('min0_max1', self._x_dim())
            }
        else:
            raise NotImplementedError
