
import numpy as np
import torch
import torch.nn as nn



class AbstractModel(nn.Module):


    def __init__(self,config,dataset):
        super(AbstractModel, self).__init__()
        self.config = config
        self.dataset = dataset
        self.device = config['device']


    def calculate_loss(self, x,t,y,w):
        r"""Calculate the training loss for a batch data.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Training loss, shape: []
        """
        raise NotImplementedError

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        raise NotImplementedError


    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return super().__str__() + '\nTrainable parameters: {}'.format(params)


class SKAbstractModel(object):

    def __init__(self,config,dataset):
        super(SKAbstractModel, self).__init__()
        self.config = config
        self.dataset = dataset
        self.device = config['device']


    def calculate_loss(self, x,t,y,w):
        r"""Calculate the training loss for a batch data.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Training loss, shape: []
        """
        raise NotImplementedError

    def predict(self, x,t0,t1):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        raise NotImplementedError


    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        params = [self.config['model']]
        return '\nTrainable parameters: {}'.format(params)

    def to(self,device):

        return self