import torch
import torch.nn as nn
import numpy as np


class ERM(nn.Module):
    def __init__(self, clf, criterion):
        super().__init__()
        self.clf = clf
        self.criterion = criterion
        self.device = next(self.parameters()).device

    def __loss__(self, clf_logits, clf_labels):
        if len(clf_logits.shape) != len(clf_labels.shape):
            clf_labels = clf_labels.reshape(clf_logits.shape)
        pred_loss = self.criterion(clf_logits, clf_labels.float())
        return pred_loss, {'loss': pred_loss.item(), 'pred': pred_loss.item()}

    def forward_pass(self, data, epoch, phase):
        clf_logits = self.clf(data)
        loss, loss_dict = self.__loss__(clf_logits, data.y)
        return loss, loss_dict, clf_logits


class ERM_L2(nn.Module):
    def __init__(self, clf, criterion, reg):
        super().__init__()
        self.clf = clf
        self.criterion = criterion
        self.device = next(self.parameters()).device
        self.reg = reg

    def __loss__(self, clf_logits, clf_labels, params):
        pred_loss = self.criterion(clf_logits, clf_labels.float())

        reg_loss = 0
        for param in params:
            reg_loss += torch.norm(param)

        all_loss = pred_loss + self.reg * reg_loss
        return all_loss, {'loss': all_loss.item(), 'pred': pred_loss.item(), 'L^2_reg': reg_loss.item()}

    def forward_pass(self, data, epoch, phase):
        clf_logits = self.clf(data).reshape(data.y.shape)
        loss, loss_dict = self.__loss__(clf_logits, data.y, list(self.clf.parameters()))
        return loss, loss_dict, clf_logits


class ERM_L2_SP(ERM_L2):
    def __init__(self, clf, criterion, reg):
        super(ERM_L2_SP, self).__init__(clf, criterion, reg)
        self.starting_point = list(self.clf.parameters())

    def __loss__(self, clf_logits, clf_labels, params):
        pred_loss = self.criterion(clf_logits, clf_labels.float())

        reg_loss = 0
        for param, sp in zip(params, self.starting_point):
            reg_loss += torch.norm(param - sp)

        all_loss = pred_loss + self.reg * reg_loss
        return all_loss, {'loss': all_loss.item(), 'pred': pred_loss.item(), 'L^2_SP_reg': reg_loss.item()}

