import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from Models.oni_lib import ONI_Linear

class RawAttLayer(nn.Module):
    def __init__(self,):
        super(RawAttLayer, self).__init__()

    def forward(self, x):
        assert len(x.shape) == 1
        s = F.softmax(x)
        y = (s*x).sum()
        return y

def calc_emp_lip(func, x, lr=0.1):
    func.cuda()
    x = x.cuda()

    y = func(x).detach()
    delta = torch.randn_like(x) * 1e-2
    delta.requires_grad = True
    eps = 1.0
    max_lip = 0.0
    for _ in range(20):
        x2 = x+delta
        y2 = func(x2)
        dist = (y2-y).norm()
        print (dist, delta.norm())
        cur_lip = dist / delta.norm()
        if cur_lip > max_lip:
            max_lip = cur_lip
        
        delta.grad = None
        #dist.backward()
        cur_lip.backward()
        delta = (delta + lr*delta.grad).detach()
        delta.requires_grad = True
    return max_lip

def calc_rawatt_grad_bound(N):
    def func(x):
        s = F.softmax(x)
        y = (s*x).sum()
        grad = s * (1+x-y)
        #return grad[0]
        return grad.norm()
    x = nn.Parameter(torch.FloatTensor(N).normal_() * 0.1)
    opt = torch.optim.SGD([x], lr=1e-1)
    for _ in range(1000):
        tgt = -func(x)
        print (x, -tgt)
        opt.zero_grad()
        tgt.backward()
        opt.step()

    return func(x)

class AttLayer(nn.Module):
    def __init__(self,):
        super(AttLayer, self).__init__()
        self.weight = torch.FloatTensor(1+2+4,1).normal_()
        self.weight = self.weight / self.weight.view(-1).norm()

    def forward(self, x):
        assert len(x.shape) == 2
        s = F.softmax(x @ self.weight, dim=0)
        y = (s*x).sum(0)
        return y

def calc_att_grad_bound(N, weight):
    #weight = torch.ones_like(weight)
    #weight = weight / weight.sum()
    weight = torch.abs(weight)

    def func(x, ord=None):
        s = F.softmax((x @ weight).squeeze(1), dim=0).unsqueeze(1)
        y = (s*x).sum(0)
        all_grads = []
        for si,xi in zip(s,x):
            cur_grad = si * ( torch.eye(len(xi)).to(xi.device) + (weight @ xi.unsqueeze(0)) - weight @ y.unsqueeze(0) )
            all_grads.append(cur_grad)
        #return all_grads[0].view(-1).norm()
        #return torch.linalg.norm(all_grads[0], ord=None)
        #return torch.linalg.norm(all_grads[0], ord=2)
        #print (all_grads[0].min(), all_grads[0].max(), all_grads[0].view(-1).norm())
        return torch.linalg.norm(all_grads[0], ord=ord)

    x = nn.Parameter(torch.FloatTensor(N,1+2+4).normal_() * 0.1)
    ###
    x.data = x.data.cuda()
    weight = weight.cuda()
    ###
    opt = torch.optim.SGD([x], lr=1e-1)
    for _ in range(1000):
        tgt = -func(x,ord=2)
        print (-tgt)
        opt.zero_grad()
        tgt.backward()
        opt.step()

    print (weight, weight.view(-1).norm())
    print (x)
    s = F.softmax((x @ weight).squeeze(1), dim=0).unsqueeze(1)
    print (s)
    y = (s*x).sum(0)
    print (y)
    return func(x,ord=2)

def main():
    #layer = ONI_Linear(128, 128)
    #x = torch.FloatTensor(1,128).normal_()
    #max_lip = calc_emp_lip(layer, x)
    #print (max_lip)

    #N = 10
    ##x = torch.FloatTensor(N).normal_() * 0.1
    #x = torch.FloatTensor([1.0]+[0.0]*(N-1))
    #layer = RawAttLayer()
    #print (x)
    #print (layer(x))
    ##max_lip = calc_emp_lip(layer, x, lr=0.1)
    ##print (max_lip)
    #max_lip = calc_rawatt_grad_bound(N)
    #print (max_lip)

    N = 10
    x = torch.FloatTensor(N,1+2+4).normal_() * 0.1
    layer = AttLayer()
    print (layer(x))
    max_lip = calc_att_grad_bound(N, layer.weight)
    print (max_lip)

if __name__ == '__main__':
    main()
