import torch
import torch.autograd.functional as AF
import foolbox

from jacobian import JacobianReg

from convexrobust.utils import torch_utils

from convexrobust.model.base_certifiable import BaseCertifiable, Certificate, Norm


class ConvexCertifiable(BaseCertifiable):
    def __init__(self, reg=0.0, **kwargs):
        super().__init__(single_logit=True, **kwargs)

        self.reg = reg
        if self.reg > 0.0:
            self.jac_reg = JacobianReg()

    def extra_loss(self, signal, target):
        # Jacobian regularization loss
        if self.reg == 0.0:
            return torch.tensor(0.0)

        with torch.enable_grad():
            signal, _ = self.lipschitz_forward(signal)
            signal.requires_grad = True
            output = self.convex_forward(signal).unsqueeze(1)
            extra_loss = self.reg * self.jac_reg(signal, output)

        return extra_loss

    def certify(self, x, target):
        assert x.shape[0] == 1

        x_orig = x

        x, lips = self.lipschitz_forward(x)
        pred = self.convex_forward(x)
        pred = pred + self.class_balance

        # positive logit -> class 1, negative logit -> class 2
        if pred <= 0:
            # Don't certify class 2
            return torch.tensor([1]).long().to(torch_utils.device()), Certificate.zero()

        margin = pred.squeeze(0)
        jac = AF.jacobian(self.convex_forward, x, strict=True, create_graph=False).squeeze()

        certificate = Certificate({
            Norm.L1: (margin / jac.norm(float('inf'))).item() / lips[Norm.L1],
            Norm.L2: (margin / jac.norm(2)).item() / lips[Norm.L2],
            Norm.LInf: (margin / jac.norm(1)).item() / lips[Norm.LInf]
        })

        return (pred <= 0).long(), certificate

    def forward(self, x):
        x, _ = self.lipschitz_forward(x)
        return self.convex_forward(x)

    def lipschitz_forward(self, x):
        # The "feature map"
        return x, {Norm.L1: 1, Norm.L2: 1, Norm.LInf: 1}

    def convex_forward(self, x):
        # Return single logit
        raise NotImplementedError()
