import torch
import math

from convexrobust.utils import torch_utils

from convexrobust.model.base_certifiable import BaseCertifiable, Certificate, Norm

import sys
sys.path.append('../../lib')
from orthconv.utils import margin_loss

custom_loss = lambda yhat, y: margin_loss(yhat, y, 0.5, 1.0, 1.0)

class CayleyCertifiable(BaseCertifiable):
    def __init__(self, **kwargs):
        super().__init__(single_logit=False, custom_loss=custom_loss, **kwargs)

    def certify(self, x, target):
        assert x.shape[0] == 1
        pred = self.forward_balanced(x)
        margins = torch.sort(pred, 1)[0]

        certified_margin = (margins[:,-1] - margins[:, -2])

        certificate = Certificate.from_l2(
            certified_margin.item() / math.sqrt(2), self.datamodule.in_n
        )

        return pred.argmax(dim=1), certificate
