import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from Models.oni_lib import ONI_Linear

N = 10
D = 2

class KimsqrtAtt(nn.Module):
    def __init__(self,):
        print ("KimsqrtAtt")
        super(KimsqrtAtt, self).__init__()
    def forward(self, x):
        dist = -torch.sqrt(( (x.unsqueeze(0)-x.unsqueeze(1))**2 ).sum(2)+1e-8)
        return dist
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            return torch.zeros_like(x[0])
        elif j==i:
            dist_kj = torch.sqrt(((x[k]-x[j])**2).sum()+1e-8) + 1e-8
            return (x[k]-x[j]) / dist_kj
        elif j==k:
            dist_ij = torch.sqrt(((x[i]-x[j])**2).sum()+1e-8) + 1e-8
            return (x[i]-x[j]) / dist_ij
        else:
            assert 0

class KimAtt(nn.Module):
    def __init__(self,):
        print ("KimAtt")
        super(KimAtt, self).__init__()
    def forward(self, x):
        dist = -( (x.unsqueeze(0)-x.unsqueeze(1))**2 ).sum(2)
        return dist
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            return torch.zeros_like(x[0])
        elif j==i:
            return 2*(x[k]-x[j])
        elif j==k:
            return 2*(x[i]-x[j])
        else:
            assert 0

class DotAtt(nn.Module):
    def __init__(self,):
        print ("DotAtt")
        super(DotAtt, self).__init__()
    def forward(self, x):
        dist = x @ x.T
        return dist
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            return 2*x[j]
        elif j==i:
            return x[k]
        elif j==k:
            return x[i]
        else:
            assert 0

class CosAtt(nn.Module):
    def __init__(self,):
        print ("CosAtt")
        super(CosAtt, self).__init__()
    def forward(self, x):
        dot = x @ x.T
        x_norm = torch.sqrt((x**2).sum(1)+1e-8)+1e-8
        dist = dot / x_norm.unsqueeze(0) / x_norm.unsqueeze(1)
        return dist
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            return torch.zeros_like(x[0])
        elif j==i:
            anorm = torch.sqrt((x[j]**2).sum()+1e-8)+1e-8
            bnorm = torch.sqrt((x[k]**2).sum()+1e-8)+1e-8
            dot = (x[j]*x[k]).sum()
            return x[k] / (anorm * bnorm) - (dot*x[j])/( (anorm**3)*bnorm )
        elif j==k:
            anorm = torch.sqrt((x[j]**2).sum()+1e-8)+1e-8
            bnorm = torch.sqrt((x[i]**2).sum()+1e-8)+1e-8
            dot = (x[j]*x[i]).sum()
            return x[i] / (anorm * bnorm) - (dot*x[j])/( (anorm**3)*bnorm )
        else:
            assert 0

class MultiAtt(nn.Module):
    def __init__(self,):
        print ("MultiAtt")
        super(MultiAtt, self).__init__()
        w = torch.FloatTensor(D,D).normal_()
        w = (torch.eye(D)+w-w.T) @ torch.linalg.inv(torch.eye(D)-w+w.T)
        self.w = nn.Parameter(w)
    def forward(self, x):
        return x @ self.w @ x.T
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            return 2*(self.w@x[j])
        elif j==i:
            return self.w@x[k]
        elif j==k:
            return self.w@x[i]
        else:
            assert 0

class AddAtt(nn.Module):
    def __init__(self,):
        print ("AddAtt")
        super(AddAtt, self).__init__()
        w = torch.FloatTensor(D,1).normal_()
        w = w / w.view(-1).norm()
        self.w = nn.Parameter(w)
    def forward(self, x):
        return ((x.unsqueeze(0)+x.unsqueeze(1)) @ self.w).squeeze(2)
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            return 2*self.w.squeeze(1)
        elif j==i:
            return self.w.squeeze(1)
        elif j==k:
            return self.w.squeeze(1)
        else:
            assert 0

class ReluAddAtt(nn.Module):
    def __init__(self,):
        print ("ReluAddAtt")
        super(ReluAddAtt, self).__init__()
        w = torch.FloatTensor(D,1).normal_()
        w = w / w.view(-1).norm()
        self.w = nn.Parameter(w)
    def forward(self, x):
        return -F.relu((x.unsqueeze(0)+x.unsqueeze(1)) @ self.w).squeeze(2)
    def get_jacobian(self, x, i, j, k):
        # calculates d(f_{ik})/d(x_j)
        if j!=i and j!=k:
            return torch.zeros_like(x[0])
        elif j==i==k:
            xin = 2*x[j]
            return -2*self.w.squeeze(1)*(xin>=0)
        elif j==i:
            xin = x[j]+x[k]
            return -self.w.squeeze(1)*(xin>=0)
        elif j==k:
            xin = x[j]+x[i]
            return -self.w.squeeze(1)*(xin>=0)
        else:
            assert 0

class AttLayer(nn.Module):
    def __init__(self,):
        super(AttLayer, self).__init__()
        #self.att_calc_module = KimAtt()
        #self.att_calc_module = KimsqrtAtt()
        #self.att_calc_module = DotAtt()
        #self.att_calc_module = CosAtt()
        #self.att_calc_module = MultiAtt()
        #self.att_calc_module = AddAtt()
        self.att_calc_module = ReluAddAtt()

    def forward(self, x):
        assert len(x.shape) == 2
        att_score = self.att_calc_module(x)
        att_prob = F.softmax(att_score, dim=1)
        y = att_prob @ x
        return y

    def get_jacobian(self, x):
        att_score = self.att_calc_module(x)
        att_prob = F.softmax(att_score, dim=1)
        y = att_prob @ x
        #print (att_prob)

        all_jacs = []
        for i in range(N):
            row_jacs = []
            #f_jacs = [self.att_calc_module.get_jacobian(x,i,j) for j in range(N)]
            for j in range(N):
                if i != j:
                    jac_ij = att_prob[i,j] * (torch.eye(D).to(y.device) + self.att_calc_module.get_jacobian(x,i,j,j).unsqueeze(1) @ (x[j]-y[i]).unsqueeze(0))
                else:
                    jac_ij = att_prob[i,i] * torch.eye(D).to(y.device)
                    for k in range(N):
                        jac_ij = jac_ij + att_prob[i,k]*(self.att_calc_module.get_jacobian(x,i,i,k).unsqueeze(1) @ (x[k]-y[i]).unsqueeze(0))
                #jac_ij = torch.eye(D).to(y.device) + f_jacs[j].unsqueeze(1) @ x[j].unsqueeze(0)
                #for k in range(N):
                #    jac_ij = jac_ij - att_prob[i,k] * (f_jacs[k].unsqueeze(1) @ x[k].unsqueeze(0))
                #jac_ij = att_prob[i,j] * jac_ij
                row_jacs.append(jac_ij)
            row_jacs = torch.cat(row_jacs, 1)
            all_jacs.append(row_jacs)
        all_jacs = torch.cat(all_jacs, 0)
        return all_jacs

def calc_att_grad_bound(layer):
    x = nn.Parameter(torch.FloatTensor(N,D).normal_() * 0.1)
    ####
    #x.data = x.data.cuda()
    #layer = layer.cuda()
    ####
    #opt = torch.optim.SGD([x], lr=1e-1)
    opt = torch.optim.Adam([x], lr=1e-2)
    for _ in range(10000):
        #print ("=====")
        #tgt = -torch.linalg.norm(layer.get_jacobian(x)[:D,:D], ord=2)
        tgt = -torch.linalg.norm(layer.get_jacobian(x), ord=2)
        print (_, -tgt)
        #print (x)
        opt.zero_grad()
        tgt.backward()
        opt.step()

    return torch.linalg.norm(layer.get_jacobian(x), ord=2)

def main():
    x = torch.FloatTensor(N,D).normal_() * 0.1
    layer = AttLayer()
    #print (x)
    #print (layer(x))
    #print (x.shape)
    #print (layer(x).shape)
    #print (layer.get_jacobian(x))
    #print (layer.get_jacobian(x).shape)
    #assert 0

    #x = torch.FloatTensor([[100],[101]])
    #print (layer.get_jacobian(x))
    #assert 0

    max_lip = calc_att_grad_bound(layer)
    print (max_lip)

if __name__ == '__main__':
    main()
