
import numpy as np
import torch
import torch.nn as nn



class AbstractModel(nn.Module):


    def __init__(self,config,dataset):
        super(AbstractModel, self).__init__()
        self.config = config
        self.dataset = dataset
        self.device = config['device']
        self.loss_type = config['loss_type']


    def calculate_loss(self, x,t,y,w):
        r"""Calculate the training loss for a batch data.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Training loss, shape: []
        """
        raise NotImplementedError

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        raise NotImplementedError


    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return super().__str__() + '\nTrainable parameters: {}'.format(params)

    def update_grad(self):
        for param in self.parameters():
            if param.requires_grad:
                param.requires_grad_(False)
            else:
                param.requires_grad_(True)

    def generate_perturbation(self,x,t,y,w,perturbation_type):
        """
                Obtaining the model's grad with respect to x
        """
        x.requires_grad_(True)
        self.update_grad()
        loss = self.calculate_loss(x,t,y,w)
        loss.backward(retain_graph=False)
        grad = x.grad.data.clone()
        x.requires_grad_(False)
        self.update_grad()

        if perturbation_type in ['FGSM']:
            grad = self.config['grad_epsilon'] * torch.sign(grad)
        else:
            grad = -self.config['grad_epsilon'] * torch.sign(grad)

        return grad.cpu().numpy()

    def regularization_loss(self):
        regularization_mse = nn.MSELoss(reduction='mean')
        regular_term = None
        for param in self.parameters():
            if len(param.shape) == 1:
                cur_loss = torch.sum(param * param) - 1
                regular_term = cur_loss if regular_term is None else cur_loss + regular_term
            elif len(param.shape) == 2:
                cur_loss = regularization_mse(torch.matmul(param.T, param),
                                                   torch.eye(n=param.shape[1]).to(self.device))
                regular_term = cur_loss if regular_term is None else cur_loss + regular_term

        return regular_term


class SKAbstractModel(object):

    def __init__(self,config,dataset):
        super(SKAbstractModel, self).__init__()
        self.config = config
        self.dataset = dataset
        self.device = config['device']


    def calculate_loss(self, x,t,y,w):
        r"""Calculate the training loss for a batch data.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Training loss, shape: []
        """
        raise NotImplementedError

    def predict(self, x,t0,t1):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        raise NotImplementedError


    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        params = [self.config['model']]
        return '\nTrainable parameters: {}'.format(params)

    def to(self,device):

        return self