from custom_activations import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from utils import *

# from last_layers import CRC_Diag

# m = CRC_Diag(1024, 10).cuda()
# x = torch.randn(512, 1024).cuda()
# x.requires_grad = True

# logits = m(x)

# grad1 = torch.autograd.grad(logits.sum(), x)[0]

# grad2 = m.gradient(x)

# print(grad1.shape)
# print(grad2.shape)

# print(torch.allclose(grad1, grad2))

# diff = grad1 - grad2
# print_stats(diff.abs(), 'diff')


# def test_single():
#     n = 1000
#     layer = LipPool(2)
#     l_t = layer.theta.detach().item()

#     t1 = torch.tensor(0.5 * (l_t + np.pi) - ((2 * np.pi) / n) )
#     z1 = torch.stack([torch.cos(t1), torch.sin(t1)], axis=0).cuda()
#     d1 = layer(z1, axis=0)

#     t2 = torch.tensor(0.5 * (l_t + np.pi) + ((2 * np.pi) / n) )
#     z2 = torch.stack([torch.cos(t2), torch.sin(t2)], axis=0).cuda()
#     d2 = layer(z2, axis=0)

#     c_d = (d2 - d1).detach().abs().item()
        
#     print(t1, t2)
    
#     print('{:.3f}, {:.3f}'.format(c_d, l_t))
    
def test():
    n = 100000

    max_d = 0
    max_l = 0

    for i in range(n + 1):
        theta = (2 * np.pi) * (i/n)
        layer = LipPool(2, theta=theta)
        l_t = layer.theta.detach().item()

        t = (2 * np.pi) * (torch.arange(n + 1)/n).cuda()

        x = torch.cos(t)
        y = torch.sin(t)

        z = torch.cat([x, y], axis=0)

        d = layer(z, axis=0)
        diff = torch.diff(d)
        diff_t = torch.diff(t)
        c_d = diff.detach().abs().max().item()

        sl = (diff/diff_t).detach().abs()
        c_l = sl.max().item()
        
        if c_l > max_l:
            max_l = c_l
            idx_l = torch.argmax(sl)
            max_t = l_t
            max_d = c_d
            
    r = (2 * np.pi)/n
#     print( (sl > 1).sum())
    
#     print('lip stats: {:.3f}, {:.3f}, {:.3f}'.format(sl.min().item(), sl.mean().item(), sl.max().item()))
    print('{:.3f}, {:.3f}, {:f}, {:d}, {:f}'.format(max_l, max_d, max_t, idx_l, r))
    
test()

#     print(x.shape, x[4999:5001], y[4999:5001])

# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.autograd import Function

# def softplus(x):
#     beta = 4.
#     x = beta * x
    
#     pos_x = (x >= 0)
#     abs_x = torch.abs(x)
#     eval_neg_x = F.softplus(-abs_x)
    
#     eval_x = pos_x * (x + eval_neg_x) + (~ pos_x) * eval_neg_x
#     return eval_x/beta 


# def test_softplus():
#     n = 1000
#     x1 = 0
#     x2 = 0
#     max_diff = 0
#     for i in range(-n, n):
#         x = i/n
#         v1 = softplus(torch.tensor(x))
#         v2 = softplus(torch.tensor(x - (1/n)))
#         diff = torch.abs(v2 - v1)
#         if diff > max_diff:
#             max_diff = diff
#             x1, x2 = (x - (1/n)), x
#     print(max_diff, x1, x2)
    
# v1 = softplus(torch.tensor(0.1))
# v2 = softplus(torch.tensor(-0.1))
# print(v1, v2)

# test_softplus()