import torch
from torch import nn
def getPi(mask):
    dim=mask.shape[0]
    p=torch.zeros([dim,dim],dtype=torch.float)
    for i in range(dim):
        p[i][mask[i]]=1
    ip = torch.linalg.inv (p)    
    return p,ip

dim=768
maskc = torch.randperm(dim)
pc,ipc=getPi(maskc)
torch.save(pc,'./key_768.pt')
torch.save(ipc,'./unkey_768.pt')

#print('key\n',key)
#print('unkey\n',unkey)
#
#print('validation:\n',torch.matmul(key,unkey))
