class OrthoLinear(nn.Module):
    def __init__(self, num_features, num_classes, conv_name, **kwargs):
        super(OrthoLinear, self).__init__()
        if conv_name == 'cayley':
            self.ortho = CayleyLinear(num_features, num_classes)
        else:
            conv_module = conv_mapping[conv_name]
            self.ortho = conv_module(num_features, num_classes, kernel_size=1, 
                                     stride=1, padding=0, **kwargs)
        
    def forward(self, features):
        logits = self.ortho(features)
        logits = torch.flatten(logits, start_dim=1)
        return logits
    
    def certificates(self, features, y=None, L=1.):
        logits = self.forward(features)
        if y is None:
            y = torch.argmax(logits, dim=1)

        batch_size = logits.shape[0]
        batch_idxs = torch.arange(batch_size)
        
        onehot = torch.zeros_like(logits)
        onehot[batch_idxs, y] = 1.

        num_classes = logits.shape[1]
        class_idxs = torch.arange(num_classes).expand(batch_size, -1)
        other = class_idxs[onehot == 0]
        other = other.view(batch_size, num_classes - 1)
        
        logits_y = logits[batch_idxs, y]
        logits_other = logits[batch_idxs.unsqueeze(1), other]
        
        other = other.cuda()
        batch_idxs = batch_idxs.cuda()
        print(batch_idxs.device, y.device)
        for i in batch_idxs:
            y_i = y[i]
            y_i_1 = (y[i] + 1) % num_classes
            
            assert y_i not in other[i, :]
            assert y_i_1 in other[i, :]
            
            assert logits_y[i] == logits[i, y_i]
        
        print(logits_y.shape, logits_other.shape)
        
        print(logits[0, :], logits_y[0], logits_other[0, :], y[0])
        

        logits_nextmax = torch.max(logits_other, dim=1)[0]
        logits_diff = logits_y - logits_nextmax
        certs = logits_diff/(np.sqrt(2)*L)
        
        
        print(logits_y.shape, logits_other.shape)
        
        quit()
        return logits, certs
    
    
class NormalizedLinear(nn.Linear):
    def forward(self, features):
        features = torch.flatten(features, start_dim=1)
        weight_norm = torch.norm(self.weight, dim=1, keepdim=True)
        self.weight_lln = self.weight/weight_norm
        logits = F.linear(features, self.weight_lln, self.bias)
        return logits
    
    def certificates(self, features, y=None, L=1.):
        logits = self.forward(features)
        if y is None:
            y = torch.argmax(logits, dim=1)

        batch_size = logits.shape[0]
        batch_idxs = torch.arange(batch_size)

        onehot = torch.zeros_like(logits)
        onehot[batch_idxs, y] = 1.
        num_classes = logits.shape[1]
        class_idxs = torch.arange(num_classes).expand(batch_size, -1)
        other = class_idxs[onehot == 0]
        other = other.view(batch_size, num_classes - 1)
        
        weight = self.weight_lln
        weight_pdists = 2 - 2 * weight.mm(weight.T)
        weight_pdists = torch.sqrt(weight_pdists)
        print(weight.shape, weight_pdists.shape)
        
        norm_diffs = weight_pdists[y.unsqueeze(1), other]
        print(norm_diffs.shape)
        
#         for i, y_i in enumerate(y):
#             other_i = other[i, :]
#             for j, y_j in enumerate(other_i):
#                 pdist2 = torch.norm(weight[y_i, :] - weight[y_j, :])
#                 print(pdist2.detach().item(), weight_diffs_norm[i, j].detach().item())
                
#                 assert torch.allclose(pdist2, weight_diffs_norm[i, j])
                
                
                
#         F.pdist(weight.float(), p=2).type(weight.dtype)
#         for i in range(num_classes):
#             for j in range(i+1, num_classes):
#                 pdist = weight_pdists[i*(num_classes-1) + j - i - 1]
                
#                 pdist2 = torch.norm(weight[i, :].float() - weight[j, :].float()).type(weight.dtype)
#                 print(i, j, pdist.detach().item(), pdist2.detach().item())
#                 assert pdist == pdist2
        
#         print(weight.shape, weight_pdists.shape)
#         print(weight.dtype, weight_pdists.dtype)
        quit()
        
        weight_y = weight[y, :]
        weight_diff = weight_y.unsqueeze(1) - weight.unsqueeze(0)
        weight_diff_norm = torch.norm(weight_diff, dim=2)
        weight_diff_norm = weight_diff_norm + onehot

        logits_y = logits[batch_idxs, y]
        logits_other = logits[batch_idxs.unsqueeze(1), other]
        
        logits_diff = logits_y.unsqueeze(1) - logits_other
        
        
        all_certs = logits_diff/(weight_diff_norm*L)
        certs = torch.min(all_certs, dim=1)[0]
        return logits, certs
    
    
class CurvDiag(nn.Module):
    def __init__(self, num_features, num_classes, sn_bound=2):
        super(CurvDiag, self).__init__()
        self.sn_bound = sn_bound
        self.scale = nn.Parameter(torch.randn(1, num_features), requires_grad=True)
        self.linear = nn.Linear(num_features, num_classes)
        self.reset_parameters()

    def reset_parameters(self):
        num_classes = self.linear.weight.shape[1]
        print('reset_parameters: ', num_classes)
        std = 1 / (num_classes ** 0.5)
        nn.init.uniform_(self.linear.weight, -std, std)
        if self.linear.bias is not None:
            self.linear.bias.data.uniform_(-std, std)
        
    def forward(self, features):
        features = torch.flatten(features, start_dim=1)
        scale = torch.clamp(self.scale, min=-self.sn_bound, max=self.sn_bound)
        x = scale * features
        x = F.softplus(x)
        logits = self.linear(x)
        return logits
    
    def certificates(self, features, y=None, L=1.):
        logits = self.forward(features)
        if y is None:
            y = torch.argmax(logits, dim=1)

        batch_size = logits.shape[0]
        batch_idxs = torch.arange(batch_size)

        onehot = torch.zeros_like(logits)
        onehot[batch_idxs, y] = 1.
        logits_trunc = logits - onehot*1e6

        logits_y = logits[batch_idxs, y]
        
        logits_nextmax = torch.max(logits_trunc, dim=1)[0]
        logits_diff = logits_y - logits_nextmax
        certs = logits_diff/(np.sqrt(2)*L)
        return logits, certs


class CurvDiag(nn.Module):
    def __init__(self, num_features, num_classes, sn_bound=2.):
        super(CurvDiag, self).__init__()
        self.sn_bound = sn_bound
        self.scale = nn.Parameter(torch.randn(1, num_features), requires_grad=True)
        
        self.beta = 0.25
        self.activation = torch.nn.Softplus(beta=self.beta)
        self.hessian_bound = 0.25 * self.beta       # Bound on the hessian of softplus activation
        self.linear = nn.Linear(num_features, num_classes)
        
    def forward(self, features):
        features = torch.flatten(features, start_dim=1)
        scale = torch.clamp(self.scale, min=-self.sn_bound, max=self.sn_bound)
        
#         print('scale: {:s}, {:.3f}, {:.3f}, {:.3f}'.format(str(scale.shape), scale.min().detach().item(), 
#                                                            scale.mean().detach().item(), 
#                                                            scale.max().detach().item()))
        
#         print('features: {:s}, {:.3f}, {:.3f}, {:.3f}'.format(str(features.shape), 
#                                                               features.min().detach().item(), 
#                                                               features.mean().detach().item(), 
#                                                               features.max().detach().item()))
    
        x = scale * features
        x = softplus(x)        
        logits = self.linear(x)
        
#         logits_norm = torch.norm(logits, dim=1, keepdim=True)
#         logits = logits/logits_norm
        
#         print('weight: {:s}, {:.3f}, {:.3f}, {:.3f}'.format(str(weight.shape), 
#                                                             weight.min().detach().item(), 
#                                                             weight.mean().detach().item(), 
#                                                             weight.max().detach().item()))

#         print('bias: {:s}, {:.3f}, {:.3f}, {:.3f}'.format(str(bias.shape), bias.min().detach().item(), 
#                                                           bias.mean().detach().item(), 
#                                                           bias.max().detach().item()))

#         print('logits: {:s}, {:.3f}, {:.3f}, {:.3f}'.format(str(logits.shape), logits.min().detach().item(), 
#                                                             logits.mean().detach().item(), 
#                                                             logits.max().detach().item()))
        
        return logits
    
    def certificates(self, features, y=None, L=1.):
        logits = self.forward(features)
        if y is None:
            y = torch.argmax(logits, dim=1)

        batch_size = logits.shape[0]
        batch_idxs = torch.arange(batch_size)
        
#         onehot = torch.zeros_like(logits)
#         onehot[batch_idxs, y] = 1.

#         num_classes = logits.shape[1]
#         class_idxs = torch.arange(num_classes).expand(batch_size, -1)
#         other = class_idxs[onehot == 0]
#         other = other.view(batch_size, num_classes - 1)
        
#         logits_y = logits[batch_idxs, y]
#         logits_other = logits[batch_idxs.unsqueeze(1), other]

#         logits_nextmax = torch.max(logits_other, dim=1)[0]
#         logits_diff = logits_y - logits_nextmax
#         certs = logits_diff/(np.sqrt(2)*L)
        
        certs = torch.ones(batch_size, device=features.device, dtype=features.dtype)
        return logits, certs
    
    
    
    
            if args.last_layer[:4] == 'curv':
#                 total_norm = 0.
#                 for p in amp.master_params(opt):
#                     param_norm = p.grad.detach().data.norm(2)
#                     total_norm += param_norm.item() ** 2
#                 total_norm = total_norm ** 0.5
#                 print(total_norm)
#                 if total_norm!=total_norm:
#                     quit()
                    
                torch.nn.utils.clip_grad_norm_(amp.master_params(opt), 10.)
        
        
        
        
        all_c_T = torch.max(bounds, bounds.T)
        if (all_c_T == 0).sum() > 10:
            print_stats(bounds, 'bounds')
            print_stats(all_c_T, 'all_c_T')
            
            for i in range(self.num_classes):
                for j in range(i+1, self.num_classes):
                    if ((bounds[i,j] == 0) and (bounds[j,i] == 0)) and (i < j):
                        print(i, j)
                        print_stats(w2[i, :] - w2[j, :])
                        print_stats(w2[j, :] - w2[i, :])
                        
                        print_stats(w2_diffs[i, j, :], '{:d}, {:d}'.format(i, j))
                        print_stats(w2_diffs[j, i, :], '{:d}, {:d}'.format(j, i))
                        
                        print(bounds[i,j], bounds[j,i])
                        
                        u = u.view(self.num_classes, self.num_classes, -1)
                        u_norms = torch.norm(u, dim=2)
                        print('u', u_norms[i, j].detach().item(), u_norms[j, i].detach().item())
                        
                        u0 = u0.view(self.num_classes, self.num_classes, -1)
                        u0_norms = torch.norm(u0, dim=2)
                        print('u0', u0_norms[i, j].detach().item(), u0_norms[j, i].detach().item())
                                                
                        u1 = u1.view(self.num_classes, self.num_classes, -1)
                        u1_norms = torch.norm(u1, dim=2)
                        print('u1', u1_norms[i, j].detach().item(), u1_norms[j, i].detach().item())
                        
                        u2 = u2.view(self.num_classes, self.num_classes, -1)
                        u2_norms = torch.norm(u2, dim=2)
                        print('u2', u2_norms[i, j].detach().item(), u2_norms[j, i].detach().item())
                        
                        u3 = u3.view(self.num_classes, self.num_classes, -1)
                        u3_norms = torch.norm(u3, dim=2)
                        print('u3', u3_norms[i, j].detach().item(), u3_norms[j, i].detach().item())
                        
                        u4_norms = torch.norm(u4, dim=2)
                        print('u4', u4_norms[i, j].detach().item(), u4_norms[j, i].detach().item())
                                                
            quit()


            
#         nocurv_i, nocurv_j = torch.where(curv_bound == 0)
#         w2_diff_nocurv = w2_diff[nocurv_i, nocurv_j]
#         b2_diff_nocurv = b2_diff[nocurv_i, nocurv_j]

#         x_nocurv = x[nocurv_i, nocurv_j]
#         x_act_nocurv, x_lin1_nocurv = self._intermediate(x_nocurv)
#         logits_diff_nocurv = torch.sum(x_act_nocurv * w2_diff_nocurv, dim=1) + b2_diff_nocurv
        
#         grad_nocurv = self._gradient(x_lin1_nocurv, w2_diff_nocurv)
        
#         grad_norm_nocurv = torch.norm(grad_nocurv, dim=1)
        
#         certs_nocurv = torch.abs(logits_diff_nocurv)/grad_norm_nocurv
#         certs_nocurv = torch.sign(logits_diff_nocurv) * certs_nocurv
        
        
        

class CRC_Diag(nn.Module):
    def __init__(self, num_features, num_classes, w1_bound=None, act_beta=1., act_thresh=20., eps=1e-6, min_curv=1e-6, 
                 inv_eps=1e-4, grad_tolerance=1e-3, outer_iters=10, train_inner_iters=5, eval_inner_iters=10):
        super(CRC_Diag, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        
        if w1_bound is None:
            w1_bound = np.float('inf')
        self.w1_bound = w1_bound
        
        self.w1 = nn.Parameter(torch.ones(1, num_features), requires_grad=True)
        self.b1 = nn.Parameter(torch.zeros(1, num_features), requires_grad=True)
        
        self.activation = Softplus(beta = act_beta, threshold = act_thresh)
        self.linear2 = nn.Linear(num_features, num_classes)
        
        self.eps = eps
        self.min_curv = min_curv
        self.inv_eps = inv_eps
        self.grad_tolerance = grad_tolerance
        
        self.outer_iters = outer_iters
        self.train_inner_iters = train_inner_iters
        self.eval_inner_iters = eval_inner_iters
        
    def _intermediate(self, x):
        x = torch.flatten(x, start_dim=1)
        x_lin1 = (self._w1 * x) + self.b1
        x_act = x + self.activation(x_lin1)
        return x_lin1, x_act
        
    def forward(self, features):
        self._w1 = torch.clamp(self.w1, min=-self.w1_bound, max=self.w1_bound)
        _, x_act = self._intermediate(features)

        logits = self.linear2(x_act)
        return logits
    
    def _gradient(self, x_lin1, w_last):
        act_grad = self.activation.gradient(x_lin1)        
        grad = w_last * (1 + act_grad * self._w1)
        return grad
    
    def _hessian(self, x_lin1, w_last):
        act_hess = self.activation.hessian(x_lin1)
        hess = w_last * act_hess * (self._w1 * self._w1)
        return hess
    
    def curv_bounds(self):
        w1 = self._w1
        w2 = self.linear2.weight
        
        w2_diffs = w2[:, None, :] - w2[None, :, :]
        w2_diffs = (w2_diffs < 0) * w2_diffs
        w2_diffs_flat = w2_diffs.view(-1, self.num_features)
        
        bounds = self.activation.hessian_bound * w2_diffs_flat * w1 * w1
        bounds = torch.abs(bounds)
        bounds, _ = torch.max(bounds, dim=1)
        bounds = bounds.view(self.num_classes, self.num_classes)
        return bounds
    
    @torch.autograd.no_grad()
    def _crc_certificates(self, x, m, M, w2_diff, b2_diff):
        x_lin1, x_act = self._intermediate(x)
        logits_diff = torch.sum(x_act * w2_diff, dim=1) + b2_diff
        logits_sign = torch.sign(logits_diff)
                
        delta = torch.zeros_like(x)
                
        m = torch.clamp(m, min=self.min_curv)
        M = torch.clamp(M, min=self.min_curv)
        
        eta_min = (logits_diff < 0) * (-torch.reciprocal(M))
        eta_max = (logits_diff > 0) * torch.reciprocal(m)
        
        if self.training:
            inner_iters = self.train_inner_iters
        else:
            inner_iters = self.eval_inner_iters
        
        for i in range(self.outer_iters):
            eta = 0.5 * (eta_min + eta_max)
            eta_mul = eta[:, None]
            
            for j in range(inner_iters + 1):
                x_n = (x + delta)
                x_n_lin1, x_n_act = self._intermediate(x_n)

                grad = self._gradient(x_n_lin1, w2_diff)
                grad_dual = delta + (eta_mul * grad)
                grad_dual_norm = torch.norm(grad_dual, dim=1)
                if (j == inner_iters) or torch.all(grad_dual_norm < self.grad_tolerance):
                    break
            
                inp_hess = self._hessian(x_n_lin1, w2_diff)
                hess = eta_mul * inp_hess
                hess1p = torch.clamp(1 + hess, min=self.inv_eps)
                inv_hess1p = torch.reciprocal(hess1p)
                if torch.any(torch.isnan(inv_hess1p)) or torch.any(torch.isinf(inv_hess1p)):
                    print(i, j)
                    
                    print(torch.any(torch.isnan(inv_hess1p)), 'torch.isnan(inv_hess1p)')
                    print(torch.any(torch.isinf(inv_hess1p)), 'torch.isinf(inv_hess1p)')
                    
                    print_stats(self._w1, 'self._w1')
                    print_stats(self.linear2.weight, 'self.linear2.weight')
                    
                    print_stats(self.curv_bounds(), 'all_curvatures')
                    
                    print_stats(m, 'm')
                    print_stats(M, 'M')
                    
                    print_stats(eta_min, 'eta_min')
                    print_stats(eta_max, 'eta_max')
                    print_stats(eta_mul, 'eta_mul')
                    print_stats(inp_hess, 'inp_hess')
                    print_stats(hess, 'hess')
                    print_stats(hess1p, 'hess1p')
                    print_stats(inv_hess1p, 'inv_hess1p')
                    quit()
                    
                d = - (inv_hess1p * grad_dual)
                delta = delta + d
                
            logits_diff = torch.sum(x_n_act * w2_diff, dim=1) + b2_diff
            
            ge_indicator = (logits_diff > 0)
            eta_min[ge_indicator] = eta[ge_indicator]
            eta_max[~ge_indicator] = eta[~ge_indicator]
            
#             certs = (delta * delta).sum(dim=1) + (2 * eta * logits_diff)    
#             certs = certs * (grad_dual_norm < self.grad_tolerance)
#             certs = torch.sqrt(torch.clamp(certs, min=0.))
#             certs = certs.view_as(logits_sign)
#             print_stats(certs, 'certs, i={:d}'.format(i))
#             print_stats(logits_diff, 'logits, i={:d}'.format(i))
#             print_stats(grad_dual_norm, 'grads, j={:d}'.format(j))
#         quit()            
        return eta, delta, grad_dual_norm, logits_sign

    def _tight_certificates(self, features, m, M, w2_diff, b2_diff):
        num_features = self.num_features
        num_classes = self.num_classes
        
        x = torch.flatten(features, start_dim=1)[:, None, :]        
        x = x.repeat(1, num_classes - 1, 1)
        x = x.view(-1, num_features)
        
        m = m.flatten()
        M = M.flatten()
        
        w2_diff = w2_diff.view(-1, num_features)
        b2_diff = b2_diff.view(-1)
                
        eta, delta, grad_dual_norm, logits_sign = self._crc_certificates(x, m, M, w2_diff, b2_diff)
        
        x_n = (x + delta)
        x_n_lin1, x_n_act = self._intermediate(x_n)
        logits_diff = torch.sum(x_n_act * w2_diff, dim=1) + b2_diff
                    
        certs = (delta * delta).sum(dim=1) + (2 * eta * logits_diff)
        if not(self.training):
            print_stats(grad_dual_norm, 'grad_dual_norm')
#             certs = certs * (grad_dual_norm < self.grad_tolerance)
        
        certs = (certs > self.eps) * torch.sqrt(torch.clamp(certs, min=self.eps))
        certs = certs * logits_sign
        certs = certs.view(-1, num_classes - 1)
        return certs
    
    def _fast_certificates(self, features, m, M, w2_diff, b2_diff):
        num_features = self.num_features
        num_classes = self.num_classes
        
        x_lin1, x_act = self._intermediate(features)
        logits_diff = torch.sum(x_act[:, None, :] * w2_diff, dim=2) + b2_diff
        logits_sign = torch.sign(logits_diff)
        curv_bound = (m * (logits_sign > 0)) + (M * (logits_sign <= 0))

        x_lin1 = x_lin1.repeat(1, num_classes - 1, 1)
        x_lin1 = x_lin1.view(-1, num_features)
        w2_diff = w2_diff.view(-1, num_features)
        
        grad = self._gradient(x_lin1, w2_diff)
        grad = grad.view(-1, num_classes - 1, num_features)
        grad_norm = torch.norm(grad, dim=2)
                
        c_i, c_j = torch.nonzero(curv_bound, as_tuple=True)
        nc_i, nc_j = torch.where(curv_bound == 0)
        
        logits_diff = torch.abs(logits_diff)
        certs_nc = (logits_diff[nc_i, nc_j]/grad_norm[nc_i, nc_j])
        
        logits_diff_c = logits_diff[c_i, c_j]
        curv_bound_c = curv_bound[c_i, c_j]
        curv_bound_c = torch.clamp(curv_bound_c, min=self.min_curv)
        grad_norm_c = grad_norm[c_i, c_j]
        
        certs_c = ((2 * logits_diff_c * curv_bound_c) + (grad_norm_c * grad_norm_c))
        certs_c = torch.clamp(certs_c, min=self.eps)
        certs_c = (torch.sqrt(certs_c) - grad_norm_c) / curv_bound_c

        
        certs = torch.empty_like(curv_bound)
        certs[c_i, c_j] = certs_c
        certs[nc_i, nc_j] = certs_nc
        certs = certs * logits_sign
        return certs


    def certificates(self, features, y=None):
        logits = self.forward(features)
        if y is None:
            y = torch.argmax(logits, dim=1)
        other = other_classes(y, self.num_classes)
        
        y = y.unsqueeze(1)
        w2 = self.linear2.weight
        w2_y = w2[y]
        w2_other = w2[other]
        w2_diff = (w2_y - w2_other)
        
        b2 = self.linear2.bias
        b2_y = b2[y]
        b2_other = b2[other]
        b2_diff = (b2_y - b2_other)
                
        all_curvatures = self.curv_bounds()
        m = all_curvatures[y, other]
        M = all_curvatures[other, y]

#         certs = self._tight_certificates(features, m, M, w2_diff, b2_diff)
        if self.training:
#             certs = self._fast_certificates(features, m, M, w2_diff, b2_diff)
            certs = self._tight_certificates(features, m, M, w2_diff, b2_diff)
        else:
#             certs = self._fast_certificates(features, m, M, w2_diff, b2_diff)
            certs = self._tight_certificates(features, m, M, w2_diff, b2_diff)

        batch_size = y.shape[0]
        batch_idxs = torch.arange(batch_size).unsqueeze(1)
        logits_cert = torch.zeros_like(logits)
        logits_cert[batch_idxs, other] = -certs
                

#         y = y[:, 0]
#         preds = torch.argmax(logits, dim=1)
#         preds_cert = torch.argmax(logits_cert, dim=1)
#         correct = (preds == y)
#         certs, _ = torch.min(certs, dim=1)        
#         logits_cert[torch.arange(batch_size), y] = np.float('inf')
#         certs_n, _ = torch.min(-logits_cert, dim=1)
#         certs_n = F.relu(certs_n * correct)
#         idxs = torch.where(torch.logical_and(correct, certs_n < 0))[0]
#         if len(idxs) > 0:
#             print(len(idxs))

#             print('y=', y[idxs])
#             print('preds=', preds[idxs])
#             print('preds_cert=', preds_cert[idxs])

#             print(logits[idxs])
#             print(logits_cert[idxs])

#             print(preds.shape, y.shape)
            
#             print_stats(certs)
#             print_stats(certs_n)
            
#             quit()
        return logits, logits_cert

        
    
class LipPool(nn.Module):
    def __init__(self, channels, stride=2, theta=None):
        super(LipPool, self).__init__()
        assert stride == 2, stride
        assert (channels % stride) == 0
        self.num_groups = channels // stride
        
        if theta is None:
            self.theta = nn.Parameter(np.pi * torch.rand(self.num_groups).cuda(), requires_grad=True)
        else:
            self.theta = nn.Parameter(theta * torch.ones(self.num_groups).cuda(), requires_grad=True)
        
    def forward(self, z, axis=1, verbose=False):        
        x, y = z.split(z.shape[axis] // 2, axis)        
        z_theta = torch.atan2(y, x) % (2 * np.pi)
        
        theta_shape = ([1] * axis) + [self.num_groups] + ([1] * (z.ndim - axis - 1))        
        theta = torch.clamp(self.theta, 0., 2 * np.pi).view(theta_shape)

        out1 = (-x*torch.sin(0.5 * theta)) + (y*torch.cos(0.5 * theta))
        out2 = torch.stack([x, y], dim=1).norm(dim=1)
        out3 = (-x*torch.sin(0.5 * theta)) - (y*torch.cos(0.5 * theta))
        
        theta_leq_pi = (theta <= np.pi)

        
        select0_1 = (z_theta <= 0.5 * (theta + np.pi))
        select0_2 = torch.logical_and(z_theta > 0.5 * (theta + np.pi), z_theta < 0.5 * (3 * np.pi - theta))
        select0_3 = torch.logical_and(~select0_1, ~select0_2)
        
        assert torch.all(torch.logical_not(torch.logical_and(select0_1, select0_2)))
        assert torch.all(torch.logical_not(torch.logical_and(select0_2, select0_3)))
        assert torch.all(torch.logical_not(torch.logical_and(select0_1, select0_3)))        
        assert torch.all(select0_1 + select0_2 + select0_3 == 1)
        
        dists0 = (select0_1 * out1) + (select0_2 * out2) + (select0_3 * out3)        

        
        select1_1 = torch.logical_and(z_theta >= 0.5 * (theta - np.pi), z_theta < np.pi)
        select1_3 = torch.logical_and(z_theta <= 0.5 * (5 * np.pi - theta), z_theta >= np.pi)
        select1_2 = torch.logical_and(~select1_1, ~select1_3)
        
        assert torch.all(torch.logical_not(torch.logical_and(select1_1, select1_2)))
        assert torch.all(torch.logical_not(torch.logical_and(select1_2, select1_3)))
        assert torch.all(torch.logical_not(torch.logical_and(select1_1, select1_3)))        
        assert torch.all(select1_1 + select1_2 + select1_3 == 1)
        
        dists1 = (select1_1 * out1) + (select1_2 * (-out2)) + (select1_3 * out3)
        
        dists = theta_leq_pi * dists0 + (~theta_leq_pi) * dists1
        return dists
    
    
        
class LipPool(nn.Module):
    def __init__(self, channels, stride=2, theta=None):
        super(LipPool, self).__init__()
        assert stride == 2, stride
        assert (channels % stride) == 0
        self.num_groups = channels // stride
        
        if theta is None:
            self.theta = nn.Parameter(np.pi * torch.rand(self.num_groups).cuda(), requires_grad=True)
        else:
            self.theta = nn.Parameter(theta * torch.ones(self.num_groups).cuda(), requires_grad=True)

    def forward(self, z, axis=1, verbose=False):
        x, y = z.split(z.shape[axis] // 2, axis)        
        z_theta = torch.atan2(y, x) % (2 * np.pi)
        
        theta_shape = ([1] * axis) + [self.num_groups] + ([1] * (z.ndim - axis - 1))        
        theta = torch.clamp(self.theta, 0., np.pi).view(theta_shape)

        out1 = (-x*torch.sin(0.5 * theta)) + (y*torch.cos(0.5 * theta))
        out2 = torch.stack([x, y], dim=1).norm(dim=1)
        out3 = (-x*torch.sin(0.5 * theta)) - (y*torch.cos(0.5 * theta))
        
#         theta_leq_pi = (theta <= np.pi)

        
        select0_1 = (z_theta <= 0.5 * (theta + np.pi))
        select0_2 = torch.logical_and(z_theta > 0.5 * (theta + np.pi), z_theta <= 0.5 * (3 * np.pi - theta))
        select0_3 = torch.logical_and(~select0_1, ~select0_2)
        
        dists = (select0_1 * out1) + (select0_2 * out2) + (select0_3 * out3)        
        
#         select1_1 = torch.logical_and(z_theta > 0.5 * (theta - np.pi), z_theta <= np.pi)
#         select1_3 = torch.logical_and(z_theta > np.pi, z_theta <= 0.5 * (5 * np.pi - theta))
#         select1_2 = torch.logical_and(~select1_1, ~select1_3)
        
#         dists1 = (select1_1 * out1) + (select1_2 * (-out2)) + (select1_3 * out3)
        
#         dists = theta_leq_pi * dists0 + (~theta_leq_pi) * dists1
        return dists


class SN_Linear(nn.Linear):
    def __init__(self, num_features, num_classes, bias=True, init_iters=50, update_iters=1, bound=None, normalize=False):
        super().__init__(num_features, num_classes, bias)
        
        self.init_iters = init_iters
        self.update_iters = update_iters

        if bound is None:
            bound = np.float('inf')
        self.bound = bound
        self.normalize = normalize
            
        nn.init.orthogonal_(self.weight)
        self._initialize_singular_vectors()
        self.update_sigma()
        
    def _get_weight(self):
        if self.normalize:
            weight_norm = torch.norm(self.weight, dim=1, keepdim=True)
            return self.weight/weight_norm
        else:
            return self.weight

    def _initialize_singular_vectors(self):
        weight = self._get_weight()
        num_classes, num_features = weight.shape
        
        u = weight.new_empty((num_classes)).normal_(0, 1)
        self.register_buffer('_u', F.normalize(u, dim=0))
        
        v = weight.new_empty((num_features)).normal_(0, 1)
        self.register_buffer('_v', F.normalize(v, dim=0))
            
    @torch.autograd.no_grad()
    def _power_method(self, num_iters):
        weight = self._get_weight()
        for i in range(num_iters):
            self._v = F.normalize(torch.mv(weight.T, self._u), dim=0)
            self._u = F.normalize(torch.mv(weight, self._v), dim=0)

    def update_sigma(self):
        self._power_method(num_iters=self.init_iters)
        
    def compute_weight(self, return_sigma=False):
        weight = self._get_weight()
        if self.training:
            self._power_method(num_iters=self.update_iters)

        sigma = torch.dot(self._u, torch.mv(weight, self._v))
        
        correction = torch.clamp(sigma, -self.bound, self.bound)
        weight = correction * (weight / sigma)
        
        if return_sigma:
            return weight, correction
        else:
            return weight

    def forward(self, features):
        features = torch.flatten(features, start_dim=1)
        
        if self.bound is None:
            self._weight_n = weight
        else:
            self._weight_n = self.compute_weight()
        logits = F.linear(features, self._weight_n, self.bias)
        return logits
        
    def certificates(self, features, y=None):
        logits = self.forward(features)
        if y is None:
            y = torch.argmax(logits, dim=1)

        batch_size = logits.shape[0]
        batch_idxs = torch.arange(batch_size)

        onehot = torch.zeros_like(logits)
        onehot[batch_idxs, y] = 1.
        num_classes = logits.shape[1]
        class_idxs = torch.arange(num_classes).expand(batch_size, -1)
        other = class_idxs[onehot == 0]
        other = other.view(batch_size, num_classes - 1)
        
        logits_y = logits[batch_idxs, y]
        logits_other = logits[batch_idxs.unsqueeze(1), other]        
        logits_diff = logits_y.unsqueeze(1) - logits_other
        
        weight = self._weight_n
#         print(weight.shape)
        sq_sum = torch.sum(weight * weight, dim=1)
#         print(sq_sum.shape)
        weight_pdists = weight.mm(weight.T)
        weight_sqsum = torch.diag(weight_pdists)
#         print(weight_pdists.shape)
        weight_pdists = weight_sqsum[:, None] + weight_sqsum[None, :] - (2 * weight_pdists)
#         weight_pdists += torch.sum(weight * weight, dim=1, keepdim=True)
        
        norm_sq_diff = weight_pdists[y.unsqueeze(1), other]
        norm_diff = torch.sqrt(norm_sq_diff)

        certs = logits_diff/norm_diff
        return logits, certs

    

    def _fast_certificates(self, features, y, other):
        batch_size = features.shape[0]
        batch_idxs = torch.arange(batch_size).unsqueeze(1)
        num_features = self.num_features
        num_classes = self.num_classes
        
        w2 = self.linear2.weight        
        b2 = self.linear2.bias
                
        m, M = self.curvature_bounds(y, other)
        
        x_lin1, x_act = self._intermediate(features)
        logits = F.linear(x_act, w2, b2)
        logits_y = logits[batch_idxs, y]
        logits_other = logits[batch_idxs, other]
        logits_diff = logits_y - logits_other
        logits_sign = torch.sign(logits_diff)
        curv_bound = (m * (logits_sign > 0)) + (M * (logits_sign <= 0))
        curv_bound = torch.abs(curv_bound)
        
        grad = w2[y] - w2[other]
        act_grad = self.activation.gradient(x_lin1)
        grad = grad * act_grad[:, None, :]
        grad_batched = grad.view(-1, grad.shape[2])
        grad_batched = torch.mm(grad_batched, self.linear1.weight)
        grad = grad_batched.view(-1, num_classes - 1, num_features)
        grad_norm = torch.norm(grad, dim=2)
                
        curv_i, curv_j = torch.nonzero(curv_bound, as_tuple=True)
        nocurv_i, nocurv_j = torch.where(curv_bound == 0)
        
        logits_diff = torch.abs(logits_diff)
        certs_nocurv = (logits_diff[nocurv_i, nocurv_j]/grad_norm[nocurv_i, nocurv_j])
        
        logits_diff_curv = logits_diff[curv_i, curv_j]
        m_curv = curv_bound[curv_i, curv_j]
        grad_norm_curv = grad_norm[curv_i, curv_j]
        certs_curv = ((2 * logits_diff_curv * m_curv) + (grad_norm_curv * grad_norm_curv))
        certs_curv = (torch.sqrt(certs_curv) - grad_norm_curv) / m_curv
        
        certs = torch.empty_like(curv_bound)
        certs[curv_i, curv_j] = certs_curv
        certs[nocurv_i, nocurv_j] = certs_nocurv
        certs = certs * logits_sign
        return certs

    
def _fast_curvature_bounds(self, y, other):
    w1 = self.linear1.weight
    w2 = self.linear2.weight
    hess = self.activation.hessian_bound

    w1_u = torch.mv(w1, self._u)
    w2_y = w2[y]
    w2_other = w2[other]
    w2_diffs = w2_y - w2_other

    diag_neg = w2_diffs * (w2_diffs < 0)
    diag_pos = w2_diffs * (w2_diffs > 0)

    w1_u = w1_u[None, None, :]
    m = hess * torch.sum(w1_u * diag_neg * w1_u, dim=2)
    M = hess * torch.sum(w1_u * diag_pos * w1_u, dim=2)

    m, M = torch.abs(m), torch.abs(M)
    return m, M

def _fast_curvature_bounds(self, y, other):
    w1 = self.linear1.weight
    w2 = self.linear2.weight
    hess = self.activation.hessian_bound

    sigma = torch.dot(self._v, torch.mv(w1, self._u))

    w2_y = w2[y]
    w2_other = w2[other]
    w2_diffs = w2_y - w2_other

    diag_neg, _ = torch.min(w2_diffs * (w2_diffs < 0), dim=2)
    diag_pos, _ = torch.max(w2_diffs * (w2_diffs > 0), dim=2)

    m = hess * sigma * sigma * torch.abs(diag_neg)
    M = hess * sigma * sigma * torch.abs(diag_pos)
    return m, M