import torch
import torch.nn as nn
import math
import torch.optim as optim
import numpy as np
#import matplotlib.pyplot as plt
import random
from scipy.optimize import linear_sum_assignment
import scipy
#from joblib import Parallel, delayed;
#import multiprocessing;
#import itertools

def fc_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    gamma2 = wt[:,-1]
    wt = torch.einsum('ij,i->ij', wt[:,:(-1)], 1/gamma2)
    x0 = w01[:,-1]
    w01 = w01[:,:(-1)]
    M = torch.einsum('ik,kj->ijk', w02, w01)-torch.einsum('ik,k,ij->ijk', w02, x0, wt)
    del w01
    M = M.reshape([M.size(0)*M.size(1),M.size(2)])
    z = -M[:,0]
    gamma1 = torch.linalg.solve(M[:,1:], z)
    del M
    gamma1 = torch.cat((torch.tensor([1]), gamma1), 0)
    gamma2 = gamma2/torch.einsum('ik,k->i',w02,x0*gamma1)
    bias1 = torch.ones(len(gamma1))*act_shift
    bias2 = bt-torch.matmul(w02,bias1)
    return gamma1, bias1, gamma2, bias2

def fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    x = torch.sum(torch.abs(wt),dim=(1))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:]
    removal = nout-wt.size(0)
    mfull = wt.size(0)*(wt.size(1)-1)+1
    mold = w01.size(0)
    m = min(mfull, w01.size(0))
    if removal > 0:
        w01 = w01[:m,:,]
        w02 = w02[:,:m]
        w02 = w02[ind_keep,:]
    #ind = torch.zeros(wt.size(0),dtype=int,device=device)
    if m < mfull:
        _, dropin_ind = torch.sort(torch.abs(wt.flatten()),descending=True)
        dropin_ind = dropin_ind[:(m+wt.size(0)-1)]
    xim = torch.zeros((wt.size(0), w01.size(0)),device=device)
    gamma2 = torch.zeros(wt.size(0), device=device)
    # print(w01.size())
    # print(w02.size())
    # print(xim.size())
    # print(wt.size())
    for i in range(wt.size(0)):
        ii = torch.argmax(torch.abs(wt[i,:]))
        xim[i,:] = w01[:,ii]*w02[i,:]
        gamma2[i] = wt[i,ii]
    wt = torch.einsum('ij,i->ij', wt, 1/gamma2)
    #x0 = w01[:,-1]
    #w01 = w01[:,:(-1)]
    M = torch.einsum('ik,kj->ijk', w02, w01)-torch.einsum('ik,ij->ijk', xim, wt)
    del w01
    M = M.reshape([M.size(0)*M.size(1),M.size(2)])
    if m < mfull:
        #print(M.size())
        #print(dropin_ind.size())
        M = M[dropin_ind,:]
    #print(M.size())
    trows = M.size(1)-1
    x = torch.sum(torch.abs(M),dim=1)
    _, ii = torch.sort(x,descending=True)
    #xx, _ = torch.sort(x,descending=True)
    #thr = xx[trows-1]
    #del xx
    M = M[ii[:trows],:]
    del ii
    z = -M[:,0]
    M = M[:,1:]
    #print("cond number M:")
    #U, S, Vh = torch.linalg.svd(M)
    #condNumber = S[0]/S[-1]
    #print(condNumber)
    #precond = True
    # if precond:
    #     ss = S.size(0)
    #     C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
    #     M = M@C
    #     #print("sec cond number M:")
    #     #U, S, Vh = torch.linalg.svd(M)
    #     #ss = S.size(0)
    #     #print(S[0]/S[-1])
    #     #C2 = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
    #     #M = M@C2
    #     #C = C@C2
    #     #print("third cond number M:")
    #     #U, S, Vh = torch.linalg.svd(M)
    #     #print(S[0]/S[-1])
    gamma1 = torch.linalg.solve(M, z)
    # if precond:
    #     gamma1 = C@gamma1
    del M
    gamma1 = torch.cat((torch.tensor([1]), gamma1), 0)
    #gamma2 = gamma2/torch.einsum('ik,k->i',w02,x0*gamma1)
    gamma2 = gamma2/(torch.einsum('ik,k->i',xim,gamma1))
    bshift = torch.sum(w02,dim=(1)) #torch.matmul(torch.sum(w02,dim=2),bias1)
    if removal > 0:
        gamma2out = torch.zeros(nout)
        gamma2out[ind_keep] = gamma2
        gamma2 = gamma2out
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
        bshiftout = torch.zeros(nout)
        bshiftout[ind_keep] = bshift
        bshift = bshiftout
    bias1 = torch.ones(len(gamma1))*act_shift
    bias2 = bt-bshift*act_shift
    return gamma1, bias1, gamma2, bias2

def fc_1to2layer_exact_remove_conditions_precond(wt, bt, w01, w02, act_scale, act_shift, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    x = torch.sum(torch.abs(wt),dim=(1))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:]
    removal = nout-wt.size(0)
    mfull = wt.size(0)*(wt.size(1)-1)+1
    mold = w01.size(0)
    m = min(mfull, w01.size(0))
    if removal > 0:
        w01 = w01[:m,:,]
        w02 = w02[:,:m]
        w02 = w02[ind_keep,:]
    #ind = torch.zeros(wt.size(0),dtype=int,device=device)
    if m < mfull:
        _, dropin_ind = torch.sort(torch.abs(wt.flatten()),descending=True)
        dropin_ind = dropin_ind[:(m+wt.size(0)-1)]
    xim = torch.zeros((wt.size(0), w01.size(0)),device=device)
    gamma2 = torch.zeros(wt.size(0), device=device)
    # print(w01.size())
    # print(w02.size())
    # print(xim.size())
    # print(wt.size())
    for i in range(wt.size(0)):
        ii = torch.argmax(torch.abs(wt[i,:]))
        xim[i,:] = w01[:,ii]*w02[i,:]
        gamma2[i] = wt[i,ii]
    wt = torch.einsum('ij,i->ij', wt, 1/gamma2)
    #x0 = w01[:,-1]
    #w01 = w01[:,:(-1)]
    M = torch.einsum('ik,kj->ijk', w02, w01)-torch.einsum('ik,ij->ijk', xim, wt)
    del w01
    M = M.reshape([M.size(0)*M.size(1),M.size(2)])
    if m < mfull:
        #print(M.size())
        #print(dropin_ind.size())
        M = M[dropin_ind,:]
    #print(M.size())
    trows = M.size(1)-1
    x = torch.sum(torch.abs(M),dim=1)
    _, ii = torch.sort(x,descending=True)
    #xx, _ = torch.sort(x,descending=True)
    #thr = xx[trows-1]
    #del xx
    M = M[ii[:trows],:]
    del ii
    z = -M[:,0]
    M = M[:,1:]
    try:
        print("cond number M:")
        U, S, Vh = torch.linalg.svd(M)
        condNumber = S[0]/S[-1]
        print(condNumber)
        ss = S.size(0)
        C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
        if torch.sum(torch.isnan(C)) <= 0:
            M = M@C
            gamma1 = torch.linalg.solve(M, z)
            gamma1 = C@gamma1
        else:
            gamma1 = torch.linalg.solve(M, z)
        #print("sec cond number M:")
        #U, S, Vh = torch.linalg.svd(M)
        #ss = S.size(0)
        #print(S[0]/S[-1])
        #C2 = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
        #M = M@C2
        #C = C@C2
        #print("third cond number M:")
        #U, S, Vh = torch.linalg.svd(M)
        #print(S[0]/S[-1])
    except:
        condNumber=torch.tensor([0])
        gamma1 = torch.linalg.solve(M, z)

    del M
    gamma1 = torch.cat((torch.tensor([1]), gamma1), 0)
    #gamma2 = gamma2/torch.einsum('ik,k->i',w02,x0*gamma1)
    gamma2 = gamma2/(torch.einsum('ik,k->i',xim,gamma1))
    bshift = torch.sum(w02,dim=(1)) #torch.matmul(torch.sum(w02,dim=2),bias1)
    if removal > 0:
        gamma2out = torch.zeros(nout)
        gamma2out[ind_keep] = gamma2
        gamma2 = gamma2out
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
        bshiftout = torch.zeros(nout)
        bshiftout[ind_keep] = bshift
        bshift = bshiftout
    bias1 = torch.ones(len(gamma1))*act_shift
    bias2 = bt-bshift*act_shift
    return gamma1, gamma2, condNumber


def fc_1to2layer_exact_project(wt, bt, w01, w02, act_scale, act_shift, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    x = torch.sum(torch.abs(wt),dim=(1))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:]
    removal = nout-wt.size(0)
    if removal > 0:
        mfull = wt.size(0)*(wt.size(1)-1)+1
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,]
        w02 = w02[:,:m]
        w02 = w02[ind_keep,:]
    #ind = torch.zeros(wt.size(0),dtype=int,device=device)
    xim = torch.zeros((wt.size(0), w01.size(0)),device=device)
    gamma2 = torch.zeros(wt.size(0), device=device)
    for i in range(wt.size(0)):
        ii = torch.argmax(torch.abs(wt[i,:]))
        xim[i,:] = w01[:,ii]*w02[i,:]
        gamma2[i] = wt[i,ii]
    wt = torch.einsum('ij,i->ij', wt, 1/gamma2)

    #M = torch.einsum('im,mj->ijm', w02, w01)-torch.einsum('im,ij->ijm', xim, wt)
    #MtM = torch.einsum('ijm,ijk->mk', M, M)  #-torch.einsum('ik,im->km', torch.einsum('im,mj,ij->im', w02, w01, wt), xim)
    MtM = -torch.einsum('ik,im->km',  w02*torch.einsum('mj,ij->im', w01, wt), xim)
    MtM = MtM + torch.transpose(MtM,0,1)
    MtM = MtM + torch.einsum('ik,im->km', w02, w02)*torch.einsum('kj,mj->km', w01, w01)
    MtM = MtM + torch.einsum('im,ik->mk', torch.einsum('i,im->im', torch.sum(wt**2,dim=(1)), xim), xim)
    del w01
    del wt
    bshift = torch.sum(w02,dim=(1)) #torch.matmul(torch.sum(w02,dim=2),bias1)
    del w02
    gamma1 = torch.linalg.solve(MtM[1:,1:], -MtM[1:,0])
    del MtM
    gamma1 = torch.cat((torch.tensor([1]), gamma1), 0)
    #gamma2 = gamma2/torch.einsum('ik,k->i',w02,x0*gamma1)
    gamma2 = gamma2/(torch.einsum('ik,k->i',xim,gamma1))
    if removal > 0:
        gamma2out = torch.zeros(nout)
        gamma2out[ind_keep] = gamma2
        gamma2 = gamma2out
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
        bshiftout = torch.zeros(nout)
        bshiftout[ind_keep] = bshift
        bshift = bshiftout
    bias1 = torch.ones(len(gamma1))*act_shift
    bias2 = bt-bshift*act_shift
    return gamma1, bias1, gamma2, bias2

#usually unstable solution unfortunately



def fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    x = torch.sum(torch.abs(wt),dim=(1))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:]
    removal = nout-wt.size(0)
    mold = w01.size(0)
    if removal > 0:
        mfull = wt.size(0)*(wt.size(1)-1)+1
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,]
        w02 = w02[:,:m]
        w02 = w02[ind_keep,:]

    MtM = torch.einsum('ik,im->km', w02, w02)*torch.einsum('kj,mj->km', w01, w01)
    z = torch.sum(w02*torch.einsum('ij,mj->im', wt, w01), dim=0)
    del w01
    del wt
    #bshift = torch.sum(w02,dim=(1)) #torch.matmul(torch.sum(w02,dim=2),bias1)
    del w02
    gamma1 = torch.linalg.solve(MtM, z)
    del MtM
    if removal > 0:
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
    return gamma1

def fc_1to2layer_project_gamma1(wt, w01, w02, gamma2, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #x = torch.sum(torch.abs(wt),dim=(1))
    #nout = wt.size(0)
    #ind_keep = torch.arange(wt.size(0))
    #ind_keep = ind_keep[x>0.0000001]
    #wt = wt[ind_keep,:]
    #removal = nout-wt.size(0)
    #if removal > 0:
        #mfull = wt.size(0)*(wt.size(1)-1)+1
        #m = min(mfull, w01.size(0))
        #w01 = w01[:m,:,]
        #w02 = w02[:,:m]
        #w02 = w02[ind_keep,:]
    #gamma1 = torch.zeros(w01.size(0))
    #m = min(w01.size(0),wt.size(0)*(wt.size(1)-1)+1)
    #w02 = w02[:,:m]
    #w01 = w01[:m,:]
    w02 = torch.einsum('i,im->im', gamma2, w02)
    MtM = torch.einsum('ik,im->km', w02, w02)*torch.einsum('kj,mj->km', w01, w01)
    z = torch.einsum('ij,im,mj->m', wt, w02, w01) #torch.sum(w02*torch.einsum('ij,mj->im', wt, w01), dim=(0))
    #del w01
    #del w02
    #del wt
    #gamma1[:m] = torch.linalg.solve(MtM, z)
    gamma1 = torch.linalg.solve(MtM, z)
    #print("linear system error: ", torch.sqrt(torch.mean((torch.matmul(MtM,gamma1[:m])-z)**2)))
    #print("linear system error: ", torch.sqrt(torch.mean((torch.matmul(MtM,gamma1)-z)**2)))
    #gam1 = torch.linalg.lstsq(MtM, z).solution
    #print("linear system error lstsq: ", torch.sqrt(torch.mean((torch.matmul(MtM,gam1)-z)**2)))
    #print("rank: ", torch.linalg.matrix_rank(MtM))
    #print("m: ", len(z))
    del MtM
    return gamma1

def fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #x = torch.sum(torch.abs(wt),dim=(1))
    #nout = wt.size(0)
    #ind_keep = torch.arange(wt.size(0))
    #ind_keep = ind_keep[x>0.0000001]
    #wt = wt[ind_keep,:]
    #removal = nout-wt.size(0)
    #if removal > 0:
        #mfull = wt.size(0)*(wt.size(1)-1)+1
        #m = min(mfull, w01.size(0))
        #w01 = w01[:m,:,]
        #w02 = w02[:,:m]
        #w02 = w02[ind_keep,:]
    w0 = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    del w01
    del w02
    del gamma1
    gamma2=torch.sum(w0*wt,dim=(1))/torch.sum(w0**2,dim=(1))
    return gamma2

def fc_1to2layer_project_iterative(wt, w01, w02, gamma1, gamma2, rep, device):
    for r in range(rep):
        gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1 = fc_1to2layer_project_gamma1(wt, w01, w02, gamma2, device)
    return gamma2, gamma1

def fc_1to2layer_exact_project_l2(wt, w01, w02, gamma2, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    x = torch.sum(torch.abs(wt),dim=(1))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:]
    removal = nout-wt.size(0)
    if removal > 0:
        mfull = wt.size(0)*(wt.size(1)-1)+1
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,]
        w02 = w02[:,:m]
        w02 = w02[ind_keep,:]

    M1 = torch.einsum('kj,mj->km', w01, w01)
    #M2 = M1*torch.einsum('im,ik,mk->imk', w02, w02, M1)
    M2 = torch.einsum('im,mj->ijm', w02, w01)
    #MtM = torch.einsum('ik,im->km', w02, w02)*torch.einsum('kj,mj->km', w01, w01)
    V = w02*torch.einsum('ij,mj->im', wt, w01)
    #gamma2=torch.ones(wt.size(0),device=device)
    del wt
    V2 = torch.einsum('im,mj->im', w02**2, w01**2)
    #del w02
    del w01
    for i in range(1):
        print("iterate: ", i)
        #z = torch.einsum('i,im->m', gamma2, V)
        gamma1 = torch.linalg.solve(torch.einsum('i,ik,im->km', gamma2**2, w02, w02)*M1, torch.einsum('i,im->m', gamma2, V))
        gamma2 = torch.matmul(V,gamma1)/torch.sum(torch.einsum('ijm,m->ij', M2, gamma1)**2) #(torch.einsum('k,m,ikm->i', gamma1, gamma1, M2)
    #del MtM
    return gamma1, gamma2

def fc_1to2layer_learn_GD_LBFGS_without_gamma2(wt, w01, w02, gamma1init, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    #precond = True #False #True #False

    #wt = torch.transpose(wt,0,1).to(device)
    def closure():
        opt.zero_grad()
        #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
        #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
        #wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #objective = torch.mean((wproxy-wt)**2)
        objective =torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
        #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
        #wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #ll=torch.sqrt(torch.mean((wproxy-wt)**2))
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    #print("error final: ", ll)
    #if precond:
    #    gamma1 = gamma1*S
    return gamma1

def fc_1to2layer_learn_GD_LBFGS_cond_expensive(wt, w01, w02, gamma1init, epochs=50, device="cpu", precond=False):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    #precond = True #False #True #False
    if precond:
        M = torch.einsum('ki,jk->ijk', w02, w01)
        #print(M.size())
        M = M.reshape((n2*n1,m))
        wt = torch.transpose(wt,0,1).to(device)
        #M = torch.transpose(M,0,1)
        #print(M.size())debb
        print("cond number M:")
        try:
            U, S, Vh = torch.linalg.svd(M)
            ss = S.size(0)
            condNumber = S[0]/S[-1]
            print(condNumber)
            gamma1in =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))#torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
            if torch.sum(torch.isnan(gamma1in))<=0:
                gamma1.data = gamma1in.detach().to(device)
                #gamma1.data = gamma1in.detach().to(device)
                C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
                M = M@C
            else:
                 C=torch.eye(m)
            print("error: ", torch.sqrt(torch.mean((M@gamma1-wt.reshape((-1,)))**2)))

        except:
            C=torch.eye(m)
            condNumber=torch.tensor([0])
        M = M.reshape((n2,n1,-1))
        # rij = torch.einsum('ijk,k->ij',M,gamma1)
        # gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)


        def closure():
            opt.zero_grad()
            rij = torch.einsum('ijk,k->ij',M,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy=torch.einsum('i,ijk,k->ij',gamma2,M,gamma1)
            objective = torch.mean((wproxy-wt)**2)
            objective.backward()
            return objective
        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            rij = torch.einsum('ijk,k->ij',M,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy=torch.einsum('i,ijk,k->ij',gamma2,M,gamma1)
            ll = torch.mean((wproxy-wt)**2) #torch.sqrt(torch.mean((M@gamma1-wt.reshape((-1,)))**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
        print("error cond: ", ll)
        gamma1.data = C@gamma1.data
        #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
        wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
        print("error start: ", torch.sqrt(torch.mean((wproxy-wt)**2)))
        wt = torch.transpose(wt,0,1).to(device)
        #conditioning
        #D=1/torch.sqrt(S)
        #w01 = w01*D
        #w02 = torch.einsum('ki,k->ki',w02,D)
        #gamma1.data = gamma1.data*S
        gamma2fixed = gamma2.data.clone().detach()
        w02=w02*gamma2fixed
        #wt = torch.transpose(wt,0,1).to(device)
        def closure():
            opt.zero_grad()
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            objective = torch.mean((wproxy-wt)**2) #torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
            objective.backward()
            return objective

        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            ll=torch.sqrt(torch.mean((wproxy-wt)**2))
            #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)

        #repeat conditioning with known gamma2
        M = torch.einsum('ki,jk->ijk', w02, w01)
        #print(M.size())
        M = M.reshape((n2*n1,m))
        wt = torch.transpose(wt,0,1).to(device)
        #M = torch.transpose(M,0,1)
        #print(M.size())debb
        print("cond number M:")
        try:
            U, S, Vh = torch.linalg.svd(M)
            ss = S.size(0)
            condNumber = S[0]/S[-1]
            print(condNumber)
            gamma1in =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))#torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
            if torch.sum(torch.isnan(gamma1in))<=0:
                gamma1.data = gamma1in.detach().to(device)
                C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
                M = M@C
            else:
                 C=torch.eye(m)
        except:
            C=torch.eye(m)
            condNumber=torch.tensor([0])
        err=torch.sqrt(torch.mean((M@gamma1-wt.reshape((-1,)))**2))
        print("error svd 2: ", err)


        M = M.reshape((n2,n1,-1))

        #finetune
        def closure():
            opt.zero_grad()
            #rij = torch.einsum('ijk,k->ij',M,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy=torch.einsum('i,ijk,k->ij',gamma2,M,gamma1)
            wproxy=torch.einsum('ijk,k->ij',M,gamma1)
            objective = torch.mean((wproxy-wt)**2)
            objective.backward()
            return objective
        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            #rij = torch.einsum('ijk,k->ij',M,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy=torch.einsum('i,ijk,k->ij',gamma2,M,gamma1)
            wproxy=torch.einsum('ijk,k->ij',M,gamma1)
            ll = torch.mean((wproxy-wt)**2) #torch.sqrt(torch.mean((M@gamma1-wt.reshape((-1,)))**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
        print("error cond 2: ", ll)
        gamma1.data = C@gamma1.data
        #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
        wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        print("error start 2: ", torch.sqrt(torch.mean((wproxy-wt)**2)))
        #wt = torch.transpose(wt,0,1).to(device)

        #finetune gamma1
        def closure():
            opt.zero_grad()
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            objective = torch.mean((wproxy-wt)**2) #torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
            objective.backward()
            return objective

        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            ll=torch.sqrt(torch.mean((wproxy-wt)**2))
            #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
        print("error finetune gamma1: ", ll)
        #finetune gamma1 and gamma2
        w02 = w02/gamma2fixed
        def closure():
            opt.zero_grad()
            rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            #wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            objective = torch.mean((wproxy-wt)**2) #torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
            objective.backward()
            return objective

        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            #wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            ll=torch.sqrt(torch.mean((wproxy-wt)**2))
            #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)

        print("error finetune gamma1 and gamma2: ", ll)
    else:
        wt = torch.transpose(wt,0,1).to(device)
        def closure():
            opt.zero_grad()
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            objective = torch.mean((wproxy-wt)**2) #torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
            objective.backward()
            return objective

        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            ll=torch.sqrt(torch.mean((wproxy-wt)**2))
            #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
        gamma2 = torch.ones(n2)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    print("error final: ", ll)
    #if precond:
    #    gamma1 = gamma1*S
    return gamma1, gamma2, condNumber


def fc_1to2layer_learn_GD_LBFGS_implicit_gamma2(wt, w01, w02, gamma1init, epochs=50, device="cpu", precond=False):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    #precond = True #False #True #False
    if precond:
        M = torch.einsum('ki,jk->ijk', w02, w01)
        #print(M.size())
        M = M.reshape((n2*n1,m))
        wt = torch.transpose(wt,0,1).to(device)
        #M = torch.transpose(M,0,1)
        #print(M.size())debb
        print("cond number M:")
        try:
            U, S, Vh = torch.linalg.svd(M)
            ss = S.size(0)
            condNumber = S[0]/S[-1]
            print(condNumber)
            gamma1in =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))#torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
            gamma1.data = gamma1in.detach().to(device)
            C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
            M = M@C
        except:
            C=torch.eye(m)
            condNumber=torch.tensor([0])
        print("error: ", torch.sqrt(torch.mean((M@gamma1-wt.reshape((-1,)))**2)))
        M = M.reshape((n2,n1,-1))

        def closure():
            opt.zero_grad()
            rij = torch.einsum('ijk,k->ij',M,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy=torch.einsum('i,ijk,k->ij',gamma2,M,gamma1)
            objective = torch.mean((wproxy-wt)**2)
            objective.backward()
            return objective
        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            rij = torch.einsum('ijk,k->ij',M,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy=torch.einsum('i,ijk,k->ij',gamma2,M,gamma1)
            ll = torch.mean((wproxy-wt)**2) #torch.sqrt(torch.mean((M@gamma1-wt.reshape((-1,)))**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
        print("error cond: ", ll)
        gamma1.data = C@gamma1.data
        #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
        #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
        wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
        print("error start: ", torch.sqrt(torch.mean((wproxy-wt)**2)))
        wt = torch.transpose(wt,0,1).to(device)
        #conditioning
        #D=1/torch.sqrt(S)
        #w01 = w01*D
        #w02 = torch.einsum('ki,k->ki',w02,D)
        #gamma1.data = gamma1.data*S

        wt = torch.transpose(wt,0,1).to(device)
        def closure():
            opt.zero_grad()
            rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            #wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            objective = torch.mean((wproxy-wt)**2) #torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
            objective.backward()
            return objective

        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            #wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            ll=torch.sqrt(torch.mean((wproxy-wt)**2))
            #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
    else:
        wt = torch.transpose(wt,0,1).to(device)
        def closure():
            opt.zero_grad()
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            objective = torch.mean((wproxy-wt)**2) #torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
            objective.backward()
            return objective

        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        #history_lbfgs = []
        for i in range(epochs):
            #rij = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            #gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #wproxy = torch.einsum('i,ki,jk,k->ij',gamma2,w02,w01,gamma1)
            wproxy = torch.einsum('ki,jk,k->ij',w02,w01,gamma1)
            ll=torch.sqrt(torch.mean((wproxy-wt)**2))
            #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
            #history_lbfgs.append(ll.item())
            opt.step(closure)
        gamma2 = torch.ones(n2)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    print("error final: ", ll)
    #if precond:
    #    gamma1 = gamma1*S
    return gamma1, gamma2, condNumber

def project_perm_mat(mat):
    n = mat.size(0)
    avail = torch.arange(n)
    for i in range(n):
        pos_max = torch.argmax(mat[i,avail]).item()
        ind = avail[pos_max].item()
        mat.data[i,:]= 0
        mat.data[i,ind] = 1
        avail = torch.cat([avail[:pos_max], avail[(pos_max+1):]]) #avail[-pos_max]
    return mat

def fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_tune(reg, epochsPermute, wt, w01, w02, gamma1init, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device).clone()
    perm = torch.ones((n2,n2), requires_grad=True, device=device)
    perm.data = torch.randn((n2,n2))
    #torch.abs(torch.randn((n2,n2), requires_grad=True, device=device))
    #reg = 0.000000001 #0.0000001
    #normalize rows and columns
    perm.data = torch.randn((n2,n2))
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
    perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
    perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return pen#pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    pen_norm = penalty(perm.data)
    w_norm = torch.mean((torch.matmul(torch.matmul(gamma1*w01,w02),perm/pen_norm)-wt)**2).detach().item()
    #print("weight norm: " + str(w_norm))
    pen_norm = pen_norm/max(w_norm, 0.1)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return max(pen*0.1/pen_norm,0.0) #pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    # def penalty(M):
    #     pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
    #     return pen/(n2*n2)

    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(torch.matmul(gamma1*w01,w02),perm)-wt)**2) + reg*penalty(perm)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1, perm], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    permOld = perm.clone()
    for i in range(epochsPermute): #range(15): #range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
        if torch.sum(torch.isnan(gamma1))>0:
            gamma1.data = gamma1init.detach().to(device).clone()
        if torch.sum(torch.isnan(perm))>0:
            #print("nan produced")
            perm.data = permOld.data
            #print(perm.data)
        #permBefore = perm.clone()
        perm.data = torch.clamp(perm.data, min=0)
        #perm.data = torch.clamp(perm.data, min=0, max=10)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[j,:] = 1.0
        perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[:,j] = 1.0
        #sumVec[sumVec==0] = np.sqrt(n2)
        #perm.data[:,sumVec==0] = 1.0
        perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
        #print(perm)
        #if torch.sum(torch.condNumber=torch.tensor([0])n(perm))>0:
        #    print("nan after")
        #    print(torch.clamp(permBefore.data, min=0))
        #    perm.data = permOld.data
        permOld = perm.clone()

        #print(perm)
    #fix permutation matrix and retrain
    #project to closest permutation matrix
    perm = project_perm_mat(perm)
    #retrain
    with torch.no_grad():
        w02.data = torch.matmul(w02.data,perm.data)
    #gamma = gamma1.clone()
    #gamma1 = torch.ones(m, requires_grad=True, device=device)
    #gamma1.data = gamma.data #.clone()
    gamma1.data = gamma1init.detach().to(device).clone()
    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        objective.backward()
        return objective

    for i in range(epochs):
        opt.step(closure)
        #print(torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)))
    loss = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
    print("Res: " + str(reg) + " " + str(epochsPermute) + ": " + str(loss.item()))
    return gamma1, torch.transpose(perm,0,1), loss

def fc_1to2layer_learn_GD_LBFGS_permute_tune(reg, epochsPermute, wt, w01, w02, gamma1init, gamma2init, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device).clone()
    gamma2 = torch.ones(n2, requires_grad=True, device=device)
    gamma2.data = gamma2init.detach().to(device).clone()
    perm = torch.ones((n2,n2), requires_grad=True, device=device)
    perm.data = torch.randn((n2,n2))
    #torch.abs(torch.randn((n2,n2), requires_grad=True, device=device))
    #reg = 0.000000001 #0.0000001
    #normalize rows and columns
    perm.data = torch.randn((n2,n2))
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
    perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
    perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return pen#pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    pen_norm = penalty(perm.data)
    w_norm = torch.mean((torch.matmul(torch.matmul(gamma1*w01,torch.einsum('ji,i->ji', w02, gamma2)),perm/pen_norm)-wt)**2).detach().item()
    #print("weight norm: " + str(w_norm))
    #w_norm = max(w_norm,1)
    pen_norm = pen_norm/max(w_norm, 1.0)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return max(pen*0.1/pen_norm,0.0) #pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    # def penalty(M):
    #     pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
    #     return pen/(n2*n2)

    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(torch.matmul(gamma1*w01,torch.einsum('ji,i->ji', w02, gamma2)),perm)-wt)**2) + reg*penalty(perm)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1, gamma2, perm], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    permOld = perm.clone()
    for i in range(epochsPermute): #range(15): #range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
        if torch.sum(torch.isnan(gamma1))>0:
            gamma1.data = gamma1init.detach().to(device).clone()
        if torch.sum(torch.isnan(gamma2))>0:
            #gamma2.data = gamma2init.detach().to(device).clone()
            #gamma2.data = torch.randn((n2))
            #gamma2.data = gamma2.data/torch.sqrt(torch.sum(gamma2.data**2))
            gamma2.data = torch.ones(n2)/np.sqrt(n2)
        if torch.sum(torch.isnan(perm))>0:
            #print("nan produced")
            perm.data = permOld.data
            #print(perm.data)
        #permBefore = perm.clone()
        perm.data = torch.clamp(perm.data, min=0)
        #perm.data = torch.clamp(perm.data, min=0, max=10)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[j,:] = 1.0
        perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[:,j] = 1.0
        #sumVec[sumVec==0] = np.sqrt(n2)
        #perm.data[:,sumVec==0] = 1.0
        perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
        #print(perm)
        #if torch.sum(torch.isnan(perm))>0:
        #    print("nan after")
        #    print(torch.clamp(permBefore.data, min=0))
        #    perm.data = permOld.data
        permOld = perm.clone()

        #print(perm)
    #fix permutation matrix and retrain
    #project to closest permutation matrix
    perm = project_perm_mat(perm)
    #retrain
    with torch.no_grad():
        w02.data = torch.matmul(w02.data,perm.data)
    #gamma = gamma1.clone()
    #gamma1 = torch.ones(m, requires_grad=True, device=device)
    #gamma1.data = gamma.data #.clone()
    #gamma1.data = gamma1init.detach().to(device).clone()
    gamma2.data = torch.ones(n2)/np.sqrt(n2)
    gamma1.data = torch.randn(m)
    gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
    #print(gamma2.data)
    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,torch.einsum('ji,i->ji', w02, gamma2))-wt)**2)
        objective.backward()
        return objective
    for i in range(epochs):
        opt.step(closure)

    #gamma2.data = torch.ones(n2)/np.sqrt(n2)
    opt = torch.optim.LBFGS([gamma1, gamma2], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,torch.einsum('ji,i->ji', w02, gamma2))-wt)**2)
        objective.backward()
        return objective

    for i in range(epochs):
        opt.step(closure)
        #print(torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)))
    loss = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
    print("Res: " + str(reg) + " " + str(epochsPermute) + ": " + str(loss.item()))
    return gamma2, gamma1, torch.transpose(perm,0,1), loss

def fc_1to2layer_learn_GD_LBFGS_exact_permute_tune(reg, epochsPermute, wt, w01, w02, gamma1init, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=False, device=device)
    gamma1.data = gamma1init.detach().to(device).clone()
    gamma2 = torch.ones(n2, requires_grad=False, device=device)/torch.sqrt(n2)
    perm = torch.ones((n2,n2), requires_grad=True, device=device)
    perm.data = torch.randn((n2,n2))
    #torch.abs(torch.randn((n2,n2), requires_grad=True, device=device))
    #reg = 0.000000001 #0.0000001
    #normalize rows and columns
    perm.data = torch.randn((n2,n2))
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
    perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
    perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return pen#pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    pen_norm = penalty(perm.data)
    w_norm = torch.mean((torch.matmul(torch.matmul(gamma1*w01,einsum('ji,i->ji', w02, gamma2)),perm/pen_norm)-wt)**2).detach().item()
    #print("weight norm: " + str(w_norm))
    pen_norm = pen_norm/max(w_norm, 0.1)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return max(pen*0.1/pen_norm,0.0) #pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    # def penalty(M):
    #     pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
    #     return pen/(n2*n2)

    def closure():
        with torch.no_grad():
            gamma1
            w01
            w02
            gamma2
        opt.zero_grad()
        objective = torch.mean((torch.matmul(torch.matmul(gamma1*w01,torch,einsum('ji,i->ji', w02, gamma2)),perm)-wt)**2) + reg*penalty(perm)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([perm], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    permOld = perm.clone()
    for i in range(epochsPermute): #range(15): #range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
        if torch.sum(torch.isnan(perm))>0:
            #print("nan produced")
            perm.data = permOld.data
            #print(perm.data)
        #permBefore = perm.clone()
        perm.data = torch.clamp(perm.data, min=0)
        #perm.data = torch.clamp(perm.data, min=0, max=10)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[j,:] = 1.0
        perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[:,j] = 1.0
        #sumVec[sumVec==0] = np.sqrt(n2)
        #perm.data[:,sumVec==0] = 1.0
        perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
        #print(perm)
        #if torch.sum(torch.isnan(perm))>0:
        #    print("nan after")
        #    print(torch.clamp(permBefore.data, min=0))
        #    perm.data = permOld.data
        permOld = perm.clone()
        #update gamma1 and gamma2
        w02
        gamma1, _, gamma2, _ = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
        gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)

        _, _, gamma1, gamma2 = proxy_target_init_flexible()

        #print(perm)
    #fix permutation matrix and retrain
    #project to closest permutation matrix
    perm = project_perm_mat(perm)
    #retrain
    with torch.no_grad():
        w02.data = torch.matmul(w02.data,perm.data)
    #gamma = gamma1.clone()
    #gamma1 = torch.ones(m, requires_grad=True, device=device)
    #gamma1.data = gamma.data #.clone()
    #gamma1.data = gamma1init.detach().to(device).clone()
    opt = torch.optim.LBFGS([gamma1, gamma2], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,torch,einsum('ji,i->ji', w02, gamma2))-wt)**2)
        objective.backward()
        return objective

    for i in range(epochs):
        opt.step(closure)
        #print(torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)))
    loss = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
    print("Res: " + str(reg) + " " + str(epochsPermute) + ": " + str(loss.item()))
    return gamma1, gamma2, torch.transpose(perm,0,1), loss



def fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_learn_tune(epochsPermute, wt, w01, w02, gamma1init, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device).clone()
    perm = torch.ones((n2,n2), requires_grad=True, device=device)
    #normalize rows and columns
    perm.data = torch.randn((n2,n2))
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
    perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
    sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
    perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
    #torch.abs(torch.randn((n2,n2), requires_grad=True, device=device))
    #reg = 0.000000001 #0.0000001
    reg = torch.ones(1, requires_grad=True, device=device)
    #reg.data = torch.tensor([10**(-7)])
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return pen#pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))
    pen_norm = penalty(perm.data)
    w_norm = torch.mean((torch.matmul(torch.matmul(gamma1*w01,w02),perm/pen_norm)-wt)**2).detach().item()
    #print("weight norm: " + str(w_norm))
    pen_norm = pen_norm/max(w_norm, 0.1)
    def penalty(M):
        pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
        return max(pen*0.1/pen_norm,0.0) #pen*(10*(-6)/(n2*n2)) #pen*(10*(-5)/(n2*n2)) #pen*(10*(-6)/(n2*n2))

    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(torch.matmul(gamma1*w01,w02),perm)-wt)**2) + reg*penalty(perm)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1, perm, reg], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    permOld = perm.clone()
    for i in range(epochsPermute): #range(15): #range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
        if torch.sum(torch.isnan(gamma1))>0:
            gamma1.data = gamma1init.detach().to(device).clone()
        if torch.sum(torch.isnan(perm))>0:
            #print("nan produced")
            perm.data = permOld.data
        if torch.isnan(reg)>0:
            reg.data = torch.tensor([1.0])
            #print("na reg")
        reg.data = torch.tensor([max(reg.data, 0.01)])
            #print(perm.data)
        #permBefore = perm.clone()
        perm.data = torch.clamp(perm.data, min=0)
        #perm.data = torch.clamp(perm.data, min=0, max=10)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[j,:] = 1.0
        perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
        sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
        ind = torch.arange(n2)
        ind = ind[sumVec==0]
        #sumVec[sumVec==0] = 1
        if len(ind) > 0:
            sumVec[ind] = np.sqrt(n2)
            for j in ind:
                perm.data[:,j] = 1.0
        #sumVec[sumVec==0] = np.sqrt(n2)
        #perm.data[:,sumVec==0] = 1.0
        perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
        #print(perm)
        #if torch.sum(torch.isnan(perm))>0:
        #    print("nan after")
        #    print(torch.clamp(permBefore.data, min=0))
        #    perm.data = permOld.data
        permOld = perm.clone()

        #print(perm)
    #fix permutation matrix and retrain
    #project to closest permutation matrix
    perm = project_perm_mat(perm)
    #print("reg: " + str(reg))
    #retrain
    with torch.no_grad():
        w02.data = torch.matmul(w02.data,perm.data)
    #gamma = gamma1.clone()
    #gamma1 = torch.ones(m, requires_grad=True, device=device)
    #gamma1.data = gamma.data #.clone()
    gamma1.data = gamma1init.detach().to(device).clone()
    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        objective.backward()
        return objective

    for i in range(epochs):
        opt.step(closure)
        #print(torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)))
    loss = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
    print("Res: " + str(reg.detach().item()) + " " + str(epochsPermute) + ": " + str(loss.item()))
    return gamma1, torch.transpose(perm,0,1), loss


def fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute(wt, w01, w02, gamma1init, epochs=50, device="cpu", njobs=4):
    lossMin = 1000
    #epochs_exp = [50, 50, 50, 100, 100, 100, 150, 150, 150] #[20, 21, 22, 23, 24, 25, 26, 27, 30, 35, 50, 100, 150, 200]
    epochs_exp = [20, 21, 22, 23, 24, 25, 26, 27, 30, 35, 50, 100, 150, 200]
    #reg_exp = [0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001]
    reg_exp = [50, 15, 10, 5, 1, 0.1, 0.05, 0.01, 0.001, 0.0001]
    #reg_exp = [0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1]
    #epochs_exp = [20, 21, 22, 23, 24, 25, 26, 27, 30, 50, 150, 200]
    #reg_exp = [0.000001, 0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001]
    param = list(itertools.product(reg_exp, epochs_exp))
    #parallelize
    results = Parallel(n_jobs=njobs)(delayed(fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_learn_tune)(epochs_exp[ind], wt, w01, w02, gamma1init, epochs, device) for ind in range(len(epochs_exp)))
    #results = Parallel(n_jobs=njobs)(delayed(fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_tune)(param[ind][0], param[ind][1], wt, w01, w02, gamma1init, epochs, device) for ind in range(len(param)))
    for i in range(len(results)):
        if results[i][2].item() < lossMin:
            gamma1 = results[i][0].clone()
            perm = results[i][1].clone()
            lossMin = results[i][2].item()
    # for reg in reg_exp:
    #     for ep in epochs_exp:
    #         w02act = w02.clone()
    #         gamma1loc, permloc, loss = fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_tune(reg, ep, wt, w01, w02act, gamma1init, epochs, device)
    #         if loss < lossMin:
    #             gamma1 = gamma1loc.clone()
    #             perm = permloc.clone()
    #             lossMin = loss.item()
    #         print("------")
    #         print(str(reg) + " " + str(ep) + ": "+ str(loss))
    #         print(lossMin)
            #print(perm)
    return gamma1, perm

def fc_1to2layer_learn_GD_LBFGS_permute(wt, w01, w02, gamma1init, gamma2init, epochs=50, device="cpu", njobs=4):
    lossMin = 50000000
    #epochs_exp = [50, 50, 50, 100, 100, 100, 150, 150, 150] #[20, 21, 22, 23, 24, 25, 26, 27, 30, 35, 50, 100, 150, 200]
    epochs_exp = [20, 21, 22, 23, 24, 25, 26, 27, 30, 35, 50, 100, 150, 200]
    #epochs_exp = [10, 20, 100]
    #reg_exp = [0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001]
    reg_exp = [50, 15, 10, 5, 1, 0.1, 0.05, 0.01, 0.001, 0.0001]
    #reg_exp = [0.01, 0.1, 1, 10]
    #reg_exp = [0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1]
    #epochs_exp = [20, 21, 22, 23, 24, 25, 26, 27, 30, 50, 150, 200]
    #reg_exp = [0.000001, 0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001]
    param = list(itertools.product(reg_exp, epochs_exp))
    #parallelize
    #results = Parallel(n_jobs=njobs)(delayed(fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_learn_tune)(epochs_exp[ind], wt, w01, w02, gamma1init, epochs, device) for ind in range(len(epochs_exp)))
    #print(fc_1to2layer_learn_GD_LBFGS_permute_tune(param[0][0], param[0][1], wt, w01, w02, gamma1init, gamma2init, epochs, device))
    results = Parallel(n_jobs=njobs)(delayed(fc_1to2layer_learn_GD_LBFGS_permute_tune)(param[ind][0], param[ind][1], wt, w01, w02, gamma1init, gamma2init, epochs, device) for ind in range(len(param)))
    for i in range(len(results)):
        if results[i][3].item() < lossMin:
            gamma2 = results[i][0].clone()
            gamma1 = results[i][1].clone()
            perm = results[i][2].clone()
            lossMin = results[i][3].item()
    # for reg in reg_exp:
    #     for ep in epochs_exp:
    #         w02act = w02.clone()
    #         gamma1loc, permloc, loss = fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute_tune(reg, ep, wt, w01, w02act, gamma1init, epochs, device)
    #         if loss < lossMin:
    #             gamma1 = gamma1loc.clone()
    #             perm = permloc.clone()
    #             lossMin = loss.item()
    #         print("------")
    #         print(str(reg) + " " + str(ep) + ": "+ str(loss))
    #         print(lossMin)
            #print(perm)
    return gamma2, gamma1, perm


def rand_permute_parallel(w02, func, nperm=12, njobs=4):
    ss = random.randint(0, 500)
    n2 = w02.size(0)
    lossMin=1000000
    if len(w02.size()) > 2:
        def ff(seed):
            random.seed(seed)
            if seed < 0.5:
                ind = torch.arange(n2)
            else:
                ind = torch.randperm(n2)
            try:
                out = func(w02[ind,:,:,:])
            except:
                out = 10^10
            #out = func(w02[ind,:,:,:])
            #print(out[0])
            return out, ind
    else:
        def ff(seed):
            random.seed(seed)
            if seed < 0.5:
                ind = torch.arange(n2)
            else:
                ind = torch.randperm(n2)
            ind = torch.randperm(n2)
            #out = func(w02[ind,:])
            #print(out[0])
            try:
                out = func(w02[ind,:])
            except:
                out = 10^10
            return out, ind
    results = Parallel(n_jobs=njobs)(delayed(ff)(ii+ss) for ii in range(nperm))
    for i in range(len(results)):
        if results[i][0][0] < lossMin:
            out = results[i]
            lossMin = results[i][0][0]
    return out


# def fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute(wt, w01, w02, gamma1init, epochs=50, device="cpu"):
#     #Inputs:
#     #wt: target weights
#     #bt: target biases
#     #w01 and w02: (random) weight tensors in network with adjusted BN parametes
#     #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
#     #Outputs:
#     #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
#     #later problem: take standard deviation and mean of BN statistics based on data batches into account
#     #probably always work with effective parameters
#     with torch.no_grad():
#         w01 = torch.transpose(w01,0,1).to(device)
#         w02 = torch.transpose(w02,0,1).to(device)
#         wt = torch.transpose(wt,0,1).to(device)
#     n2, n1 = wt.size(1), wt.size(0)
#     m = w01.size(1)
#     gamma1 = torch.ones(m, requires_grad=True, device=device)
#     gamma1.data = gamma1init.detach().to(device)
#     perm = torch.ones((n2,n2), requires_grad=True, device=device)
#     perm.data = torch.randn((n2,n2))
#     #torch.abs(torch.randn((n2,n2), requires_grad=True, device=device))
#     reg = 0.000000001 #0.0000001
#     def penalty(M):
#         pen = 2*torch.sum(torch.abs(M)) - (torch.sum(torch.sqrt(torch.sum(M**2,dim=0)))+ torch.sum(torch.sqrt(torch.sum(M**2,dim=1))))
#         return pen/(n2*n2)
#
#     def closure():
#         opt.zero_grad()
#         objective = torch.mean((torch.matmul(torch.matmul(gamma1*w01,w02),perm)-wt)**2) + reg*penalty(perm)
#         objective.backward()
#         return objective
#
#     opt = torch.optim.LBFGS([gamma1, perm], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
#     #history_lbfgs = []
#     permOld = perm.clone()
#     for i in range(40): #range(15): #range(epochs):
#         #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
#         #history_lbfgs.append(ll.item())
#         opt.step(closure)
#         if torch.sum(torch.isnan(perm))>0:
#             print("nan produced")
#             perm.data = permOld.data
#             #print(perm.data)
#         #permBefore = perm.clone()
#         perm.data = torch.clamp(perm.data, min=0)
#         #perm.data = torch.clamp(perm.data, min=0, max=10)
#         sumVec = torch.sqrt(torch.sum(perm.data**2,dim=1))
#         ind = torch.arange(n2)
#         ind = ind[sumVec==0]
#         #sumVec[sumVec==0] = 1
#         if len(ind) > 0:
#             sumVec[ind] = np.sqrt(n2)
#             for j in ind:
#                 perm.data[j,:] = 1.0
#         perm.data = torch.einsum('ij,i->ij', perm.data, 1/sumVec)
#         sumVec = torch.sqrt(torch.sum(perm.data**2,dim=0))
#         ind = torch.arange(n2)
#         ind = ind[sumVec==0]
#         #sumVec[sumVec==0] = 1
#         if len(ind) > 0:
#             sumVec[ind] = np.sqrt(n2)
#             for j in ind:
#                 perm.data[:,j] = 1.0
#         #sumVec[sumVec==0] = np.sqrt(n2)
#         #perm.data[:,sumVec==0] = 1.0
#         perm.data = torch.einsum('ij,j->ij', perm.data, 1/sumVec)
#         print(perm)
#         #if torch.sum(torch.isnan(perm))>0:
#         #    print("nan after")
#         #    print(torch.clamp(permBefore.data, min=0))
#         #    perm.data = permOld.data
#         permOld = perm.clone()
#         #print(perm)
#     #fix permutation matrix and retrain
#     #project to closest permutation matrix
#     perm = project_perm_mat(perm)
#     #retrain
#     with torch.no_grad():
#         w02 = torch.matmul(w02,perm)
#     #gamma = gamma1.clone()
#     #gamma1 = torch.ones(m, requires_grad=True, device=device)
#     #gamma1.data = gamma.data #.clone()
#     opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
#     #gamma1.data = gamma1init.detach().to(device)
#     def closure():
#         opt.zero_grad()
#         objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
#         objective.backward()
#         return objective
#
#     for i in range(epochs):
#         opt.step(closure)
#     #print("loss: " + str(torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))))
#     return gamma1, torch.transpose(perm,0,1)

def fc_1to2layer_learn_GD_LBFGS_without_gamma2_prune_narrow(mt, wt, w01, w02, gamma1init, epochs=50, epochsFinetune=10, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    gamma_loc = gamma1.clone()
    for k in range(m-1,mt-1,-1):
        gamma1 = torch.ones(k, requires_grad=True, device=device)
        gamma1.data = gamma_loc[:k].detach().to(device)
        w01 = w01[:,:k]
        w02 = w02[:k,:]
        # def closure():
        #     opt.zero_grad()
        #     objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        #     objective.backward()
        #     return objective
        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        for i in range(epochsFinetune):
            opt.step(closure)
        gamma_loc = gamma1.clone()
        print("k ", k, ": ", torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)))
    # plt.semilogy(history_lbfgs, label='L-BFGS')
    # plt.legend()
    # plt.show()
    return gamma1

def delete(arr: torch.Tensor, ind: int, dim: int) -> torch.Tensor:
    skip = [i for i in range(arr.size(dim)) if i != ind]
    indices = [slice(None) if i != dim else skip for i in range(arr.ndim)]
    return arr.__getitem__(indices)

def fc_1to2layer_learn_GD_LBFGS_without_gamma2_prune_mag(mt, wt, w01, w02, gamma1init, epochs=50, epochsFinetune=10, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    gamma_loc = gamma1.clone()
    for k in range(m-1,mt-1,-1):
        gamma1 = torch.ones(k, requires_grad=True, device=device)
        ind = torch.argmin(torch.abs(gamma_loc)).item()
        ii = torch.cat([torch.arange(ind),torch.arange(ind+1,k+1)])
        gamma1.data = gamma_loc[ii].detach().to(device)
        w01 = w01[:,ii]
        w02 = w02[ii,:]
        # def closure():
        #     opt.zero_grad()
        #     objective = torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)
        #     objective.backward()
        #     return objective
        opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
        for i in range(epochsFinetune):
            opt.step(closure)
        gamma_loc = gamma1.clone()
        print("k ", k, ": ", torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)-wt)**2)))
    # plt.semilogy(history_lbfgs, label='L-BFGS')
    # plt.legend()
    # plt.show()
    return gamma1, torch.transpose(w01,0,1).to(device), torch.transpose(w02,0,1).to(device)

def fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(wt, w01, w02, gamma1init, feat_add, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
        feat_add = feat_add.to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    add_dim = feat_add.size(0)
    gamma1 = torch.zeros(m+add_dim, requires_grad=True, device=device)
    gamma1.data[:m] = gamma1init.detach().to(device)
    #gamma_add = torch.ones(add_dim, requires_grad=True, device=device) #torch.zeros(add_dim, requires_grad=True, device=device) #torch.ones(add_dim, requires_grad=True, device=device) #
    #gamma_add.data = gamma1init[:add_dim].detach().to(device)
    #opt = torch.optim.LBFGS([gamma1, gamma_add], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1[:m]*w01,w02)+torch.einsum('m,mij->ji', gamma1[m:], feat_add)-wt)**2)
        objective.backward()
        return objective

    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sqrt(torch.mean((torch.matmul(gamma1[:m]*w01,w02)+torch.einsum('m,mij->ji', gamma1[m:], feat_add)-wt)**2))
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    #print("error extended: " + str(torch.sqrt(torch.mean((torch.matmul(gamma1[:m]*w01,w02)+torch.einsum('m,mij->ji', gamma1[m:], feat_add)-wt)**2)).detach()))
    return gamma1[:m], gamma1[m:]

def fc_1to2layer_learn_GD_LBFGS_wider(wt, w01, w02, feat_add, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = w01.to(device) #torch.transpose(w01,0,1).to(device)
        w02 = w02.to(device) #torch.transpose(w02,0,1).to(device)
        wt = wt.to(device) #torch.transpose(wt,0,1).to(device)
        feat_add = feat_add.to(device)
    n1, n2 = wt.size(1), wt.size(0)
    m = w01.size(0)
    add_dim = feat_add.size(0)
    #gamma1 = torch.zeros(m+add_dim, requires_grad=True, device=device)
    #gamma1.data[:m] = gamma1init.detach().to(device)
    #gamma1 = torch.randn(m+add_dim, requires_grad=True, device=device)
    #gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
    gamma1 = torch.zeros(m+add_dim, requires_grad=True, device=device)
    gamma1.data[:m] = torch.randn(m,device=device)
    gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
    gamma2 = torch.ones(n2, requires_grad=True, device=device) #torch.randn(n2, requires_grad=True, device=device)
    gamma2.data = gamma2.data/torch.sqrt(torch.sum(gamma2.data**2))
    #gamma_add = torch.ones(add_dim, requires_grad=True, device=device) #torch.zeros(add_dim, requires_grad=True, device=device) #torch.ones(add_dim, requires_grad=True, device=device) #
    #gamma_add.data = gamma1init[:add_dim].detach().to(device)
    #opt = torch.optim.LBFGS([gamma1, gamma_add], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    opt = torch.optim.LBFGS([gamma1, gamma2], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.einsum('i,m,im,mj->ij', gamma2, gamma1[:m], w02, w01)+torch.einsum('i,m,mij->ij', gamma2, gamma1[m:], feat_add)-wt)**2)
        objective.backward()
        return objective

    for i in range(epochs):
        opt.step(closure)
    return gamma2, gamma1[:m], gamma1[m:]


# def fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(wt, w01, w02, gamma1init, feat_add, epochs=50, device="cpu"):
#     #Inputs:
#     #wt: target weights
#     #bt: target biases
#     #w01 and w02: (random) weight tensors in network with adjusted BN parametes
#     #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
#     #Outputs:
#     #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
#     #later problem: take standard deviation and mean of BN statistics based on data batches into account
#     #probably always work with effective parameters
#     with torch.no_grad():
#         w01 = torch.transpose(w01,0,1).to(device)
#         w02 = torch.transpose(w02,0,1).to(device)
#         wt = torch.transpose(wt,0,1).to(device)
#     n2, n1 = wt.size(1), wt.size(0)
#     m = w01.size(1)
#     gamma1 = torch.ones(m, requires_grad=True, device=device)
#     gamma1.data = gamma1init.detach().to(device)
#     add_dim = feat_add.size(0)
#     gamma_add = torch.ones(add_dim, requires_grad=True, device=device) #torch.zeros(add_dim, requires_grad=True, device=device) #torch.ones(add_dim, requires_grad=True, device=device) #
#     #gamma_add.data = gamma1init[:add_dim].detach().to(device)
#     def closure():
#         opt.zero_grad()
#         objective = torch.mean((torch.matmul(gamma1*w01,w02)+torch.einsum('m,mij->ji', gamma_add, feat_add)-wt)**2)
#         objective.backward()
#         return objective
#
#     opt = torch.optim.LBFGS([gamma1, gamma_add], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
#     history_lbfgs = []
#     for i in range(epochs):
#         ll = torch.sqrt(torch.mean((torch.matmul(gamma1*w01,w02)+torch.einsum('m,mij->ji', gamma_add, feat_add)-wt)**2))
#         history_lbfgs.append(ll.item())
#         opt.step(closure)
#     plt.semilogy(history_lbfgs, label='L-BFGS')
#     plt.legend()
#     plt.show()
#     return gamma1, gamma_add

def fc_1to2layer_learn_GD_LBFGS_feat(wt, feat_add, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        #w01 = torch.transpose(w01,0,1).to(device)
        #w02 = torch.transpose(w02,0,1).to(device)
        wt = wt.to(device) #torch.transpose(wt,0,1).to(device)
        feat_add = feat_add.to(device)
    n1, n2 = wt.size(1), wt.size(0)
    m = feat_add.size(0)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = torch.randn(m)#*0.3
    gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
    #torch.ones(add_dim, requires_grad=True, device=device) #
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.einsum('m,mij->ij', gamma1, feat_add)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    history_lbfgs = []
    for i in range(epochs):
        ll = torch.sqrt(torch.mean((torch.einsum('m,mij->ij', gamma1, feat_add)-wt)**2))
        history_lbfgs.append(ll.item())
        opt.step(closure)
    plt.semilogy(history_lbfgs, label='L-BFGS')
    plt.legend()
    plt.show()
    return gamma1


def fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1init, gamma2init, epochs=50, device="cpu"):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    with torch.no_grad():
        w01 = torch.transpose(w01,0,1).to(device)
        w02 = torch.transpose(w02,0,1).to(device)
        wt = torch.transpose(wt,0,1).to(device)
    n2, n1 = wt.size(1), wt.size(0)
    m = w01.size(1)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma2 = torch.ones(n2, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    gamma2.data = gamma2init.detach().to(device)
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.matmul(gamma1*w01,gamma2*w02)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1,gamma2], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sum((torch.matmul(gamma1*w01,gamma2*w02)-wt)**2)/(n1*n2)
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    bias1 = torch.ones(m,device=device)*act_shift
    bias2 = bt.to(device)-torch.matmul(torch.transpose(w02,0,1),bias1)
    return gamma1, bias1, gamma2, bias2

# convolutional networks

def conv_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device):
    #Inputs:
    #wt: conv target tensor of dimension n2xn1xk1Xk2
    #bt: conv target biases
    #w01 and w02: (random) weight tensors in network for which we adjust the BN parametes, assume that w02 is of shape n2xmxk1xk2 and w01 has shape mxn1x1x1
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #flatten kernel dimension
    wt = wt.reshape([wt.size(0),wt.size(1),-1])
    wt[torch.isnan(wt)] = 0
    w01 = w01.reshape([w01.size(0),w01.size(1),-1])
    w02 = w02.reshape([w02.size(0),w02.size(1),-1])
    #remove zero target channels
    x = torch.sum(torch.abs(wt),dim=(1,2))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:,:]
    removal = nout-wt.size(0)
    if removal > 0:
        mold = w01.size(0)
        mfull = wt.size(0)*(wt.size(1)*wt.size(2)-1)+1
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,:]
        w02 = w02[:,:m,:]
        w02 = w02[ind_keep,:,:]
    #ind = torch.nonzero(wt)[:wt.size(0)]
    #gamma2 = torch.tensor([wt[ind[i][0], ind[i][1], ind[i][2]] for i in range(wt.size(0))],device=device)
    #ind = np.array([torch.nonzero(wt[i,:,:])[0] for i in range(wt.size(0))])
    #print(wt.size())
    #print(ind.size())
    #ind = torch.zeros((wt.size(0),2),dtype=int,device=device)
    #x02 = torch.zeros((wt.size(0), w01.size(0)),device=device)
    xim = torch.zeros((wt.size(0), w01.size(0)),device=device)
    gamma2 = torch.zeros(wt.size(0), device=device)
    for i in range(wt.size(0)):
        #ii = torch.nonzero(wt[i,:,:])[0]
        #print(ii[0].item())
        #print(ind[i,0].size())
        #ind[i,0] = int(ii[0].item())
        #ind[i,1] = int(ii[1].item())
        #ii = torch.nonzero(wt[i,:,:])
        wi = torch.abs(wt[i,:,:])
        #ii = torch.argmax(wi,keepdim=False) #torch.where(wi==torch.max(wi))
        #row, col = divmod(wi.argmax().item(), wi.shape[1])
        #ind[i,:] = torch.tensor([row,col],device=device)
        #xim[i,:] = w01[:,ind[i,0],0]*w02[i,:,ind[i,1]]
        #gamma2[i] = wt[i,ind[i,0],ind[i,1]]
        c, k = divmod(wi.argmax().item(), wi.shape[1])
        #ind[i,:] = torch.tensor([row,col],device=device)
        xim[i,:] = w01[:,c,0]*w02[i,:,k]
        gamma2[i] = wt[i,c,k]
    #gamma2 = torch.tensor([wt[i,ind[i][0],ind[i][1]] for i in range(wt.size(0))],device=device)
    wt = torch.einsum('ijk,i->ijk', wt, 1/gamma2)
    #x0 = torch.zeros([wt.size(0), w01.size(0)],device=device)
    #x0 = torch.tensor([], device=device)
    #torch.tensor([w01[:, ind[i][0], ind[i][2]] for i in range(wt.size(0))],device=device) #w01[:,0,0]
    M = torch.einsum('ikl,kjm->ijlk', w02, w01)-torch.einsum('ik,ijl->ijlk', xim, wt)
    del w01
    M = M.reshape([M.size(0)*M.size(1)*M.size(2),M.size(3)])
    print(M.size())
    #remove zero rows that correspond to index used for gamma2
    #x = torch.abs(wt)
    trows = M.size(1)-1
    #print(trows)
    x = torch.sum(torch.abs(M),dim=1)
    xx, _ = torch.sort(x,descending=True)
    #print(xx[:10])
    thr = xx[trows-1]
    #print(xx[trows:])
    #print(thr)
    del xx
    #thr = 10**(-10)
    M = M[x >= thr,:] #M[x > thr,:]
    #print(M.size())
    #print(M.size())
    z = -M[:,0]
    del x
    #scaling often leads to better approximation error after transformation with gamma2
    scale = 100
    gamma1 = torch.linalg.solve(scale*M[:,1:], scale*z)
    gamma1 = torch.cat((torch.tensor([1]), gamma1), 0)
    del M
    #gamma2 = gamma2/(torch.einsum('ik,k->i',w02[:,:,0],x0*gamma1))
    #gamma2 = 1/(torch.einsum('ik,k->i',xim,gamma1))
    #error = torch.sqrt(torch.mean((torch.einsum('i,ikl,k,kjm->ijl', gamma2, w02, gamma1, w01)-wt)**2))
    gamma2 = gamma2/(torch.einsum('ik,k->i',xim,gamma1))
    #print("err: ", error)
    bshift = torch.sum(w02,dim=(1,2)) #torch.matmul(torch.sum(w02,dim=2),bias1)
    if removal > 0:
        gamma2out = torch.zeros(nout)
        gamma2out[ind_keep] = gamma2
        gamma2 = gamma2out
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
        bshiftout = torch.zeros(nout)
        bshiftout[ind_keep] = bshift
        bshift = bshiftout
    bias1 = torch.ones(len(gamma1))*act_shift
    bias2 = bt-bshift*act_shift
    return gamma1, bias1, gamma2, bias2

def conv_1to2layer_exact_pruned(wt, bt, w01, w02, act_scale, act_shift, device):
    #Inputs:
    #wt: conv target tensor of dimension n2xn1xk1Xk2
    #bt: conv target biases
    #w01 and w02: (random) weight tensors in network for which we adjust the BN parametes, assume that w02 is of shape n2xmxk1xk2 and w01 has shape mxn1x1x1
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #flatten kernel dimension
    wt = wt.reshape([wt.size(0),wt.size(1),-1])
    wt[torch.isnan(wt)] = 0
    w01 = w01.reshape([w01.size(0),w01.size(1),-1])
    w02 = w02.reshape([w02.size(0),w02.size(1),-1])
    #remove zero target channels
    x = torch.sum(torch.abs(wt),dim=(1,2))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:,:]
    removal = nout-wt.size(0)
    mfull = wt.size(0)*(wt.size(1)*wt.size(2)-1)+1
    m=w01.size(0)
    if removal > 0:
        mold = w01.size(0)
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,:]
        w02 = w02[:,:m,:]
        w02 = w02[ind_keep,:,:]
    #remove kernel dimensions to reduce the conditions to match a smaller m
    #round up -> we will remove more conditions later as well
    #newk = max(int(np.ceil(((m-1)/wt.size(0)+1)/wt.size(1))),1)
    #scorek = torch.sum(torch.abs(wt),dim=(0,1))
    #xx, ind = torch.sort(scorek,descending=True)
    #keepk = ind[:newk]
    #keepk, _ = torch.sort(keepk)
    #print(len(keepk))
    #wt = wt[:,:,keepk]
    #w02 = w02[:,:,keepk]
    #drop conditions for small weight entries if we have not enough capacity to represent everything
    if m < mfull:
        _, dropin_ind = torch.sort(torch.abs(wt.flatten()),descending=True)
        dropin_ind = dropin_ind[:(m+wt.size(0)-1)]

    #create conditions for gamma2 and gamma1
    xim = torch.zeros((wt.size(0), w01.size(0)),device=device)
    gamma2 = torch.zeros(wt.size(0), device=device)
    for i in range(wt.size(0)):
        wi = torch.abs(wt[i,:,:])
        c, k = divmod(wi.argmax().item(), wi.shape[1])
        xim[i,:] = w01[:,c,0]*w02[i,:,k]
        gamma2[i] = wt[i,c,k]
    #gamma2 = torch.tensor([wt[i,ind[i][0],ind[i][1]] for i in range(wt.size(0))],device=device)
    wt = torch.einsum('ijk,i->ijk', wt, 1/gamma2)
    #x0 = torch.zeros([wt.size(0), w01.size(0)],device=device)
    #x0 = torch.tensor([], device=device)
    #torch.tensor([w01[:, ind[i][0], ind[i][2]] for i in range(wt.size(0))],device=device) #w01[:,0,0]
    M = torch.einsum('ikl,kjm->ijlk', w02, w01)-torch.einsum('ik,ijl->ijlk', xim, wt)
    del w01
    M = M.reshape([M.size(0)*M.size(1)*M.size(2),M.size(3)])
    if m < mfull:
        M = M[dropin_ind,:]
    #print(M.size())
    #remove zero rows that correspond to index used for gamma2
    #x = torch.abs(wt)
    trows = M.size(1)-1
    #print(trows)
    x = torch.sum(torch.abs(M),dim=1)
    ##
    _, ii = torch.sort(x,descending=True)
    #xx, _ = torch.sort(x,descending=True)
    #thr = xx[trows-1]
    M = M[ii[:trows],:]
    del ii, x
    z = -M[:,0]
    ##
    #xx, _ = torch.sort(x,descending=True)
    #print(xx[:10])
    #thr = xx[trows-1]
    #print(xx[trows:])
    #print(thr)
    #del xx
    #thr = 10**(-10)
    #M = M[x >= thr,:] #M[x > thr,:]
    #z = -M[:,0]
    #del x
    #scaling often leads to better approximation error after transformation with gamma2
    scale = 100
    gamma1 = torch.linalg.solve(scale*M[:,1:], scale*z)
    gamma1 = torch.cat((torch.tensor([1]), gamma1), 0)
    del M
    gamma2 = gamma2/(torch.einsum('ik,k->i',xim,gamma1))
    #print("err: ", error)
    bshift = torch.sum(w02,dim=(1,2)) #torch.matmul(torch.sum(w02,dim=2),bias1)
    if removal > 0:
        gamma2out = torch.zeros(nout)
        gamma2out[ind_keep] = gamma2
        gamma2 = gamma2out
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
        bshiftout = torch.zeros(nout)
        bshiftout[ind_keep] = bshift
        bshift = bshiftout
    bias1 = torch.ones(len(gamma1))*act_shift
    bias2 = bt-bshift*act_shift
    return gamma1, bias1, gamma2, bias2

def conv_1to2layer_project_gamma1(wt, w01, w02, gamma2, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #x = torch.sum(torch.abs(wt),dim=(1))
    #nout = wt.size(0)
    #ind_keep = torch.arange(wt.size(0))
    #ind_keep = ind_keep[x>0.0000001]
    #wt = wt[ind_keep,:]
    #removal = nout-wt.size(0)
    #if removal > 0:
        #mfull = wt.size(0)*(wt.size(1)-1)+1
        #m = min(mfull, w01.size(0))
        #w01 = w01[:m,:,]
        #w02 = w02[:,:m]
        #w02 = w02[ind_keep,:]
    #gamma1 = torch.zeros(w01.size(0))
    #m = min(w01.size(0),wt.size(0)*(wt.size(1)-1)+1)
    #w02 = w02[:,:m]
    #w01 = w01[:m,:]
    wt = wt.reshape([wt.size(0),wt.size(1),-1])
    wt[torch.isnan(wt)] = 0
    w01 = w01.reshape([w01.size(0),w01.size(1),-1])
    w02 = w02.reshape([w02.size(0),w02.size(1),-1])
    #remove zero target channels
    x = torch.sum(torch.abs(wt),dim=(1,2))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:,:]
    removal = nout-wt.size(0)
    mfull = wt.size(0)*(wt.size(1)*wt.size(2)-1)+1
    m=w01.size(0)
    if removal > 0:
        mold = w01.size(0)
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,:]
        w02 = w02[:,:m,:]
        w02 = w02[ind_keep,:,:]
        gamma2 = gamma2[ind_keep]

    w02 = torch.einsum('i,iml->iml', gamma2, w02)
    MtM = torch.einsum('ikl,iml->km', w02, w02)*torch.einsum('kjl,mjl->km', w01, w01)
    z = torch.einsum('ijl,iml,mjp->m', wt, w02, w01) #torch.sum(w02*torch.einsum('ij,mj->im', wt, w01), dim=(0))
    #del w01
    #del w02
    #del wt
    #gamma1[:m] = torch.linalg.solve(MtM, z)
    gamma1 = torch.linalg.solve(MtM, z)
    #print("linear system error: ", torch.sqrt(torch.mean((torch.matmul(MtM,gamma1[:m])-z)**2)))
    #print("linear system error: ", torch.sqrt(torch.mean((torch.matmul(MtM,gamma1)-z)**2)))
    #gam1 = torch.linalg.lstsq(MtM, z).solution
    #print("linear system error lstsq: ", torch.sqrt(torch.mean((torch.matmul(MtM,gam1)-z)**2)))
    #print("rank: ", torch.linalg.matrix_rank(MtM))
    #print("m: ", len(z))
    del MtM
    if removal > 0:
        gamma1out = torch.zeros(mold)
        gamma1out[:m] = gamma1
        gamma1 = gamma1out
    return gamma1

def conv_1to2layer_project_gamma2(wt, w01, w02, gamma1, device):
    #Inputs:
    #wt: target weights
    #bt: target biases
    #w01 and w02: (random) weight tensors in network with adjusted BN parametes
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #x = torch.sum(torch.abs(wt),dim=(1))
    #nout = wt.size(0)
    #ind_keep = torch.arange(wt.size(0))
    #ind_keep = ind_keep[x>0.0000001]
    #wt = wt[ind_keep,:]
    #removal = nout-wt.size(0)
    #if removal > 0:
        #mfull = wt.size(0)*(wt.size(1)-1)+1
        #m = min(mfull, w01.size(0))
        #w01 = w01[:m,:,]
        #w02 = w02[:,:m]
        #w02 = w02[ind_keep,:]
    wt = wt.reshape([wt.size(0),wt.size(1),-1])
    wt[torch.isnan(wt)] = 0
    w01 = w01.reshape([w01.size(0),w01.size(1),-1])
    w02 = w02.reshape([w02.size(0),w02.size(1),-1])
    #remove zero target channels
    x = torch.sum(torch.abs(wt),dim=(1,2))
    nout = wt.size(0)
    ind_keep = torch.arange(wt.size(0))
    ind_keep = ind_keep[x>0.0000001]
    wt = wt[ind_keep,:,:]
    removal = nout-wt.size(0)
    mfull = wt.size(0)*(wt.size(1)*wt.size(2)-1)+1
    m=w01.size(0)
    if removal > 0:
        mold = w01.size(0)
        m = min(mfull, w01.size(0))
        w01 = w01[:m,:,:]
        w02 = w02[:,:m,:]
        w02 = w02[ind_keep,:,:]
        gamma1 = gamma1[:m]
    w0 = torch.einsum('iml,m,mjp->ijl', w02, gamma1, w01)
    del w01
    del w02
    del gamma1
    gamma2=torch.sum(w0*wt,dim=(1,2))/torch.sum(w0**2,dim=(1,2))
    if removal > 0:
        gamma2out = torch.zeros(nout)
        gamma2out[ind_keep] = gamma2
        gamma2 = gamma2out
    return gamma2

def conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1init, gamma2init, epochs=50, device="cpu"):
    #Inputs:
    #wt: conv target tensor of dimension n2xn1xk1Xk2
    #bt: conv target biases
    #w01 and w02: (random) weight tensors in network for which we adjust the BN parametes, assume that w02 is of shape n2xmxk1xk2 and w01 has shape mxn1x1x1
    #act_scale, act_shift: parameters to shift the pre-activations of the first layer to a linear region
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #flatten kernel dimension
    with torch.no_grad():
        w01 = w01.reshape([w01.size(0),w01.size(1),-1])
        w02 = w02.reshape([w02.size(0),w02.size(1),-1])
        wt = wt.reshape([wt.size(0),wt.size(1),-1])
        w01 = w01.to(device)
        w02 = w02.to(device)
        wt = wt.to(device)
    n2, n1 = wt.size(0), wt.size(1)
    m = w01.size(0)
    nk = wt.size(2)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma2 = torch.ones(n2, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    gamma2.data = gamma2init.detach().to(device)
    def closure():
        opt.zero_grad()
        objective = torch.sum((torch.einsum('i,imk,m,mjc->ijk', gamma2, w02, gamma1, w01)-wt)**2)/(n1*n2*nk)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1,gamma2], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sum((torch.einsum('i,imk,m,mjc->ijk', gamma2, w02, gamma1, w01)-wt)**2)/(n1*n2*nk)
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    bias1 = torch.ones(m,device=device)*act_shift
    bias2 = bt.to(device)-torch.einsum('imk,m->i',w02,bias1) #torch.matmul(torch.transpose(w02,0,1),bias1)
    return gamma1, bias1, gamma2, bias2

def conv_1to2layer_learn_GD_LBFGS_without_gamma2(wt, w01, w02, gamma1init, epochs=50, device="cpu"):
    #Inputs:
    #wt: conv target tensor of dimension n2xn1xk1Xk2
    #bt: conv target biases
    #w01 and w02: (random) weight tensors in network for which we adjust the BN parametes, assume that w02 is of shape n2xmxk1xk2 and w01 has shape mxn1x1x1
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #flatten kernel dimension
    with torch.no_grad():
        w01 = w01.reshape([w01.size(0),w01.size(1),-1])
        w02 = w02.reshape([w02.size(0),w02.size(1),-1])
        wt = wt.reshape([wt.size(0),wt.size(1),-1])
        w01 = w01.to(device)
        w02 = w02.to(device)
        wt = wt.to(device)
    n2, n1 = wt.size(0), wt.size(1)
    m = w01.size(0)
    nk = wt.size(2)
    gamma1 = torch.ones(m, requires_grad=True, device=device)
    gamma1.data = gamma1init.detach().to(device)
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.einsum('imk,m,mjc->ijk', w02, gamma1, w01)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sum((torch.einsum('i,imk,m,mjc->ijk', gamma2, w02, gamma1, w01)-wt)**2)/(n1*n2*nk)
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    return gamma1

def conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(wt, w01, w02, feat, epochs=50, device="cpu"):
    #Inputs:
    #wt: conv target tensor of dimension n2xn1xk1Xk2
    #bt: conv target biases
    #w01 and w02: (random) weight tensors in network for which we adjust the BN parametes, assume that w02 is of shape n2xmxk1xk2 and w01 has shape mxn1x1x1
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #flatten kernel dimension
    with torch.no_grad():
        w01 = w01.reshape([w01.size(0),w01.size(1),-1])
        w02 = w02.reshape([w02.size(0),w02.size(1),-1])
        wt = wt.reshape([wt.size(0),wt.size(1),-1])
        w01 = w01.to(device)
        w02 = w02.to(device)
        wt = wt.to(device)
        feat = feat.to(device)
    n2, n1 = wt.size(0), wt.size(1)
    m = w01.size(0)
    nk = wt.size(2)
    add_dim = feat.size(0)
    #gamma1 = torch.ones(m, requires_grad=True, device=device)
    #gamma1.data = gamma1init.detach().to(device)
    #gamma_add = torch.zeros(add_dim, requires_grad=True, device=device)
    gamma1 = torch.randn(m, requires_grad=True, device=device)
    gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
    gamma_add = torch.zeros(add_dim, requires_grad=True, device=device)
    #gamma_add = torch.randn(add_dim, requires_grad=True, device=device)
    #gamma_add.data = gamma_add.data/torch.sqrt(torch.sum(gamma_add.data**2))
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.einsum('imk,m,mjc->ijk', w02, gamma1, w01)+torch.einsum('mijk,m->ijk', feat, gamma_add)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1, gamma_add], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sum((torch.einsum('i,imk,m,mjc->ijk', gamma2, w02, gamma1, w01)-wt)**2)/(n1*n2*nk)
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    return gamma1, gamma_add

def conv_1to2layer_learn_GD_LBFGS_wider(wt, w01, w02, feat, epochs=50, device="cpu"):
    #Inputs:
    #wt: conv target tensor of dimension n2xn1xk1Xk2
    #bt: conv target biases
    #w01 and w02: (random) weight tensors in network for which we adjust the BN parametes, assume that w02 is of shape n2xmxk1xk2 and w01 has shape mxn1x1x1
    #Outputs:
    #effective batch norm parameters for the neurons of the first and second layer: gamma1, beta1, gamma2, beta2
    #later problem: take standard deviation and mean of BN statistics based on data batches into account
    #probably always work with effective parameters
    #flatten kernel dimension
    with torch.no_grad():
        w01 = w01.reshape([w01.size(0),w01.size(1),-1])
        w02 = w02.reshape([w02.size(0),w02.size(1),-1])
        wt = wt.reshape([wt.size(0),wt.size(1),-1])
        w01 = w01.to(device)
        w02 = w02.to(device)
        wt = wt.to(device)
        feat = feat.to(device)
    n2, n1 = wt.size(0), wt.size(1)
    m = w01.size(0)
    nk = wt.size(2)
    add_dim = feat.size(0)
    #gamma1 = torch.ones(m, requires_grad=True, device=device)
    #gamma1.data = gamma1init.detach().to(device)
    #gamma_add = torch.zeros(add_dim, requires_grad=True, device=device)
    gamma1 = torch.randn(m, requires_grad=True, device=device)
    gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
    #gamma_add = torch.zeros(add_dim, requires_grad=True, device=device)
    gamma_add = torch.zeros(add_dim, requires_grad=True, device=device)
    #gamma_add.data = gamma_add.data/torch.sqrt(torch.sum(gamma_add.data**2))
    gamma2 = torch.ones(n2, requires_grad=True, device=device)
    gamma2.data = gamma2.data/torch.sqrt(torch.sum(gamma2.data**2))
    def closure():
        opt.zero_grad()
        objective = torch.mean((torch.einsum('i,imk,m,mjc->ijk', gamma2, w02, gamma1, w01)+torch.einsum('i,mijk,m->ijk', gamma2, feat, gamma_add)-wt)**2)
        objective.backward()
        return objective

    opt = torch.optim.LBFGS([gamma1, gamma_add, gamma2], lr=1, max_iter=40, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, history_size=100, line_search_fn='strong_wolfe')
    #history_lbfgs = []
    for i in range(epochs):
        #ll = torch.sum((torch.einsum('i,imk,m,mjc->ijk', gamma2, w02, gamma1, w01)-wt)**2)/(n1*n2*nk)
        #history_lbfgs.append(ll.item())
        opt.step(closure)
    #plt.semilogy(history_lbfgs, label='L-BFGS')
    #plt.legend()
    #plt.show()
    return gamma2, gamma1, gamma_add

def proxy_target_exact(wt, seed, epochs, device):
    random.seed(seed)
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        m=n2*(n1*k1*k2-1)+1
        #print(m)
        w01 = torch.randn([m,n1,1,1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m,k1,k2])/math.sqrt(2*m*k1*k2)
        bt = torch.ones(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
    else:
        n2, n1 = wt.size()
        m=n2*(n1-1)+1
        w01 = torch.randn([m,n1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m])/math.sqrt(2*m)
        bt = torch.ones(n2,device=device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target(wt, seed, epochs, m, device):
    random.seed(seed)
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        w01 = torch.randn([m,n1,1,1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m,k1,k2])/math.sqrt(2*m*k1*k2)
        gamma1 = torch.ones(m,device=device)
        gamma2 = torch.ones(n2,device=device)
        bt = torch.ones(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
    else:
        n2, n1 = wt.size()
        w01 = torch.randn([m,n1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m])/math.sqrt(2*m*k1*k2)
        bt = torch.ones(n2,device=device)
        gamma1 = torch.ones(m,device=device)
        gamma2 = torch.ones(n2,device=device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_init_exact(wt, seed, epochs, m, device):
    random.seed(seed)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        #print(m)
        m = min(m,mfull)
        torch.manual_seed(seed)
        w01 = torch.randn([m,n1,1,1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m,k1,k2])/math.sqrt(2*m*k1*k2)
        bt = torch.ones(n2,device=device)
        #wt2 = wt.reshape((n2,n1*k1*k2))
        #w022 = w02[:,:,0,0]
        #w011 = torch.einsum('imkl,mjst->ijmn', w02, w01)
        #bt =
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt2, bt, w01, w022, act_scale, act_shift, device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_exact_pruned(wt, bt, w01, w02, act_scale, act_shift, device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse = torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #Alternative
        gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, gamma2, device)
        #gamma1_2, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, torch.randn(n2), epochs, device)
        #gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
        gamma1_2, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
        gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
        gamma2_2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        torch.manual_seed(seed)
        w01 = torch.randn([m,n1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m])/math.sqrt(2*m)
        bt = torch.ones(n2,device=device)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("best error: ", rmse)
        #Alternative 1:
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #Alternative 2: projections
        gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
        #gamma1_2, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, torch.randn(n2), epochs, device)
        #gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
        gamma1_2, bias1, gamma2_2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
        gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
        gamma2_2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
        wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
        #if m>=mfull:
            #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device)
            #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #else:
            #gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
            #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, torch.randn(n2), epochs, device)
            #gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
            #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)

        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_project(wt, bt, w01, w02, act_scale, act_shift, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("error: ", rmse)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #gamma1 = fc_1to2layer_project_gamma1(wt, w01, w02, gamma2, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("error dropout: ", rmse)
        #gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("error project given gamma2 correct: ", rmse)
        #gamma2, gamma1 = fc_1to2layer_project_iterative(wt, w01, w02, torch.ones(m), torch.ones(n2), 100, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("error iterate project: ", rmse)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("finetune error: ", rmse)
        #gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("Project with good gamma2: ", rmse)
        #gamma1, gamma2 = fc_1to2layer_exact_project_l2(wt, w01, w02, gamma2, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("iterative project error: ", rmse)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("LBFGS error: ", rmse)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print(rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_lower_mem(wt, seed, epochs, m, device):
    random.seed(seed)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        #print(m)
        m = min(m,mfull)
        torch.manual_seed(seed)
        w01 = torch.randn([m,n1,1,1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m,k1,k2])/math.sqrt(2*m*k1*k2)
        bt = torch.ones(n2,device=device)
        #Alternative
        gamma1 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
        gamma2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        torch.manual_seed(seed)
        w01 = torch.randn([m,n1])/math.sqrt(2*n1)
        w02 = torch.randn([n2,m])/math.sqrt(2*m)
        bt = torch.ones(n2,device=device)
        gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
        gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target_init_flexible_fc_precond(wt, w01, w02, epochs, m, device, precond):
    random.seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    n2, n1 = wt.size()
    mfull=n2*(n1-1)+1
    m = min(m,mfull)
    #torch.manual_seed(1)
    bt = torch.ones(n2,device=device)
    #Alternative 1:
    if precond:
        #print("######################")
        gamma1, gamma2, condNumber = fc_1to2layer_exact_remove_conditions_precond(wt, bt, w01, w02, act_scale, act_shift, device)
    else:
        condNumber=torch.tensor([0])
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
    wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("error exact:")
    print(rmse)
    print("######################")
    # if precond:
    #     fact=1
    #     #fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
    #     w02 = w02*fact
    #     w01 = w01*fact
    #     #fact=100
    #     gamma2acc = gamma2.data.clone().detach()
    #     # print(gamma2acc.size())
    #     # print(w02.size())
    #     # print(gamma2.size())
    #     # print(wt.size())
    #     # print(w01.size())
    #     # print(gamma1.size())
    #     gamma12, gamma22, _ = fc_1to2layer_learn_GD_LBFGS_cond_expensive_finetune(wt, w01, gamma2acc*w02, gamma1, epochs, device, precond)
    #     #print(gamma2.size())
    #     gamma12, _, gamma22, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01, fact*w02, act_scale, act_shift, gamma12, gamma2acc*gamma22, epochs, device)
    # else:
    gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
    wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("finetuned exact: " + str(rmse))
    if torch.isnan(rmse):
        rmse = torch.sqrt(torch.mean(wt**2)).detach()
    #Alternative 2: projections
    #gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
    #gamma1_2, bias1, gamma2_2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
    #wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
    #rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
    #if torch.isnan(rmse2):
    #    rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
    #if rmse2 < rmse:
    #    rmse = rmse2
    #    wproxy = wproxy2
    #gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
    #gamma2_2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
    #wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
    #rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
    #if torch.isnan(rmse2):
    #    rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
    #if rmse2 < rmse:
    #    rmse = rmse2
    #    wproxy = wproxy2
    return rmse, wproxy, condNumber


def proxy_target_init_flexible(wt, w01, w02, epochs, m, device):
    random.seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        #print(m)
        m = min(m,mfull)
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_exact_pruned(wt, bt, w01, w02, act_scale, act_shift, device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse = torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #Alternative
        gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, gamma2, device)
        gamma1_2, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
        gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
        gamma2_2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        #Alternative 1:
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        print("error exact:")
        print(rmse)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print("exact: " + str(rmse))
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #Alternative 2: projections
        gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
        gamma1_2, bias1, gamma2_2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
        gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
        gamma2_2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
        wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
        rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
        if rmse2 < rmse:
            rmse = rmse2
            wproxy = wproxy2
    return rmse, wproxy, [w02, w01, gamma2, gamma1]



def proxy_target_init_flexible_lower_mem(wt, w01, w02, epochs, m, device):
    random.seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5

    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        #print(m)
        m = min(m,mfull)
        w01 = w01[:m]
        w02 = w02[:,:m,:,:]
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        #Alternative
        gamma1 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
        gamma2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        print(w02.size())
        print(w01.size())
        print(gamma2.size())
        print(gamma1.size())
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            #rmse2 = 100000
            gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
            wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        w01 = w01[:m,:]
        w02 = w02[:,:m,:,:]
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
        gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            #rmse2 = 100000
            gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_init_flexible_lower_mem_rand(wt, w01, w02, epochs, m, device):
    random.seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        mfull=n2*n1*k1*k2
        #print(m)
        m = min(m,mfull)
        w01 = w01[:m]
        w02 = w02[:,:m,:,:]
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        #Alternative
        gamma1 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
        gamma1[torch.isnan(gamma1)] = 1
        #gamma2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        #gamma1, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #wproxy = torch.einsum('i,ikmn,k,kj->ijmn', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #if torch.isnan(rmse):
        #    #rmse2 = 100000
        #    gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
        #    wproxy = torch.einsum('i,ikmn,k,kj->ijmn', gamma2, w02, gamma1, w01)
        #    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    else:
        n2, n1 = wt.size()
        mfull=n2*n1
        m = min(m,mfull)
        w01 = w01[:m,:]
        w02 = w02[:,:m,:,:]
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
        gamma1[torch.isnan(gamma1)] = 1
        #gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #if torch.isnan(rmse):
        #    #rmse2 = 100000
        #    gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
        #    #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        #    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return gamma1

# def proxy_target_cond(wt, w01, w02, epochs, m, device):
#     random.seed(1)
#     wt[torch.isnan(wt)] = 0
#     act_scale = 1
#     act_shift = 5
#     if len(wt.size()) > 2:
#         n2, n1, k1, k2 = wt.size()
#         mfull=n2*(n1*k1*k2-1)+1
#         #print(m)
#         m = min(m,mfull)
#         torch.manual_seed(1)
#         bt = torch.ones(n2,device=device)
#         gamma1, bias1, gamma2, bias2 = conv_1to2layer_exact_pruned(wt, bt, w01, w02, act_scale, act_shift, device)
#         gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
#         wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
#         rmse = torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#         if torch.isnan(rmse):
#             rmse = torch.sqrt(torch.mean(wt**2)).detach()
#         #Alternative
#         gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, gamma2, device)
#         gamma1_2, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
#         wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#         gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
#         gamma2_2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
#         wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#     else:
#         n2, n1 = wt.size()
#         mfull=n2*(n1-1)+1
#         m = min(m,mfull)
#         torch.manual_seed(1)
#         bt = torch.ones(n2,device=device)
#         #Alternative 1:
#         gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
#         gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
#         wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
#         rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#         #print("exact: " + str(rmse))
#         if torch.isnan(rmse):
#             rmse = torch.sqrt(torch.mean(wt**2)).detach()
#         #Alternative 2: projections
#         gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
#         gamma1_2, bias1, gamma2_2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
#         wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#         gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
#         gamma2_2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
#         wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#     return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target(wt, seed, epochs, m, device):
    dd = torch.numel(wt)
    #if len(wt.size()) > 2:
        #mfull=wt.size(0)*(wt.size(1)*wt.size(2)*wt.size(3)-1)+1
    #else:
        #mfull=wt.size(0)*(wt.size(1)-1)+1
    if (dd <= 80000) and (m/full_dim(wt) >= 0.5):
        return proxy_target_init_exact(wt, seed, epochs, m, device)
    else:
        return proxy_target_lower_mem(wt, seed, epochs, m, device)

def proxy_target_permute(nper, wt, seed, epochs, m, init, device):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    mfull = full_dim(wt)
    if init in ["He", "ortho", "ER"]:
        if init == "He":
            w01, w02 = init_He(wt.size(), min(m, mfull), seed)
        elif init == "ortho":
            w01, w02 = init_ortho(wt.size(), min(m, mfull), seed)
        elif init == "ER":
            w01, w02 = init_ER(wt.size(), min(m, mfull), seed)
        else: #if init == "aij":
            w01, w02 = init_aij(wt.size(), min(m, mfull), seed)
        if len(wt.size())>2:
            dist_mat_rows = torch.cdist(wt.reshape(wt.size(0),-1),torch.einsum('ikmn,kjlp->ijmn', w02, w01).reshape(wt.size(0),-1),p=2)
        else:
            dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
        row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        w02 = w02[col_ind]
        rmse, wproxy, x = proxy_target_init_flexible_lower_mem_permute(wt, w01, w02, epochs, m, device)
        for i in range(nper): #maybe parallelize
            torch.manual_seed(seed+i)
            ind = torch.randperm(wt.size(0))
            w02 = w02[ind]
            rmse_new, wproxy_new, x_new = proxy_target_init_flexible_lower_mem_permute(wt, w01, w02, epochs, m, device)
            if rmse_new < rmse:
                rmse = rmse_new
                wproxy = wproxy_new.clone()
                #print("yes")
    else:
        rmse, wproxy, x = proxy_target_ER_equiv_permute(wt, seed, epochs, m, device)
    return rmse, wproxy, x


def proxy_target_accurate(wt, seed, epochs, m, init, device):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    mfull = full_dim(wt)
    if init in ["He", "ortho", "ER"]:
        if init == "He":
            w01, w02 = init_He(wt.size(), min(m, mfull), seed)
        elif init == "ortho":
            w01, w02 = init_ortho(wt.size(), min(m, mfull), seed)
        elif init == "ER":
            w01, w02 = init_ER(wt.size(), min(m, mfull), seed)
        else: #if init == "aij":
            w01, w02 = init_aij(wt.size(), min(m, mfull), seed)
        if len(wt.size())>2:
            try:
                dist_mat_rows = torch.cdist(wt.reshape(wt.size(0),-1),torch.einsum('ikmn,kjlp->ijmn', w02, w01).reshape(wt.size(0),-1),p=2)
                row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                w02 = w02[col_ind]
            except ValueError:
                w02 = w02
        else:
            try:
                dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
                row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                w02 = w02[col_ind]
            except ValueError:
                w02 = w02
        rmse, wproxy, x = proxy_target_init_flexible(wt, w01, w02, epochs, m, device)
    else:
        rmse, wproxy, x = proxy_target_ER_equiv_permute(wt, seed, epochs, m, device)
    return rmse, wproxy, x

def init_He(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w01 = torch.randn([m,n1,1,1])/math.sqrt(n1)
        w02 = torch.randn([n2,m,k1,k2])/math.sqrt(m*k1*k2/2)
        #w01 = torch.randn([m,n1,1,1])
        #w02 = torch.randn([n2,m,k1,k2])
        #fact = torch.sqrt(torch.sum(torch.einsum('imkl,mjbp->ijkl',w02,w01)**2))
        #w01 = w01/fact
        #w02 = w02/fact
    else:
        n2, n1 = dim
        w01 = torch.randn([m,n1])/math.sqrt(n1)
        w02 = torch.randn([n2,m])/math.sqrt(m/2)
        #w01 = torch.randn([m,n1])
        #w02 = torch.randn([n2,m])
        #fact = torch.sqrt(torch.sum(torch.einsum('im,mj->ij',w02,w01)**2))
        #w01 = w01/fact
        #w02 = w02/fact
    return w01, w02

def init_uni(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w01 = (2*torch.rand([m,n1,1,1])-1)/math.sqrt(n1/3)
        w02 = (2*torch.rand([n2,m,k1,k2])-1)/math.sqrt(m*k1*k2/6)
        #w01 = torch.rand([m,n1,1,1])
        #w02 = torch.rand([n2,m,k1,k2])
        #fact = torch.sqrt(torch.sum(torch.einsum('imkl,mjbp->ijkl',w02,w01)**2))
        #w01 = w01/fact
        #w02 = w02/fact
    else:
        n2, n1 = dim
        w01 = (2*torch.rand([m,n1])-1)/math.sqrt(n1/3)
        w02 = (2*torch.rand([n2,m])-1)/math.sqrt(m/6)
        #w01 = torch.rand([m,n1])
        #w02 = torch.rand([n2,m])
        #fact = torch.sqrt(torch.sum(torch.einsum('im,mj->ij',w02,w01)**2))
        #w01 = w01/fact
        #w02 = w02/fact
    return w01, w02

def init_uni_pos(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    #this has a positive mean!
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w01 = torch.rand([m,n1,1,1])/math.sqrt(n1/3)
        w02 = torch.rand([n2,m,k1,k2])/math.sqrt(m*k1*k2/6)
        #w01 = torch.rand([m,n1,1,1])
        #w02 = torch.rand([n2,m,k1,k2])
        #fact = torch.sqrt(torch.sum(torch.einsum('imkl,mjbp->ijkl',w02,w01)**2))
        #w01 = w01/fact
        #w02 = w02/fact
    else:
        n2, n1 = dim
        w01 = torch.rand([m,n1])/math.sqrt(n1/3)
        w02 = torch.rand([n2,m])/math.sqrt(m/6)
        #w01 = torch.rand([m,n1])
        #w02 = torch.rand([n2,m])
        #fact = torch.sqrt(torch.sum(torch.einsum('im,mj->ij',w02,w01)**2))
        #w01 = w01/fact
        #w02 = w02/fact
    return w01, w02

def init_ortho(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        rank = min(m,n2*k1*k2)
        gaus = torch.randn(n2*k1*k2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = w02.reshape((n2,m,k1,k2))
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
        w01 = w01.reshape((m,n1,1,1))
    else:
        n2, n1 = dim
        rank = min(m,n2)
        gaus = torch.randn(n2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
    return w01, w02


def th_delete(tensor, indices):
    mask = torch.ones(tensor.numel(), dtype=torch.bool)
    mask[indices] = False
    return tensor[mask]

def init_precond(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        rank = min(m,n2*k1*k2)
        gaus = torch.randn(n2*k1*k2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = w02.reshape((n2,m,k1,k2))
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
        w01 = w01.reshape((m,n1,1,1))
    else:
        n2, n1 = dim
        #rank = min(m,n2)
        gaus = torch.randn(n1,n1)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        U = svd[0] @ svd[2]
        #w02
        gaus = torch.randn(n2,n2)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        U2 = svd[0] @ svd[2]
        remove = n2*n1-m
        if remove >=0:
            #build network with all edges
            w02 = U2.repeat(1,n1)
            w01 = torch.zeros((n1*n2,n1))
            for i in range(n1):
                w01[(i*n2):((i+1)*n2),:] = U[i,:].repeat(n2,1)
            if remove > 0:
                input_full = torch.randint(0,n1,(1,))
                input_full = input_full.item()
                w02[:,input_full*n2] = torch.sum(U2,dim=1)
                ind_rem = torch.arange(input_full*n2+1,input_full*n2+1+min(remove,n2-1))
                if remove > n2-1:
                    rem = torch.cat([torch.arange(0,input_full*n2),torch.arange((input_full+1)*n2,n1*n2)])
                    #print(remove-n2+1)
                    #print(rem.size(0))
                    rem = rem[torch.randperm(rem.size(0))]
                    rem = rem[:(remove-n2+1)]
                    ind_rem = torch.cat([ind_rem,rem])
                w02 = w02[:,~torch.isin(torch.arange(0,n1*n2),ind_rem)]
                w01 = w01[~torch.isin(torch.arange(0,n1*n2),ind_rem),:]
                #print(w01.size())
            # #keep a random subset of size m+n2 (the n2 will be pivots)
            # nbr_keep = min(m+n2, n1*n2) #(n1-1)*n2+1)
            # keep = torch.randperm(n1*n2)[:nbr_keep] #probabilty is small that whole neurons are removed
            # single = keep.copy()
            # #introduce pivots
            # for i in range(n2):
            #     #ind = torch.where(ii==i)[0][0]
            #     ii = keep%n1
            #     #print(i)
            #     icand = torch.where(ii==i)[0]
            #     #print(icand)
            #     #print(w02)
            #     #deg = torch.sum(w02[:,keep[icand]],dim=0)
            #     #print(deg)
            #     #icand = icand[deg<2]
            #     icand = icand[torch.isin(icand,single)]
            #     ind = icand[0]
            #     #print(ind)
            #     j = int((keep[ind].item()-i)/n1)
            #     #keep = keep[]
            #     keep = torch.cat([keep[:ind],keep[(ind+1):]])
            #     #print(keep)
            #     single[single%n1 != i]
            #     link_add =
            #     #link_add = keep[torch.randint(0,keep.size(dim=0),(1,))].item()
            #     #w01[link_add,j] = 0.5 #U[] #0.5
            #     w02[i,link_add] = 0.5 #U2[]#0.5
            # w01 = w01[keep,:]
            # w02 = w01[:,keep]
            # #
            # # k = math.floor(m/n1)
            # # deg = torch.ones(n1,dtype=int)*k
            # # rest = m-n1*k
            # # ind = torch.randperm(n1)
            # # ind = ind[:rest]
            # # deg[ind] = deg[ind]+1
            # # print(deg)
            # # print(sum(deg))
            # # ii = 0
            # # adj = torch.zeros((n2,n1))
            # # w01 = torch.zeros((m,n1))
            # # w02 = torch.zeros((n2,m))
            # # #missing = torch.zeros(n2)
            # # for i in range(n1):
            # #     w01[ii:(ii+deg[i]),:] = U[i,:].repeat(deg[i].item(),1)
            # #     #subset = torch.randperm(n2) #subset or full set of out-connections
            # #     subset = torch.cat([torch.arange((i%n2),n2),torch.arange(0,i%n2)])
            # #     subset = subset[:deg[i].item()]
            # #     #print(len(subset))
            # #     w02[:,ii:(ii+deg[i].item())] = U2[:,subset]
            # #     #adj[subset,i]=1
            # #     ii = ii+deg[i].item()
            # # #add n2 additional connections to gain n2 more degrees of freedom with gamma2
            # # for i in range(n2):
            # #     #j = torch.where(adj[i,:]==0)[0]
            # #     #if len(j) > 0:
            # #     #might have more missing... what if deg[j] = nout?
            # #     #i -> missing j
            # #     #0 -> 1
            # #     #1 -> 2
            # #     #n2-1 -> n2
            # #     #n2 -> 0
            # #     missing = (i+1)%n1
            # #     if deg[missing] < nout:
            # #         missing = torch.sum(deg[(i+1)%n1])-1
            # #         w02[:,missing] = U2[:,i]
            # #     else:
            # #         missing = (i+1)%n1
            # #         # j = j[0]
            # #         # upind = int(torch.sum(deg[:j])-1)
            # #         # if upind >= 0:
            # #         #     w02[:,upind] = U2[:,i]
        else:
            k = math.ceil(m/n2)
            w02 = U2.repeat(1,k)
            w02 = w02[:,:m]
            w01 = torch.zeros((m,n1))
            for i in range(n1):
                w01[(i*n2):((i+1)*n2),:] = U[i,:].repeat(n2,1)
            if m > n1*n2:
                subset = torch.randint(0,n1,(m-n1*n2,))
                w01[(n1*n2):,:] = U[subset,:]
    return w01, w02

def init_precond_umiss(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        rank = min(m,n2*k1*k2)
        gaus = torch.randn(n2*k1*k2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = w02.reshape((n2,m,k1,k2))
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
        w01 = w01.reshape((m,n1,1,1))
    else:
        n2, n1 = dim
        #rank = min(m,n2)
        gaus = torch.randn(n1,n1)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        U = svd[0] @ svd[2]
        #w02
        gaus = torch.randn(n2,n2)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        U2 = svd[0] @ svd[2]
        remove = n2*n1-m
        Umiss=torch.tensor([0]) #torch.zeros((n2*n1,max(remove-n2+1,0)))
        if remove >=0:
            #build network with all edges
            w02 = U2.repeat(1,n1)
            w01 = torch.zeros((n1*n2,n1))
            for i in range(n1):
                w01[(i*n2):((i+1)*n2),:] = U[i,:].repeat(n2,1)
            if remove > 0:
                input_full = torch.randint(0,n1,(1,))
                input_full = input_full.item()
                w02[:,input_full*n2] = torch.sum(U2,dim=1)
                ind_rem = torch.arange(input_full*n2+1,input_full*n2+1+min(remove,n2-1))
                if remove > n2-1:
                    rem = torch.cat([torch.arange(0,input_full*n2),torch.arange((input_full+1)*n2,n1*n2)])
                    #print(remove-n2+1)
                    #print(rem.size(0))
                    rem = rem[torch.randperm(rem.size(0))]
                    rem = rem[:(remove-n2+1)]
                    ind_rem = torch.cat([ind_rem,rem])
                    Umiss = torch.einsum('ik,kj->ijk',w02[:,rem],w01[rem,:])
                    Umiss = Umiss.reshape((n1*n2,len(rem)))
                w02 = w02[:,~torch.isin(torch.arange(0,n1*n2),ind_rem)]
                w01 = w01[~torch.isin(torch.arange(0,n1*n2),ind_rem),:]

        else:
            k = math.ceil(m/n2)
            w02 = U2.repeat(1,k)
            w02 = w02[:,:m]
            w01 = torch.zeros((m,n1))
            for i in range(n1):
                w01[(i*n2):((i+1)*n2),:] = U[i,:].repeat(n2,1)
            if m > n1*n2:
                subset = torch.randint(0,n1,(m-n1*n2,))
                w01[(n1*n2):,:] = U[subset,:]
    return w01, w02, Umiss

def init_id(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        rank = min(m,n2*k1*k2)
        gaus = torch.randn(n2*k1*k2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = w02.reshape((n2,m,k1,k2))
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
        w01 = w01.reshape((m,n1,1,1))
    else:
        n2, n1 = dim
        U = torch.eye(n1)
        #w02
        U2 = torch.eye(n2)
        remove = n2*n1-m
        if remove >=0:
            #build network with all edges
            w02 = U2.repeat(1,n1)
            w01 = torch.zeros((n1*n2,n1))
            for i in range(n1):
                w01[(i*n2):((i+1)*n2),:] = U[i,:].repeat(n2,1)
            if remove > 0:
                input_full = torch.randint(0,n1,(1,))
                input_full = input_full.item()
                w02[:,input_full*n2] = torch.sum(U2,dim=1)
                ind_rem = torch.arange(input_full*n2+1,input_full*n2+1+min(remove,n2-1))
                if remove > n2-1:
                    rem = torch.cat([torch.arange(0,input_full*n2),torch.arange((input_full+1)*n2,n1*n2)])
                    rem = rem[torch.randperm(rem.size(0))]
                    rem = rem[:(remove-n2+1)]
                    #rem = torch.randperm(rem)[:(remove-n2-1)]
                    ind_rem = torch.cat([ind_rem,rem])
                w02 = w02[:,~torch.isin(torch.arange(0,n1*n2),ind_rem)]
                w01 = w01[~torch.isin(torch.arange(0,n1*n2),ind_rem),:]
        else:
            k = math.ceil(m/n2)
            w02 = U2.repeat(1,k)
            w02 = w02[:,:m]
            w01 = torch.zeros((m,n1))
            for i in range(n1):
                w01[(i*n2):((i+1)*n2),:] = U[i,:].repeat(n2,1)
            if m > n1*n2:
                subset = torch.randint(0,n1,(m-n1*n2,))
                w01[(n1*n2):,:] = U[subset,:]
        # n2, n1 = dim
        # w02 = torch.zeros((n2,m))
        # U = torch.eye(n2)
        # w02 = U.repeat(1,math.ceil(m/n2))
        # w02 = w02[:,:m]
        # #w01
        # w01 = torch.zeros((m,n1))
        # U = torch.eye(n1)
        # rep = math.floor(m/n1)
        # for i in range(rep):
        #     w01[(i*n1):((i+1)*n1),:] = U[i,:].repeat(n1,1)
        # w01[(rep*n1):,:] = U[:(m%n1),:]
    return w01, w02

def init_ONB(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        rank = min(m,n2*k1*k2)
        gaus = torch.randn(n2*k1*k2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = w02.reshape((n2,m,k1,k2))
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
        w01 = w01.reshape((m,n1,1,1))
    else:
        n2, n1 = dim
        rank = min(m,n2)
        gaus = torch.randn(n2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
    return w01, w02

def init_ER(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    p=0.5
    if m < 1000:
        p=0.8
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w02 = torch.where(torch.rand((n2,m,k1,k2)) <= p,1.0,0.0)#.to(device)
        #if n1*m < 10000:
        #    p=1.0
        w01 = torch.where(torch.rand((m,n1,1,1)) <= p,1.0,0.0)#.to(device)
    else:
        n2, n1 = dim
        w02 = torch.where(torch.rand((n2,m)) <= p,1.0,0.0)#.to(device)
        #if n1*m < 10000:
        #    p=1.0
        w01 = torch.where(torch.rand((m,n1)) <= p,1.0,0.0)#.to(device)
    return w01, w02

def init_ER_norm(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    p=0.5
    if m < 1000:
        p=0.8
    w01, w02 = init_He(dim, m, seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w02 = w02*torch.where(torch.rand((n2,m,k1,k2)) <= p,1.0,0.0)#.to(device)
        #if n1*m < 10000:
        #    p=0.8
        w01 = w01*torch.where(torch.rand((m,n1,1,1)) <= p,1.0,0.0)#.to(device)
    else:
        n2, n1 = dim
        w02 = w02*torch.where(torch.rand((n2,m)) <= p,1.0,0.0)#.to(device)
        #if n1*m < 10000:
        #    p=0.8
        w01 = w01*torch.where(torch.rand((m,n1)) <= p,1.0,0.0)#.to(device)
    return w01, w02

def init_ER_sign(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    p=0.5
    if m < 1000:
        p=0.8
    w01, w02 = init_sign(dim, m, seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w02 = w02*torch.where(torch.rand((n2,m,k1,k2)) <= p,1.0,0.0)#.to(device)
        #if n1*m < 10000:
        #    p=0.8
        w01 = w01*torch.where(torch.rand((m,n1,1,1)) <= p,1.0,0.0)#.to(device)
    else:
        n2, n1 = dim
        w02 = w02*torch.where(torch.rand((n2,m)) <= p,1.0,0.0)#.to(device)
        #if n1*m < 10000:
        #    p=0.8
        w01 = w01*torch.where(torch.rand((m,n1)) <= p,1.0,0.0)#.to(device)
    return w01, w02

def init_sign(dim, m, seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w01 = torch.where(torch.rand([m,n1,1,1]) > 0, 1.0, -1.0)
        w02 = torch.where(torch.rand([n2,m,k1,k2]) > 0, 1.0, -1.0)
    else:
        n2, n1 = dim
        w01 = torch.where(torch.rand([m,n1])> 0, 1.0, -1.0)
        w02 = torch.where(torch.rand([n2,m])> 0, 1.0, -1.0)
        w01 = w01
    return w01, w02

def init_aij(dim, m, seed):
    #todo: randomization of equivalent to training with other parameters, does not need explicit construction of w01 and w02
    random.seed(seed)
    torch.manual_seed(seed)
    if len(dim) > 2:
        n2, n1, k1, k2 = dim
        w02 = torch.where(torch.rand((n2,m,k1,k2)) <= p,1.0,0.0).to(device)
        w01 = torch.where(torch.rand((m,n1,1,1)) <= p,1.0,0.0).to(device)
    else:
        n2, n1 = dim
        w02 = torch.where(torch.rand((n2,m)) <= p,1.0,0.0).to(device)
        w01 = torch.where(torch.rand((m,n1)) <= p,1.0,0.0).to(device)
    return w01, w02

def proxy_target_init_flexible_lower_mem_permute(wt, w01, w02, epochs, m, device):
    random.seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        #print(m)
        m = min(m,mfull)
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        #optimal permutation for gamma2 = ones
        #maskt = torch.where(torch.abs(wt)>0.1*torch.mean(torch.abs(wt)),1.0,0.0)
        #dist_mat_rows = torch.cdist(wt,torch.einsum('ikmn,kjlb->ijmn', w02, w01),p=2)
        #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        #w02 = w02[col_ind]
        gamma1 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
        gamma2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            #rmse2 = 100000
            gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
            wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        torch.manual_seed(1)
        bt = torch.ones(n2,device=device)
        #dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
        #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        #w02 = w02[col_ind]
        gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
        gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            #rmse2 = 100000
            gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_ortho(wt, seed, epochs, m, device):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        m = min(m,mfull)
        rank = min(m,n2*k1*k2)
        gaus = torch.randn(n2*k1*k2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0] @ svd[2]
        w02 = w02.reshape((n2,m,k1,k2))
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
        w01 = w01.reshape((m,n1,1,1))
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        rank = min(m,n2)
        gaus = torch.randn(n2,m)
        svd = torch.linalg.svd(gaus)
        #orth = svd[0][:,:rank] @ svd[2][:rank,:]
        w02 = svd[0][:,:rank] @ svd[2][:rank,:]
        gaus = torch.randn(m,n1)
        svd = torch.linalg.svd(gaus)
        rank = min(m,n1)
        w01 = svd[0][:,:rank] @ svd[2][:rank,:]
    if dd <= 80000:
        return proxy_target_init_flexible(wt, w01, w02, epochs, m, device)
    else:
        return proxy_target_init_flexible_lower_mem(wt, w01, w02, epochs, m, device)

def proxy_target_ER(wt, seed, epochs, m, p, device):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        m = min(m,mfull)
        w02 = torch.where(torch.rand((n2,m,k1,k2)) <= p,1.0,0.0).to(device)
        w01 = torch.where(torch.rand((m,n1,1,1)) <= p,1.0,0.0).to(device)
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        w02 = torch.where(torch.rand((n2,m)) <= p,1.0,0.0).to(device)
        w01 = torch.where(torch.rand((m,n1)) <= p,1.0,0.0).to(device)
    if dd <= 80000:
        return proxy_target_init_flexible(wt, w01, w02, epochs, m, device)
    else:
        return proxy_target_init_flexible_lower_mem(wt, w01, w02, epochs, m, device)

def proxy_target_aij(wt, seed, epochs, m, device):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
        m = min(m,mfull)
        #fill random target links in
        indflat = torch.randperm(dd)[:m]
        x=torch.zeros(dd)
        x[indflat] = 1.0
        mask = x.reshape(wt.size())
        ind = torch.where(mask>0.5)
        del x
        w02 = torch.zeros((n2,m,k1,k2))
        w01 = torch.zeros((m,n1,1,1))
        w02[ind[0],torch.arange(m),ind[2],ind[3]] = 1
        w01[:,ind[1],0,0] = 1
        #add n2 free connections to ensure that we can use gamma2->these will only help represent n2-1 more edges
        #each edge leads to exactly one of the output neurons
        indn2 = torch.randint(0, m, (n2,))
        ind1 = torch.randint(0, k1, (n2,))
        ind2 = torch.randint(0, k2, (n2,))
        w02[torch.arange(n2),indn2,ind1,ind2] = 1
        mask[torch.arange(n2), ind[1][indn2], ind1, ind2] = 1
        del ind, ind1, ind2, indn2
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
        m = min(m,mfull)
        indflat = torch.randperm(dd)[:m]
        #need constraint that every output neuron needs at least one out link -> otherwise corresponding gamma2 not usable
        x=torch.zeros(dd)
        x[indflat] = 1.0
        mask = x.reshape(wt.size())
        ind = torch.where(mask>0.5)
        del x
        w02 = torch.zeros((n2,m))
        w01 = torch.zeros((m,n1))
        #print(w02[ind[0],torch.arange(m)].size())
        w02[ind[0],torch.arange(m)] = 1
        w01[torch.arange(m),ind[1]] = 1
        #add additional n2 random connections
        #each leads to exactly one of the output neurons
        indn2 = torch.randint(0, m, (n2,))
        w02[torch.arange(n2),indn2] = 1
        mask[torch.arange(n2), ind[1][indn2]] = 1
        del ind
    #if dd <= 80000:
    #    return proxy_target_init_flexible(wt, w01, w02, epochs, m, device)
    #else:
        #return proxy_target_init_flexible_lower_mem(wt, w01, w02, epochs, m, device)
    wproxy = mask*wt
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    gamma2 = torch.ones(n2) #actually more complicated than that -> would need to resolve potential loops etc.
    gamma1 = wt.flatten()[indflat] #actually more complicated than that -> would need to resolve potential loops etc.
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_ER_equiv(wt, seed, epochs, m, device):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    indflat = torch.randperm(dd)[:(m+wt.size(0))]
    x=torch.zeros(dd)
    x[indflat] = 1.0
    mask = x.reshape(wt.size())
    ind = torch.where(mask>0.5)
    #pp=float(m+torch.size())/float(dd)
    #wproxy = wt.clone() #torch.where(torch.rand(wt.size(),device=device) <= pp, wt, 0)
    #wproxy[torch.rand(wt.size(),device=device) > pp] = 0.0
    #wproxy = torch.zeros(wt.size(),device=device)
    wproxy = wt*mask
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    gamma2 = torch.ones(wt.size(0))
    gamma1 = wproxy.flatten()
    gamma1 = gamma1[torch.abs(gamma1)>0]
    w01=1
    w02=1
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_ER_equiv_permute(wt, seed, epochs, m, device):
    dd = torch.numel(wt)
    m = min(m,dd-wt.size(0))
    random.seed(seed)
    torch.manual_seed(seed)
    mm = (m+wt.size(0))
    p = mm/dd
    indflat = torch.randperm(dd)[:mm]
    x=torch.zeros(dd)
    x[indflat] = 1.0
    mask = x.reshape(wt.size())
    x, _ = torch.sort(torch.abs(wt).flatten(),descending=True)
    thr = min(int(dd*2*p/(1+p)),dd-1)
    thr = x[thr]
    #print(thr)
    maskt = torch.where(torch.abs(wt)>thr,1.0,0.0)
    if len(wt.size())>2:
        #dist_mat_rows = torch.cdist(torch.sum(maskt,dim=(2,3)),torch.sum(mask,dim=(2,3)),p=2)
        dist_mat_rows = torch.cdist(maskt.reshape(wt.size(0),-1),mask.reshape(wt.size(0),-1),p=2)
        row_ind, col_ind = linear_sum_assignment(dist_mat_rows .cpu().detach().numpy())
    else:
        dist_mat_rows = torch.cdist(maskt,mask,p=2)
        row_ind, col_ind = linear_sum_assignment(dist_mat_rows .cpu().detach().numpy())
    mask = mask[col_ind]
    #ind = torch.where(mask>0.5)
    #pp=float(m+torch.size())/float(dd)
    #wproxy = wt.clone() #torch.where(torch.rand(wt.size(),device=device) <= pp, wt, 0)
    #wproxy[torch.rand(wt.size(),device=device) > pp] = 0.0
    #wproxy = torch.zeros(wt.size(),device=device)
    wproxy = wt*mask
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    gamma2 = torch.ones(wt.size(0))
    gamma1 = wproxy.flatten()
    gamma1 = gamma1[torch.abs(gamma1)>0]
    w01=1
    w02=1
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def eval_wproxy(wt, wproxy):
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    gamma2 = torch.ones(wt.size(0))
    gamma1 = wproxy.flatten()
    gamma1 = gamma1[torch.abs(gamma1)>0]
    w01=1
    w02=1
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_ER_equiv_rand_permute(wt, seed, epochs, m, nperm, device):
    dd = torch.numel(wt)
    m = min(m,dd-wt.size(0))
    random.seed(seed)
    torch.manual_seed(seed)
    mm = (m+wt.size(0)-1)
    #mm = m
    p = mm/dd
    indflat = torch.randperm(dd)[:mm]
    x=torch.zeros(dd)
    x[indflat] = 1.0
    mask = x.reshape(wt.size())
    x, _ = torch.sort(torch.abs(wt).flatten(),descending=True)
    thr = min(int(dd*2*p/(1+p)),dd-1)
    thr = x[thr]
    #print(thr)
    maskt = torch.where(torch.abs(wt)>thr,1.0,0.0)
    if len(wt.size())>2:
        #dist_mat_rows = torch.cdist(torch.sum(maskt,dim=(2,3)),torch.sum(mask,dim=(2,3)),p=2)
        dist_mat_rows = torch.cdist(maskt.reshape(wt.size(0),-1),mask.reshape(wt.size(0),-1),p=2)
        row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
    else:
        dist_mat_rows = torch.cdist(maskt,mask,p=2)
        row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
    mask = mask[col_ind]
    #print(1-torch.sum(mask)/dd)
    wproxy = wt*mask
    #print(torch.sqrt(torch.mean((wproxy-wt)**2)))
    rmse, wproxy, [w02, w01, gamma2, gamma1] = eval_wproxy(wt, wproxy)
    for i in range(1,nperm):
        ind = torch.randperm(wt.size(0))
        rmse_new, wproxy_new, [w02, w01, gamma2_new, gamma1_new] = eval_wproxy(wt, mask[ind]*wt)
        if rmse_new < rmse:
            rmse = rmse_new
            wproxy = wproxy_new
            gamma2 = gamma2_new
            gamma1 = gamma2_new
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_id_ER_permute(wt, seed, epochs, m, nperm, device):
    w01, w02 = init_id(wt.size(), m, seed)
    w02 = id_permute(w01,w02,wt)
    wproxy = torch.einsum('ik,kj->ij', w02, w01)*wt
    rmse = torch.sqrt(torch.mean((wt-wproxy)**2)).item()
    gamma2 = torch.ones(wt.size(0))
    gamma1 = torch.ones(m)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def id_permute(w01,w02,wt):
    n2, n1, m = w02.size(0), w01.size(1), w01.size(0)
    mask = torch.ones((n2,n2))-torch.einsum('ik,kj->ij', w02, w01)
    deg = torch.sum(wt**2,dim=1)
    #Greedy
    order_match = torch.argsort(deg,descending=True)
    perm_Greedy = torch.arange(n2)
    remain = torch.arange(n2)
    for i in order_match:
        i = i.item()
        sc = torch.sum((mask[perm_Greedy,:]*wt)**2).item() #torch.zeros(remain.size(0))
        #print(sc)
        ind_current = perm_Greedy[i].item()
        pc = ind_current
        ic = i
        for p in remain:
            p = p.item()
            ip = torch.where(perm_Greedy==p)[0].item()
            x = sc+torch.sum((mask[p,:]*wt[i,:])**2)+torch.sum((mask[ind_current,:]*wt[ip,:])**2)
            x = x-torch.sum((mask[ind_current,:]*wt[i,:])**2)-torch.sum((mask[p,:]*wt[ip,:])**2)
            if x.item() < sc:
                sc = x.item()
                #print(sc)
                pc = p
                ic = ip
        perm_Greedy[i] = pc
        perm_Greedy[ic] = ind_current
        remain = remain[remain!=pc]
    #col_ind = torch.arange(n2)
    #col_ind = col_ind[perm_Greedy]
    col_ind = perm_Greedy
    w02=w02[col_ind]
    #col_ind = torch.arange(n2)
    #print(col_ind)
    err1 = torch.sqrt(torch.mean((mask[col_ind]*wt)**2)).item()
    print("error: ", err1)

    mask = torch.einsum('ik,kj->ij', w02, w01)
    x, _ = torch.sort(torch.abs(wt).flatten(),descending=True)
    dd=n1*n2
    p=(m+n2-1)/dd
    thr = min(int(dd*2*p/(1+p)),dd-1)
    thr = x[thr]
    #print(thr)
    maskt = torch.where(torch.abs(wt)>thr,1.0,0.0)
    dist_mat_rows = torch.cdist(maskt,mask,p=2)
    row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
    mask = mask[col_ind]
    err = torch.sqrt(torch.mean((mask*wt-wt)**2)).item()
    print("error: ", err)
    if err < err1:
        w02=w02[col_ind]
        err=err1
    # mask = torch.einsum('ik,kj->ij', w02, w01)
    # dist_mat_rows = torch.cdist(wt,mask,p=2)
    # row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
    # w02=w02[col_ind]
    # print("error: ", torch.sqrt(torch.mean((mask[col_ind]*wt-wt)**2)).item())
    return err, w02, col_ind #, gamma1, gamma2


def svd_permute_fc(w01,w02,wt):
    n2, n1, m = w02.size(0), w01.size(1), w01.size(0)
    M = torch.einsum('ik,kj->ijk', w02, w01)
    M = M.reshape((n2*n1,m))
    print("cond number M:")
    try:
        U, S, Vh = torch.linalg.svd(M)
        ss = S.size(0)
        print(S[0]/S[-1])
        #gamma1 =  torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
        #del S, Vh
        #torch.cuda.empty_cache()
        Umiss = U[:,ss:]
        #del U
        #torch.cuda.empty_cache()
        Umiss = Umiss.reshape((n2,n1,-1))
        #Greedy
        order_match = torch.argsort(torch.sum(Umiss**2,dim=(1,2)),descending=True)
        perm_Greedy = torch.arange(n2)
        remain = torch.arange(n2)
        for i in order_match:
            i = i.item()
            #print(i)
            back = torch.einsum('ijk,ij->k',Umiss,wt[perm_Greedy])
            #print("loss: ", torch.sum(back**2))
            ind_current = perm_Greedy[i].item()
            score = back-torch.einsum('jk,j->k',Umiss[i,:,:],wt[ind_current,:])
            sc = torch.sum(back**2).item() #torch.zeros(remain.size(0))
            pc = ind_current
            ic = i
            for p in remain:
                p = p.item()
                ip = torch.where(perm_Greedy==p)[0].item()
                x = score+torch.einsum('jk,j->k',Umiss[i,:,:],wt[p,:])+torch.einsum('jk,j->k',Umiss[ip,:,:],wt[ind_current,:])
                x = x-torch.einsum('jk,j->k',Umiss[ip,:,:],wt[p,:])
                x = torch.sum(x**2)
                if x.item() < sc:
                    sc = x.item()
                    pc = p
                    ic = ip
            perm_Greedy[i] = pc
            perm_Greedy[ic] = ind_current
            remain = remain[remain!=pc]
        #print(remain)
        #print("hello")
        #print(perm_Greedy)
        #col_ind = torch.arange(n2)
        #col_ind = col_ind[perm_Greedy]
        col_ind = torch.argsort(perm_Greedy)
        w02 = w02[col_ind]
        gamma1 =  torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt[perm_Greedy].reshape((-1,))
    except:
        col_ind = torch.arange(n2)
        gamma1 = torch.randn(m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
    #col_ind = perm_Greedy
    #print(col_ind)
    #print(col_ind)
    #score = torch.sum(torch.einsum('jk,pj->pk',Umiss[i,:,:],wt[remain,:])**2,dim=1)
    # old = perm_Greedy[i]
    # perm_Greedy[i] = remain[torch.argmin(score)]
    # perm_Greedy[perm_Greedy[i]] = old
    # remain = remain[remain!=perm_Greedy[i]]
    # score = torch.einsum('ijk,pj->ipk',Umiss,wt)
    # back = torch.einsum('iik->k',score)
    # dd = torch.eye(n2,dtype=torch.double)*torch.sum(back**2)#torch.zeros((n2,n2),dtype=float64) #torch.eye(n2)*torch.sum(back**2)
    # for i in range(n2):
    #     for j in range(i):
    #         x = torch.sum((back-score[i,i,:]-score[j,j,:]+score[i,j,:]+score[j,i,:])**2)
    #         dd.data[i,j] = x.item()
    #         dd.data[j,i] = x.item()
    # print("First")
    # print(torch.sum(back**2))
    # dist_mat_rows = dd #torch.sum(dist_mat_rows**2,dim=2)
    # row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
    #print(col_ind)
    #perm_inverse = torch.arange(n2)
    #wt = wt[col_ind]
    #dist_mat_rows = torch.einsum('ijk,pj->ipk',Umiss[col_ind],wt)
    #score = torch.einsum('ijk,pj->ipk',Umiss[col_ind],wt)
    #back = torch.einsum('iik->k',score)
    #print("After perm")
    #print(torch.sum(back**2))
    return w02, gamma1

def svd_permute_cond_fc(w01,w02,wt):
    n2, n1, m = w02.size(0), w01.size(1), w01.size(0)
    M = torch.einsum('ik,kj->ijk', w02, w01)
    M = M.reshape((n2*n1,m))
    #print("cond number M:")
    try:
        U, S, Vh = torch.linalg.svd(M)
        ss = S.size(0)
        condNumber=S[0]/S[-1]
        Umiss = U[:,ss:]
        Umiss = Umiss.reshape((n2,n1,-1))
        #Greedy
        order_match = torch.argsort(torch.sum(Umiss**2,dim=(1,2)),descending=True)
        perm_Greedy = torch.arange(n2)
        remain = torch.arange(n2)
        for i in order_match:
            i = i.item()
            back = torch.einsum('ijk,ij->k',Umiss,wt[perm_Greedy])
            ind_current = perm_Greedy[i].item()
            score = back-torch.einsum('jk,j->k',Umiss[i,:,:],wt[ind_current,:])
            sc = torch.sum(back**2).item() #torch.zeros(remain.size(0))
            pc = ind_current
            ic = i
            for p in remain:
                p = p.item()
                ip = torch.where(perm_Greedy==p)[0].item()
                x = score+torch.einsum('jk,j->k',Umiss[i,:,:],wt[p,:])+torch.einsum('jk,j->k',Umiss[ip,:,:],wt[ind_current,:])
                x = x-torch.einsum('jk,j->k',Umiss[ip,:,:],wt[p,:])
                x = torch.sum(x**2)
                if x.item() < sc:
                    sc = x.item()
                    pc = p
                    ic = ip
            perm_Greedy[i] = pc
            perm_Greedy[ic] = ind_current
            remain = remain[remain!=pc]
        col_ind = torch.argsort(perm_Greedy)
    except:
        col_ind = torch.arange(n2)
        condNumber=torch.tensor([0])
        Vh=0
        S=0
        U=0
    return col_ind, condNumber, Vh, S, U
    #w02, gamma1, gamma1Perm, condNumber, C, M@C

def svd_permute_precond_fc(Umiss,wt):
    n2, n1, m = wt.size(0), wt.size(1), Umiss.size(1)
    #M = torch.einsum('ik,kj->ijk', w02, w01)
    #M = M.reshape((n2*n1,m))
    #print("cond number M:")
    #proxy that ignores special neuron that connects one input to all outpus (to use gamma2 for increased expressiveness)
    #S=torch.ones(m)
    #Vh = torch.eye(m)
    #U=torch.einsum('ik,kj->ijk', w02, w01)
    #U=U.reshape((n1*n2,m))

    #U, S, Vh = torch.linalg.svd(M)
    #ss = S.size(0)
    condNumber=torch.tensor([1]) #S[0]/S[-1]
    #Umiss = U[:,ss:]
    Umiss = Umiss.reshape((n2,n1,-1))
    #Greedy
    order_match = torch.argsort(torch.sum(Umiss**2,dim=(1,2)),descending=True)
    perm_Greedy = torch.arange(n2)
    remain = torch.arange(n2)
    for i in order_match:
        i = i.item()
        back = torch.einsum('ijk,ij->k',Umiss,wt[perm_Greedy])
        ind_current = perm_Greedy[i].item()
        score = back-torch.einsum('jk,j->k',Umiss[i,:,:],wt[ind_current,:])
        sc = torch.sum(back**2).item() #torch.zeros(remain.size(0))
        pc = ind_current
        ic = i
        for p in remain:
            p = p.item()
            ip = torch.where(perm_Greedy==p)[0].item()
            x = score+torch.einsum('jk,j->k',Umiss[i,:,:],wt[p,:])+torch.einsum('jk,j->k',Umiss[ip,:,:],wt[ind_current,:])
            x = x-torch.einsum('jk,j->k',Umiss[ip,:,:],wt[p,:])
            x = torch.sum(x**2)
            if x.item() < sc:
                sc = x.item()
                pc = p
                ic = ip
        perm_Greedy[i] = pc
        perm_Greedy[ic] = ind_current
        remain = remain[remain!=pc]
    col_ind = torch.argsort(perm_Greedy)
    #try:
        #U, S, Vh = torch.linalg.svd(M)
        #ss = S.size(0)
        #condNumber=S[0]/S[-1]
        #Umiss = U[:,ss:]
        #Umiss = Umiss.reshape((n2,n1,-1))
        #Greedy
        #order_match = torch.argsort(torch.sum(Umiss**2,dim=(1,2)),descending=True)
        #perm_Greedy = torch.arange(n2)
        #remain = torch.arange(n2)
        #for i in order_match:
            #i = i.item()
            #back = torch.einsum('ijk,ij->k',Umiss,wt[perm_Greedy])
            #ind_current = perm_Greedy[i].item()
            #score = back-torch.einsum('jk,j->k',Umiss[i,:,:],wt[ind_current,:])
            #sc = torch.sum(back**2).item() #torch.zeros(remain.size(0))
            #pc = ind_current
            #ic = i
            #for p in remain:
                #p = p.item()
                #ip = torch.where(perm_Greedy==p)[0].item()
                #x = score+torch.einsum('jk,j->k',Umiss[i,:,:],wt[p,:])+torch.einsum('jk,j->k',Umiss[ip,:,:],wt[ind_current,:])
                #x = x-torch.einsum('jk,j->k',Umiss[ip,:,:],wt[p,:])
                #x = torch.sum(x**2)
                #if x.item() < sc:
                    #sc = x.item()
                    #pc = p
                    #ic = ip
            #perm_Greedy[i] = pc
            #perm_Greedy[ic] = ind_current
            #remain = remain[remain!=pc]
        #col_ind = torch.argsort(perm_Greedy)
    #except:
        #col_ind = torch.arange(n2)
        #condNumber=torch.tensor([0])
        #Vh=0
        #S=0
        #U=0
    return col_ind, condNumber


def proxy_target_gamma1(wt, seed, epochs, m, device):
    n2, n1 = wt.size()
    mfull=n2*n1
    m = min(m,mfull)
    torch.manual_seed(seed)
    w01 = torch.randn([m,n1])/math.sqrt(2*n1)
    w02 = torch.randn([n2,m])/math.sqrt(2*m)
    #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device)
    #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print("best error: ", rmse)
    gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
    wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("project error: ", rmse)
    #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
    gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(wt, w01, w02, gamma1, epochs, device)
    wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("finetune error: ", rmse)
    gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(wt, w01, w02, torch.ones(m), epochs, device)
    wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("LBFGS error: ", rmse)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, wproxy, [w02, w01, gamma1]


def proxy_target_wider(wt, seed, epochs, m, device):
    n2, n1 = wt.size()
    mfull=n2*n1
    #m = min(m,mfull)
    torch.manual_seed(seed)
    w01 = torch.randn([m,n1])/math.sqrt(2*n1)
    w02 = torch.randn([n2,m])/math.sqrt(2*m)
    #gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact(wt, bt, w01, w02, act_scale, act_shift, device)
    #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print("best error: ", rmse)
    #gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
    #wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print("project error: ", rmse)
    #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(wt, w01, w02, gamma1, epochs, device)
    #wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print("finetune error: ", rmse)
    gamma1 = 0.1*torch.randn(m) #torch.ones(m) #torch.randn(m)
    w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
    wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    print("Hidden mean:", [wproxy[0,0],wproxy[1,0],wproxy[5,0],wproxy[5,1],wproxy[3,2]])
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("Error init: ", rmse)
    gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01, 100*w02, gamma1, epochs, device)
    wproxy = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("LBFGS error: ", rmse)
    gamma2 = fc_1to2layer_project_gamma2(100*wt, w01, 100*w02, gamma1, device)
    wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    print("LBFGS error with gamma2: ", rmse)
    #m0 = 2000
    #gamma1_0, bias1, gamma2_0, bias2 = conv_1to2layer_exact_pruned(wt, torch.ones(n2), w01[:m0,:], w02[:,:m0], 1, 5, device)
    #gamma1_0 = fc_1to2layer_exact_project_without_gamma2(wt, w01[:m0,:], w02[:,m0], device)
    #gamma1 = torch.zeros(m)
    #gamma1[:m0] = gamma1_0
    #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(wt, w01, w02, gamma1, epochs, device)
    #gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
    #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print("LBFGS error init: ", rmse)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, torch.ones(n2), w01, w02, 1, 5, gamma1, gamma2, epochs, device)
    #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    #print("LBFGS error both: ", rmse)
    return rmse, wproxy, [w02, w01, gamma1]

def proxy_target_LBFGS_RF(wt, w01, w02, epochs, m, add_dim, device):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) + add_dim
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #compress
        norm2inv = 1/(torch.einsum('imkl->m', w02[:,:m,:,:]**2)*torch.einsum('mjpq->m', w01[:m,:,:,:]**2))
        gamma1[:m] = gamma1[:m] + torch.einsum('inkl,imkl,njpq,mjpq,n,m->m', w02[:,m:,:,:], w02[:,:m,:,:], w01[m:,:,:,:], w01[:m,:,:,:], gamma1[m:],norm2inv)
        gamma1 = gamma1[:m]
        w02 = w02[:,:m,:,:]
        w01 = w01[:m,:,:,:]
        gamma1 = gamma1[:m]
        #retrain compressed network with better initialization
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0) #+ add_dim
        #gamma1 = torch.randn(mwider)
        gamma1 = torch.randn(w01.size(0))
        print("mwider: ", mwider)
        print("m: ", m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.ones(n2,device=device)
        #w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
        #widen features with random linear combinations of given ones
        #GammaMat = torch.eye(m,add_dim)#torch.randn((m,add_dim))
        w02 = w02*100
        w01 = w01*100
        #GammaMat = torch.bernoulli(torch.ones((m,add_dim))*0.05)/(np.sqrt(m)*torch.std(w01)*torch.std(w02)) #torch.randn((m,add_dim))*0.4
        #GammaMat = torch.randn((m,mwider))/(np.sqrt(m)*torch.std(w01)*torch.std(w02))
        #GammaMat = torch.cat([torch.eye(m).float(), torch.bernoulli(torch.ones((m,add_dim))*2/m).float()*torch.randn((m,add_dim))*0.05],dim=1) #torch.randn((m,mwider))/(np.sqrt(m)*torch.std(w01)*torch.std(w02))
        dd = torch.diag(torch.ones(m-1),diagonal=1) - torch.eye(m)
        dd = dd[:,:add_dim].float()
        GammaMat = torch.cat([torch.eye(m).float(),dd*torch.randn((m,add_dim))*0.05],dim=1) #torch.randn((m,mwider))/(np.sqrt(m)*torch.std(w01)*torch.std(w02))
        #print(GammaMat.size())
        #GammaMat = torch.einsum('ij,i->ij', GammaMat, torch.sqrt(torch.sum(GammaMat**2,dim=(1))))
        feat_add = torch.einsum('ma,im,mj->aij', GammaMat, w02[:,:m], w01[:m,:])
        #plt.hist(w02.detach().numpy().flatten())
        #plt.show()
        #plt.hist(feat_add[:2,:,:].detach().numpy().flatten())
        #plt.show()
        #feat_add = torch.randn((add_dim,n2,n1))*0.4
        #feat_add =
        #print(100*wt[0,:15])
        #w01a, w02a = init_He(wt.size(), add_dim, 5)
        #feat_add =  torch.einsum('im,mj->mij', w02a, w01a)
        #w01 = torch.cat([w01, torch.transpose(torch.matmul(torch.transpose(w01,0,1),GammaMat),0,1)])
        #w01 = torch.cat([w01, w01[:add_dim,:]])
        #w02 = torch.cat([w02, torch.matmul(w02, GammaMat)],dim=1)
        #w02 = torch.cat([w02, w02[:,:add_dim]],dim=1)
        #optimize overparameterized problem
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(100*wt, bt, w01, 100*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #print(100*wt[0,:15])
        gamma_add = fc_1to2layer_learn_GD_LBFGS_feat(100*wt, feat_add*(100/np.sqrt(n2)), epochs, device)
        wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print(100*wt[0,:15])
        #gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, feat_add, epochs, device)
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #print(torch.mean((torch.einsum('im,m,mj->ij', w02/np.sqrt(n2), gamma1, w01)+torch.einsum('m,mij->ij', gamma_add, feat_add*/np.sqrt(n2))-wt)**2))
        #print("test", torch.sqrt(torch.mean((torch.matmul(w02[:,1],w01[1,:])-torch.matmul(w02[:,m+1],w01[m+1,:]))**2)))
        #print(torch.sqrt(torch.mean((wt- torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:])  - torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:add_dim], gamma1[m:], w01[:add_dim,:]))**2)).detach() )
        #x = gamma1.clone()
        #x.data[:add_dim] = x.data[:add_dim] + gamma1.data[m:]
        #print(torch.sqrt(torch.mean((wt- torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], x[:m], w01[:m,:]))**2)).detach() )
        #test
        #target = torch.einsum('i,j->ij', w02[:,m], w01[m,:])*gamma1[m]
        #norm2inv = 1/(torch.einsum('im->m', w02[:,:m]**2)*torch.einsum('mj->m', w01[:m,:]**2))
        #print("test")
        #print("sp1", torch.sum(w01[2,:]*w01[3,:]))
        #print("sp2", torch.sum(w02[:,2]*w02[:,3]))
        #print(torch.sqrt(norm2inv[:5]))
        #coeff = torch.einsum('im,i->m', w02[:,:m],w02[:,m])*torch.einsum('mj,j->m',w01[:m,:],w01[m,:])*norm2inv*gamma1[m]
        #proxy = torch.einsum('im,mj,m->ij', w02[:,:m], w01[:m,:], coeff)
        #print(torch.sqrt(torch.mean((target-proxy)**2)))
        #print(torch.sqrt(torch.mean((target)**2)))
        #print(m)
        #print(n2*n1)
        #compress
        #norm2inv = 1/(torch.einsum('im->m', w02[:,:m]**2)*torch.einsum('mj->m', w01[:m,:]**2))
        #gamma1[:m].data = gamma1[:m].data + torch.einsum('in,im,nj,mj,n,m->m', w02[:,m:], w02[:,:m], w01[m:,:], w01[:m,:], gamma1[m:],norm2inv).data
        #gamma1[:m].data = gamma1[:m].data + torch.matmul(GammaMat,gamma1[m:].data)
        #gamma1.data[:add_dim] = gamma1.data[:add_dim] + gamma1.data[m:]
        gamma1.data = torch.matmul(GammaMat,gamma_add.data)
        #print(gamma1[:10])
        #print(gamma1[(m):(m+10)])
        #gamma1[:add_dim].data = gamma1[:add_dim].data*2
        #gamma1 = x[:m].clone() #gamma1[:m]
        #w02 = w02[:,:m]
        #w01 = w01[:m,:]
        #gamma1 = gamma1[:m]
        #plt.hist(gamma1.detach().numpy())
        #plt.show()
        #plt.hist(gamma_add.detach().numpy())
        #plt.show()
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Compression error:", rmse)
        #retrain compressed network with better initialization
        #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01[:m,:], 100*w02[:,:m]/np.sqrt(n2), gamma1, epochs, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01[:m,:], w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def update_features(w02,w01,m,epochs,device):
    mwider = w01.size(0)
    add_dim = mwider-m
    GammaMat = torch.zeros((m,add_dim))
    n2 = w02.size(0)
    n1 = w01.size(1)
    feat_add = torch.zeros((add_dim, n2, n1), requires_grad=False)
    gamma1 = torch.ones(m)/np.sqrt(m)
    for i in range(add_dim):
        wt = torch.outer(w02[:,i+m],w01[i+m,:])
        # gamma1 = torch.randn(m)
        # gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01[:m,:], 100*w02[:,:m], gamma1, epochs, device)
        GammaMat[:,i] = gamma1.data
        #feat_add[i,:,:].data = torch.einsum('m,im,mj->ij', gamma1.data, w02[:,:m], w01[:m,:]).detach()
        xx = torch.einsum('m,im,mj->ij', gamma1.data, w02[:,:m], w01[:m,:])
        feat_add[i,:,:] = xx.detach()
        rmse = torch.sqrt(torch.mean((wt-feat_add[i,:,:])**2)).detach()
        print("err add feat " + str(i) + ": " + str(rmse))
    return GammaMat, feat_add

def proxy_target_LBFGS_extend(wt, w01, w02, epochs, epochsFinetune, m, device):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #compress
        norm2inv = 1/(torch.einsum('imkl->m', w02[:,:m,:,:]**2)*torch.einsum('mjpq->m', w01[:m,:,:,:]**2))
        gamma1[:m] = gamma1[:m] + torch.einsum('inkl,imkl,njpq,mjpq,n,m->m', w02[:,m:,:,:], w02[:,:m,:,:], w01[m:,:,:,:], w01[:m,:,:,:], gamma1[m:],norm2inv)
        gamma1 = gamma1[:m]
        w02 = w02[:,:m,:,:]
        w01 = w01[:m,:,:,:]
        gamma1 = gamma1[:m]
        #retrain compressed network with better initialization
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        gamma1 = torch.randn(w01.size(0))
        print("mwider: ", mwider)
        print("m: ", m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.ones(n2,device=device)
        #w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
        #widen features with random linear combinations of given ones
        #GammaMat = torch.eye(m,add_dim)#torch.randn((m,add_dim))
        w02 = w02*100
        w01 = w01*100
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, device)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01))**2)).detach()
        #wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #find representation of additional features
        GammaMat, feat_add = update_features(w02/np.sqrt(n2),w01,m,epochs,device)
        #with approximate feature representation
        gamma1proxy = gamma1[:m].data + torch.matmul(GammaMat,gamma1[m:].data)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1proxy, w01[:m,:]))**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Proxy compression error:", rmse)
        #finetune error
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01[:m,:], w02[:,:m], act_scale, act_shift, gamma1proxy, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS finetune error:", rmse)
        #
        gamma1 = torch.randn(m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        #gamma1 = gamma1proxy
        gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(100*wt, w01[:m,:], 100/np.sqrt(n2)*w02[:,:m], gamma1, 100*feat_add, epochs, device)
        #features = torch.cat([torch.einsum('im,mj->mij',w02/np.sqrt(n2), w01), feat_add.detach()],dim=0)
        #gamma1 = fc_1to2layer_learn_GD_LBFGS_feat(100*wt, 100*features, epochs, device)
        #gamma_add = fc_1to2layer_learn_GD_LBFGS_feat(100*wt, feat_add*(100/np.sqrt(n2)), epochs, device)
        gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
        #gammaNew = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        #wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            print("na")
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Compression error:", rmse)
        # rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        # gamma1 = gamma1[:m].data + torch.matmul(GammaMat,gamma1[m:].data)
        # #exact add feature representation
        # gamma1.data = torch.matmul(GammaMat,gamma_add.data)
        # rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        #
        #finetune from here
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01[:m,:], w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target_LBFGS_prune_narrow(wt, w01, w02, epochs, epochsFinetune, m, device):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) + add_dim
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #compress
        norm2inv = 1/(torch.einsum('imkl->m', w02[:,:m,:,:]**2)*torch.einsum('mjpq->m', w01[:m,:,:,:]**2))
        gamma1[:m] = gamma1[:m] + torch.einsum('inkl,imkl,njpq,mjpq,n,m->m', w02[:,m:,:,:], w02[:,:m,:,:], w01[m:,:,:,:], w01[:m,:,:,:], gamma1[m:],norm2inv)
        gamma1 = gamma1[:m]
        w02 = w02[:,:m,:,:]
        w01 = w01[:m,:,:,:]
        gamma1 = gamma1[:m]
        #retrain compressed network with better initialization
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        gamma1 = torch.randn(w01.size(0))
        print("mwider: ", mwider)
        print("m: ", m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.ones(n2,device=device)
        #w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
        #widen features with random linear combinations of given ones
        #GammaMat = torch.eye(m,add_dim)#torch.randn((m,add_dim))
        w02 = w02*100
        w01 = w01*100
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2_prune_narrow(m, 100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, epochsFinetune, device)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        if torch.isnan(rmse):
            print("na")
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("narrow error: ", rmse)
        plt.hist(gamma1.detach().numpy())
        plt.show()
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01[:m,:], w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_LBFGS_prune_mag(wt, w01, w02, epochs, epochsFinetune, m, device):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) + add_dim
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #compress
        norm2inv = 1/(torch.einsum('imkl->m', w02[:,:m,:,:]**2)*torch.einsum('mjpq->m', w01[:m,:,:,:]**2))
        gamma1[:m] = gamma1[:m] + torch.einsum('inkl,imkl,njpq,mjpq,n,m->m', w02[:,m:,:,:], w02[:,:m,:,:], w01[m:,:,:,:], w01[:m,:,:,:], gamma1[m:],norm2inv)
        gamma1 = gamma1[:m]
        w02 = w02[:,:m,:,:]
        w01 = w01[:m,:,:,:]
        gamma1 = gamma1[:m]
        #retrain compressed network with better initialization
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        gamma1 = torch.randn(w01.size(0))
        print("mwider: ", mwider)
        print("m: ", m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.ones(n2,device=device)
        #w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
        #widen features with random linear combinations of given ones
        #GammaMat = torch.eye(m,add_dim)#torch.randn((m,add_dim))
        w02 = w02*100
        w01 = w01*100
        #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2_prune_narrow(m, 100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, epochsFinetune, device)
        gamma1, w01, w02 = fc_1to2layer_learn_GD_LBFGS_without_gamma2_prune_mag(m, 100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, epochsFinetune, device)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01))**2)).detach()
        if torch.isnan(rmse):
            print("na")
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("narrow error: ", rmse)
        plt.hist(gamma1.detach().numpy())
        plt.show()
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_LBFGS_overp(wt, w01, w02, epochs, m, add_dim, device):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) + add_dim
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #compress
        norm2inv = 1/(torch.einsum('imkl->m', w02[:,:m,:,:]**2)*torch.einsum('mjpq->m', w01[:m,:,:,:]**2))
        gamma1[:m] = gamma1[:m] + torch.einsum('inkl,imkl,njpq,mjpq,n,m->m', w02[:,m:,:,:], w02[:,:m,:,:], w01[m:,:,:,:], w01[:m,:,:,:], gamma1[m:],norm2inv)
        gamma1 = gamma1[:m]
        w02 = w02[:,:m,:,:]
        w01 = w01[:m,:,:,:]
        gamma1 = gamma1[:m]
        #retrain compressed network with better initialization
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0) + add_dim
        #gamma1 = torch.randn(mwider)
        gamma1 = torch.randn(w01.size(0))
        print("mwider: ", mwider)
        print("m: ", m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.ones(n2,device=device)
        #w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
        #widen features with random linear combinations of given ones
        #GammaMat = torch.eye(m,add_dim)#torch.randn((m,add_dim))
        w02 = w02*100
        w01 = w01*100
        #GammaMat = torch.randn((m,add_dim))*0.3
        GammaMat = torch.bernoulli(torch.ones((m,add_dim))*0.05)
        GammaMat = torch.einsum('ij,i->ij', GammaMat, torch.sum(GammaMat**2,dim=(1))) #torch.einsum('ij,i->ij', GammaMat, torch.sqrt(torch.sum(GammaMat**2,dim=(1))))
        feat_add = torch.einsum('ma,im,mj->aij', GammaMat, w02, w01)
        #
        #print(100*wt[0,:15])
        #w01a, w02a = init_He(wt.size(), add_dim, 5)
        #feat_add =  torch.einsum('im,mj->mij', w02a, w01a)
        #w01 = torch.cat([w01, torch.transpose(torch.matmul(torch.transpose(w01,0,1),GammaMat),0,1)])
        #w01 = torch.cat([w01, w01[:add_dim,:]])
        #w02 = torch.cat([w02, torch.matmul(w02, GammaMat)],dim=1)
        #w02 = torch.cat([w02, w02[:,:add_dim]],dim=1)
        #optimize overparameterized problem
        #gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(100*wt, bt, w01, 100*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        #print(100*wt[0,:15])
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, device)
        print("gamma1: ", torch.mean(gamma1), torch.std(gamma1))
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01))**2)).detach()
        plt.hist(gamma1.detach().numpy())
        plt.show()
        print("narrow error: ", rmse)
        #print(100*wt[0,:15])
        #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma1 = torch.ones(m)
        gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(100*wt, w01, w02*(100/np.sqrt(n2)), gamma1, feat_add*(100/np.sqrt(n2)), epochs, device)
        #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, device)
        #gamma2 = fc_1to2layer_project_gamma2(100*wt, w01, 100*w02, gamma1, device)
        #wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        wproxy = (torch.einsum('im,m,mj->ij', w02, gamma1, w01) + torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #print(100*wt[0,:15])
        #gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, feat_add, epochs, device)
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #print(torch.mean((torch.einsum('im,m,mj->ij', w02/np.sqrt(n2), gamma1, w01)+torch.einsum('m,mij->ij', gamma_add, feat_add*/np.sqrt(n2))-wt)**2))
        #print("test", torch.sqrt(torch.mean((torch.matmul(w02[:,1],w01[1,:])-torch.matmul(w02[:,m+1],w01[m+1,:]))**2)))
        #print(torch.sqrt(torch.mean((wt- torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:])  - torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:add_dim], gamma1[m:], w01[:add_dim,:]))**2)).detach() )
        #x = gamma1.clone()
        #x.data[:add_dim] = x.data[:add_dim] + gamma1.data[m:]
        #print(torch.sqrt(torch.mean((wt- torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], x[:m], w01[:m,:]))**2)).detach() )
        #test
        #target = torch.einsum('i,j->ij', w02[:,m], w01[m,:])*gamma1[m]
        #norm2inv = 1/(torch.einsum('im->m', w02[:,:m]**2)*torch.einsum('mj->m', w01[:m,:]**2))
        #print("test")
        #print("sp1", torch.sum(w01[2,:]*w01[3,:]))
        #print("sp2", torch.sum(w02[:,2]*w02[:,3]))
        #print(torch.sqrt(norm2inv[:5]))
        #coeff = torch.einsum('im,i->m', w02[:,:m],w02[:,m])*torch.einsum('mj,j->m',w01[:m,:],w01[m,:])*norm2inv*gamma1[m]
        #proxy = torch.einsum('im,mj,m->ij', w02[:,:m], w01[:m,:], coeff)
        #print(torch.sqrt(torch.mean((target-proxy)**2)))
        #print(torch.sqrt(torch.mean((target)**2)))
        #print(m)
        #print(n2*n1)
        #compress
        #norm2inv = 1/(torch.einsum('im->m', w02[:,:m]**2)*torch.einsum('mj->m', w01[:m,:]**2))
        #gamma1[:m].data = gamma1[:m].data + torch.einsum('in,im,nj,mj,n,m->m', w02[:,m:], w02[:,:m], w01[m:,:], w01[:m,:], gamma1[m:],norm2inv).data
        #gamma1[:m].data = gamma1[:m].data + torch.matmul(GammaMat,gamma1[m:].data)
        #gamma1.data[:add_dim] = gamma1.data[:add_dim] + gamma1.data[m:]
        gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
        #print(gamma1[:10])
        #print(gamma1[(m):(m+10)])
        #gamma1[:add_dim].data = gamma1[:add_dim].data*2
        #gamma1 = x[:m].clone() #gamma1[:m]
        #w02 = w02[:,:m]
        #w01 = w01[:m,:]
        #gamma1 = gamma1[:m]
        plt.hist(gamma1.detach().numpy())
        plt.show()
        plt.hist(gamma_add.detach().numpy())
        plt.show()
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Compression error:", rmse)
        #retrain compressed network with better initialization
        #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, device)
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def create_RF_gamma(w02, w01, m, epochs, device, filename):
    mwider = w01.size(0)
    add_dim = mwider-m
    GammaMat = torch.zeros((m,add_dim))
    n2 = w02.size(0)
    n1 = w01.size(1)
    feat_add = torch.zeros((add_dim, n2, n1), requires_grad=False)
    gamma1 = torch.ones(m)/np.sqrt(m)
    for i in range(add_dim):
        #could be parallelized
        wt = torch.outer(w02[:,i+m],w01[i+m,:])
        # gamma1 = torch.randn(m)
        # gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01[:m,:], 100*w02[:,:m], gamma1, epochs, device)
        GammaMat[:,i] = gamma1.data
        #feat_add[i,:,:].data = torch.einsum('m,im,mj->ij', gamma1.data, w02[:,:m], w01[:m,:]).detach()
        xx = torch.einsum('m,im,mj->ij', gamma1.data, w02[:,:m], w01[:m,:])
        feat_add[i,:,:] = xx.detach()
        rmse = torch.sqrt(torch.mean((wt-feat_add[i,:,:])**2)).detach()
        print("err add feat " + str(i) + ": " + str(rmse))
    np.savetxt(filename + ".txt", GammaMat, delimiter=",")
    return GammaMat, feat_add

def create_RF_parallel(w02, w01, m, epochs, device):
    mwider = w01.size(0)
    add_dim = mwider-m
    n2 = w02.size(0)
    n1 = w01.size(1)
    #gamma1 = torch.ones(m)/np.sqrt(m)
    def compute_feat(w02l, w01l, w02m, w01m):
        wt = torch.outer(w02l,w01l)
        gamma1 = torch.randn(m)/np.sqrt(m)
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01m, 100*w02m, gamma1, epochs, device)
        #GammaMat[:,i] = gamma1.data
        xx = torch.einsum('m,im,mj->ij', gamma1.data, w02m, w01m)
        #feat_add[i,:,:] = xx.detach()
        return gamma1.detach(), xx.detach()
    #pool = multiprocessing.Pool(4)
    #out1, out2, out3 = zip(*pool.map(calc_stuff, range(0, 10 * offset, offset)))
    #Parallel(n_jobs=6, backend="threading")(delayed(compute_feat)(ind) for ind in range(add_dim))
    results = Parallel(n_jobs=6)(delayed(compute_feat)(w02[:,ind+m], w01[ind+m,:], w02[:,:m], w01[:m,:]) for ind in range(add_dim))
    #np.savetxt(filename + ".txt", GammaMat, delimiter=",")
    GammaMat = torch.zeros((m,add_dim))
    feat_add = torch.zeros((add_dim, n2, n1), requires_grad=False)
    #print(results[0])
    for i in range(len(results)):
        GammaMat[:,i] = results[i][0]
        feat_add[i,:,:] = results[i][1]
    return GammaMat, feat_add

def proxy_target_LBFGS_RF_wider(wt, w01, w02, epochs, epochsFinetune, m, device, newFeat, filename):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse2):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #compress
        norm2inv = 1/(torch.einsum('imkl->m', w02[:,:m,:,:]**2)*torch.einsum('mjpq->m', w01[:m,:,:,:]**2))
        gamma1[:m] = gamma1[:m] + torch.einsum('inkl,imkl,njpq,mjpq,n,m->m', w02[:,m:,:,:], w02[:,:m,:,:], w01[m:,:,:,:], w01[:m,:,:,:], gamma1[m:],norm2inv)
        gamma1 = gamma1[:m]
        w02 = w02[:,:m,:,:]
        w01 = w01[:m,:,:,:]
        gamma1 = gamma1[:m]
        #retrain compressed network with better initialization
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        gamma1 = torch.randn(w01.size(0))
        print("mwider: ", mwider)
        print("m: ", m)
        add_dim = mwider-m
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.ones(n2,device=device)
        #w02 = permute_outer(wt, w02, torch.einsum('m,mj->mj', gamma1, w01))
        #widen features with random linear combinations of given ones
        #GammaMat = torch.eye(m,add_dim)#torch.randn((m,add_dim))
        w02 = w02*100
        w01 = w01*100
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(100*wt, w01, 100*w02/np.sqrt(n2), gamma1, epochs, device)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01))**2)).detach()
        #wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Wider error:", rmse)
        #find representation of additional features
        if newFeat:
            #GammaMat, feat_add = create_RF_gamma(w02/np.sqrt(n2),w01,m,epochs,device, filename)
            GammaMat, feat_add = create_RF_parallel(w02/np.sqrt(n2),w01,m,epochs,device)
        else:
            #GammaMat = torch.tensor(np.loadtxt(filename + ".txt", delimiter=",", dtype=np.float64),requires_grad=False, dtype=torch.double)
            #GammaMat = GammaMat.float()
            GammaMat = torch.tensor(np.loadtxt(filename + ".txt", delimiter=",", dtype=np.float64),requires_grad=False, dtype=torch.float)
            gm, ga = GammaMat.size()
            if gm > m:
                GammaMat = GammaMat[:m,:]*np.sqrt(gm/m)
            elif gm < m:
                ind = torch.multinomial(torch.ones((m-gm,gm)), ga, replacement=True)
                GammaMat = torch.cat([GammaMat, GammaMat[ind]], dim=0)*np.sqrt(gm/m)
            if ga > add_dim:
                GammaMat = GammaMat[:,:add_dim]
            if ga < add_dim:
                distr = GammaMat.flatten()
                addGamma = distr[torch.multinomial(torch.ones(len(distr)), (add_dim-ga)*m, replacement=True)]
                GammaMat = torch.cat([GammaMat, addGamma.reshape((m,add_dim-ga))], dim=1)
            feat_add = torch.einsum('im,mj,ma->aij', w02[:,:m]/np.sqrt(n2), w01[:m,:], GammaMat)
        #with approximate feature representation
        gamma1proxy = gamma1[:m].data + torch.matmul(GammaMat,gamma1[m:].data)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1proxy, w01[:m,:]))**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Proxy compression error:", rmse)
        #finetune error
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01[:m,:], w02[:,:m], act_scale, act_shift, gamma1proxy, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS finetune error:", rmse)
        #
        gamma1 = torch.randn(m)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        #gamma1 = gamma1proxy
        gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(100*wt, w01[:m,:], 100/np.sqrt(n2)*w02[:,:m], gamma1, 100*feat_add, epochs, device)
        #features = torch.cat([torch.einsum('im,mj->mij',w02/np.sqrt(n2), w01), feat_add.detach()],dim=0)
        #gamma1 = fc_1to2layer_learn_GD_LBFGS_feat(100*wt, 100*features, epochs, device)
        #gamma_add = fc_1to2layer_learn_GD_LBFGS_feat(100*wt, feat_add*(100/np.sqrt(n2)), epochs, device)
        gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
        #gammaNew = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        #wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            print("na")
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("Compression error:", rmse)
        # rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        # gamma1 = gamma1[:m].data + torch.matmul(GammaMat,gamma1[m:].data)
        # #exact add feature representation
        # gamma1.data = torch.matmul(GammaMat,gamma_add.data)
        # rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
        #
        #finetune from here
        gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01[:m,:], w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def fc_RF_parallel(w02, w01, m, epochs, device, njobs):
    mwider = w01.size(0)
    add_dim = mwider-m
    n2 = w02.size(0)
    n1 = w01.size(1)
    def compute_feat(w02l, w01l, w02m, w01m):
        wt = torch.outer(w02l,w01l)
        gamma1 = torch.randn(m)/np.sqrt(m)
        fact=10/torch.mean(torch.abs(wt))
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01m, fact*w02m, gamma1, epochs, device)
        feat = torch.einsum('m,im,mj->ij', gamma1.data, w02m, w01m)
        return gamma1.detach(), feat.detach()
    results = Parallel(n_jobs=njobs)(delayed(compute_feat)(w02[:,ind+m], w01[ind+m,:], w02[:,:m], w01[:m,:]) for ind in range(add_dim))
    GammaMat = torch.zeros((m,add_dim))
    feat_add = torch.zeros((add_dim, n2, n1), requires_grad=False)
    for i in range(len(results)):
        GammaMat[:,i] = results[i][0]
        feat_add[i,:,:] = results[i][1]
    return GammaMat, feat_add

def conv_RF_parallel(w02, w01, m, epochs, device, njobs):
    mwider = w01.size(0)
    add_dim = mwider-m
    n2 = w02.size(0)
    n1 = w01.size(1)
    def compute_feat(w02l, w01l, w02m, w01m):
        wt = torch.einsum('ikc,jpq->ijkc', w02l, w01l)
        #wt = wt/torch.sqrt(torch.mean(wt**2))
        gamma1 = torch.randn(m)#np.sqrt(m)
        gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
        fact=1000/torch.mean(torch.abs(wt))
        #print("w: " + str(torch.mean(torch.abs(wt))))
        gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01m, fact*w02m, gamma1, epochs, device)
        feat = torch.einsum('m,imkc,mjpq->ijkc', gamma1.data, w02m, w01m)
        #print(torch.sqrt(torch.mean((feat-wt)**2)))
        feat = feat.reshape([feat.size(0),feat.size(1),-1])
        return gamma1.detach(), feat.detach()
    results = Parallel(n_jobs=njobs)(delayed(compute_feat)(w02[:,ind+m,:,:], w01[ind+m,:,:,:], w02[:,:m,:,:], w01[:m,:,:,:]) for ind in range(add_dim))
    GammaMat = torch.zeros((m,add_dim))
    feat_add = torch.zeros((add_dim, n2, n1, w02.size(2)*w02.size(3)), requires_grad=False)
    for i in range(len(results)):
        GammaMat[:,i] = results[i][0]
        feat_add[i,:,:,:] = results[i][1]
    return GammaMat, feat_add

def proxy_target_LBFGS_RF_parallel(wt, w01, w02, epochs, m, device, njobs):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        add_dim = mwider-m
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        w02 = w02*fact
        w01 = w01*fact
        ###
        fact=100
        #gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        #wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #if torch.isnan(rmse):
        #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("Wider error:", rmse)
        #represent additional random features
        if add_dim>0:
            GammaMat, feat_add = conv_RF_parallel(w02, w01, m, epochs, device, njobs)
            #gamma1.data = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02[:,:m,:,:]/np.sqrt(n2), gamma1, w01[:m,:,:,:])
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error proxy:", rmse)
            #gamma1 = torch.randn(m)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #compress
            w02 = w02[:,:m,:,:]
            w01 = w01[:m,:,:,:]
            gamma2 = torch.ones(n2)/np.sqrt(n2)
            gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01, fact/np.sqrt(n2)*w02, fact/np.sqrt(n2)*feat_add, epochs, device)
            #gamma2, gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:,:,:], 100*w02[:,:m,:,:], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02/np.sqrt(n2), gamma1, w01)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)
            #retrain compressed network with better initialization
        gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        #print("mwider: ", mwider)
        #print("m: ", m)
        add_dim = mwider-m
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        #w02 = w02*100
        #w01 = w01*100
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        w02 = w02*fact
        w01 = w01*fact
        fact=100
        #find representation of additional features
        if add_dim>0:
            GammaMat, feat_add = fc_RF_parallel(w02, w01, m, epochs, device, njobs)
            w01 = w01[:m,:]
            w02 = w02[:,:m]
            gamma1 = torch.randn(m)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1, fact/np.sqrt(n2)*feat_add, epochs, device)
            gamma2 = torch.ones(n2)/np.sqrt(n2)

            #gamma2, gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:], 100*w02[:,:m], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:]))**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)

        #finetune from here
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02/np.sqrt(n2), gamma1, epochs, device)
        gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target_LBFGS_RF_parallel_permute_develop(wt, w01, w02, epochs, m, device, njobs, permute):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        add_dim = mwider-m
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        w02 = w02*fact
        w01 = w01*fact
        ###
        fact=100
        #gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        #wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #if torch.isnan(rmse):
        #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("Wider error:", rmse)
        #represent additional random features
        if add_dim>0:
            GammaMat, feat_add = conv_RF_parallel(w02, w01, m, epochs, device, njobs)
            #gamma1.data = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02[:,:m,:,:]/np.sqrt(n2), gamma1, w01[:m,:,:,:])
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error proxy:", rmse)
            #gamma1 = torch.randn(m)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #compress
            w02 = w02[:,:m,:,:]
            w01 = w01[:m,:,:,:]
            gamma2 = torch.ones(n2)/np.sqrt(n2)
            gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01, fact/np.sqrt(n2)*w02, fact/np.sqrt(n2)*feat_add, epochs, device)
            #gamma2, gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:,:,:], 100*w02[:,:m,:,:], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02/np.sqrt(n2), gamma1, w01)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)
            #retrain compressed network with better initialization
        gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        #print("mwider: ", mwider)
        #print("m: ", m)
        add_dim = mwider-m
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        #w02 = w02*100
        #w01 = w01*100
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        #fact=1
        w02 = w02*fact
        w01 = w01*fact
        fact=100
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device)
        gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        print("Error identity:", rmse)
        #err, _, _ = proxy_target_accurate(wt, 1, epochs, m, "He",  "cpu")
        #dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02[:,:m], w01[:m,:]),p=2)
        #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        #w02 = w02[col_ind]
        mexact = min(full_dim(wt),m)
        rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        print("Error exact:", rmse)
        if permute:
            #gamma1 = torch.randn(mwider)
            #gamma2 = torch.ones(n2)/np.sqrt(n2)
            #gamma1, perm = fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device, njobs)
            def func(x):
                rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], x, epochs, mexact, "cpu")
                return rmse
            out = rand_permute_parallel(w02[:,:mexact], func, njobs, njobs)
            print(out)
            rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[out[1],:mexact], epochs, mexact, "cpu")
            print(rmse)
            print("rand perm LBFGS")

            def func(x):
                gamma1 = torch.randn(mwider)
                gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
                gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact/np.sqrt(n2)*x, gamma1[:m], epochs, device)
                gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*x, act_scale, act_shift, gamma1, torch.ones(n2)/np.sqrt(n2), epochs, device)
                rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, x, gamma1[:m], w01[:m,:]))**2)).detach()
                return rmse

            #out = rand_permute_parallel(w02[:,:mexact], func, njobs, njobs)
            out = rand_permute_parallel(w02[:,:m], func, njobs, njobs)
            print(out)
            #gamma1, perm = fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device, njobs)
            #gamma2, gamma1, perm = fc_1to2layer_learn_GD_LBFGS_permute(fact*wt, w01[:m,:], fact*w02[:,:m], gamma1[:m], gamma2, epochs, device, njobs)
            #print("Permutation matrix")
            #print(perm)
            #w02 = torch.matmul(perm, w02)
            w02 = w02[out[1],:]
            rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
            #wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
            print("Wider error:", rmse)
            gamma1 = torch.randn(mwider)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            gamma2 = torch.ones(n2)/np.sqrt(n2)
            gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1[:m], epochs, device)
            gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1[:m], gamma2, epochs, device)
            rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
            print("Wider error with g2:", rmse)
            #exact after permutation
            rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
            print("Error exact permuted:", rmse)
            w02add = torch.einsum('i,ij->ij', gamma2, w02)
        #find representation of additional features
        if add_dim>0:
            GammaMat, feat_add = fc_RF_parallel(w02add, w01, m, epochs, device, njobs)
            w01 = w01[:m,:]
            w02 = w02[:,:m]
            w02add = w02add[:,:m]
            gamma1 = torch.randn(m)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #w02=w02*np.sqrt(n2)
            #gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1, fact/np.sqrt(n2)*feat_add, epochs, device)
            gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact*w02add[:,:m], gamma1[:m], fact*feat_add, epochs, device)
            #gamma2 = torch.ones(n2)/np.sqrt(n2)

            #gamma2, gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:], 100*w02[:,:m], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            rmse =  torch.sqrt(torch.mean((wt-torch.einsum('im,m,mj->ij', w02add, gamma1, w01))**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            print("Compression error:", rmse)

        #finetune from here
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
        #w02 = w02/np.sqrt(n2)
        gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
        rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        if torch.isnan(rmse):
            rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_LBFGS_precond_fc(wt, w01, w02, epochs, device, precond):
    #fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
    fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
    w02 = w02*fact
    w01 = w01*fact
    fact=100
    n2, m, n1 = w02.size(0), w02.size(1), w01.size(1)
    act_scale=1
    act_shift=0
    ###
    #fact=100
    gamma1 = torch.randn(w02.size(1))
    gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
    if precond:
        #gamma1, gamma2, condNumber = fc_1to2layer_learn_GD_LBFGS_implicit_gamma2(wt, w01, w02, gamma1, epochs, device, precond)
        #gamma2=torch.ones(n2)/math.sqrt(n2)
        gamma1, gamma2, condNumber = fc_1to2layer_learn_GD_LBFGS_cond_expensive(wt, w01, w02, gamma1, epochs, device, precond)
        gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
    else:
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01, fact*w02, act_scale, act_shift, gamma1, torch.ones(n2)/np.sqrt(n2), epochs, device)
        condNumber=torch.tensor([0])
    wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
    rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
    return rmse, condNumber, wproxy

def proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02, epochs, m, device, nperm, njobs):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        add_dim = mwider-m
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        w02 = w02*fact
        w01 = w01*fact
        ###
        fact=100
        #gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        #wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #if torch.isnan(rmse):
        #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("Wider error:", rmse)
        #represent additional random features
        if nperm > 0:
            def func(x):
                gamma1 = torch.randn(m) #torch.randn(mwider)
                gamma1 = 0.1*gamma1/torch.sqrt(torch.sum(gamma1**2))
                gamma2 = torch.ones(n2)/np.sqrt(n2)
                gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:,:,:], fact/np.sqrt(n2)*w02[:,:m,:,:], gamma1[:m], epochs, device)
                gamma1, _, gamma2, _ = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01[:m,:,:,:], fact*w02[:,:m,:,:], act_scale, act_shift, gamma1, gamma2, epochs, device)
                wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02[:,:m,:,:], gamma1, w01[:m,:,:,:])
                rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
                if torch.isnan(rmse):
                    rmse = torch.sqrt(torch.mean(wt**2)).detach()
                return rmse, wproxy, gamma1, gamma2
            out = rand_permute_parallel(w02[:,:m,:,:], func, nperm, njobs)
            #print(out[0][0])
            w02 = w02[out[1],:,:,:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2]
            gamma2 = out[0][3]
        if add_dim>0:
            w02add = torch.einsum('i,ijkl->ijkl', gamma2, w02)
            GammaMat, feat_add = conv_RF_parallel(w02add, w01, m, epochs, device, njobs)
            #gamma1.data = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02[:,:m,:,:]/np.sqrt(n2), gamma1, w01[:m,:,:,:])
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error proxy:", rmse)
            #gamma1 = torch.randn(m)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #compress
            w02add = w02add[:,:m,:,:]
            w02 = w02[:,:m,:,:]
            w01 = w01[:m,:,:,:]
            gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01, fact*w02add, fact*feat_add, epochs, device)
            #gamma2, gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:,:,:], 100*w02[:,:m,:,:], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02/np.sqrt(n2), gamma1, w01)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)
            #retrain compressed network with better initialization
            #gamma2 = torch.ones(n2)/np.sqrt(n2)
            gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            gamma1, _, gamma2, _ = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        #print("mwider: ", mwider)
        #print("m: ", m)
        add_dim = mwider-m
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        #w02 = w02*100
        #w01 = w01*100
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        #fact=1
        w02 = w02*fact
        w01 = w01*fact
        fact=100
        #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device)
        #gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        #print("Error identity:", rmse)
        #err, _, _ = proxy_target_accurate(wt, 1, epochs, m, "He",  "cpu")
        #dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02[:,:m], w01[:m,:]),p=2)
        #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        #w02 = w02[col_ind]
        #mexact = min(full_dim(wt),m)
        #rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        #print("Error exact:", rmse)
        if nperm > 0:
            #gamma1 = torch.randn(mwider)
            #gamma2 = torch.ones(n2)/np.sqrt(n2)
            #gamma1, perm = fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device, njobs)
            #def func(x):
            #    rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], x, epochs, mexact, "cpu")
            #    return rmse
            #out = rand_permute_parallel(w02[:,:mexact], func, njobs, njobs)
            #print(out)
            #rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[out[1],:mexact], epochs, mexact, "cpu")
            #print(rmse)
            #print("rand perm LBFGS")

            def func(x):
                gamma1 = torch.randn(m) #torch.randn(mwider)
                gamma1 = 0.1*gamma1/torch.sqrt(torch.sum(gamma1**2))
                gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact/np.sqrt(n2)*x, gamma1[:m], epochs, device)
                gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*x, act_scale, act_shift, gamma1, torch.ones(n2)/np.sqrt(n2), epochs, device)
                wproxy = torch.einsum('i,im,m,mj->ij', gamma2, x, gamma1[:m], w01[:m,:])
                rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
                return rmse, wproxy, gamma1, gamma2

            #out = rand_permute_parallel(w02[:,:mexact], func, njobs, njobs)
            out = rand_permute_parallel(w02[:,:m], func, nperm, njobs)
            #print(out[0][0])
            #gamma1, perm = fc_1to2layer_learn_GD_LBFGS_without_gamma2_permute(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device, njobs)
            #gamma2, gamma1, perm = fc_1to2layer_learn_GD_LBFGS_permute(fact*wt, w01[:m,:], fact*w02[:,:m], gamma1[:m], gamma2, epochs, device, njobs)
            #print("Permutation matrix")
            #print(perm)
            #w02 = torch.matmul(perm, w02)
            w02 = w02[out[1],:]
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
            #wproxy = (torch.einsum('m,mij->ij', gamma_add, feat_add))/np.sqrt(n2)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Wider error:", rmse)
            #gamma1 = torch.randn(mwider)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #gamma2 = torch.ones(n2)/np.sqrt(n2)
            #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1[:m], epochs, device)
            #gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1[:m], gamma2, epochs, device)
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
            #print("Wider error with g2:", rmse)
            #exact after permutation
            #rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
            #print("Error exact permuted:", rmse)
            #find representation of additional features
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2]
            gamma2 = out[0][3]
        else:
            gamma1 = torch.randn(mwider)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1[:m], epochs, device)
            gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, torch.ones(n2)/np.sqrt(n2), epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:])
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()

        if add_dim>0:
            w02add = torch.einsum('i,ij->ij', gamma2, w02)
            GammaMat, feat_add = fc_RF_parallel(w02add, w01, m, epochs, device, njobs)
            w01 = w01[:m,:]
            w02 = w02[:,:m]
            w02add = w02add[:,:m]
            gamma1 = torch.randn(m)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #w02=w02*np.sqrt(n2)
            #gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1, fact/np.sqrt(n2)*feat_add, epochs, device)
            gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact*w02add[:,:m], gamma1[:m], fact*feat_add, epochs, device)
            #gamma2 = torch.ones(n2)/np.sqrt(n2)

            #gamma2, gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:], 100*w02[:,:m], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('im,m,mj->ij', w02add, gamma1, w01))**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)

            #finetune from here
            gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            #w02 = w02/np.sqrt(n2)
            gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def proxy_target_LBFGS_init(wt, w01, w02, gamma1, M, C, epochs, m, device):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    nperm=0
    njobs=1
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mwider = w01.size(0) #+ add_dim
        add_dim = mwider-m
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        w02 = w02*fact
        w01 = w01*fact
        ###
        fact=100
        #represent additional random features
        if nperm > 0:
            def func(x):
                gamma1 = torch.randn(m) #torch.randn(mwider)
                gamma1 = 0.1*gamma1/torch.sqrt(torch.sum(gamma1**2))
                gamma2 = torch.ones(n2)/np.sqrt(n2)
                gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:,:,:], fact/np.sqrt(n2)*w02[:,:m,:,:], gamma1[:m], epochs, device)
                gamma1, _, gamma2, _ = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01[:m,:,:,:], fact*w02[:,:m,:,:], act_scale, act_shift, gamma1, gamma2, epochs, device)
                wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02[:,:m,:,:], gamma1, w01[:m,:,:,:])
                rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
                if torch.isnan(rmse):
                    rmse = torch.sqrt(torch.mean(wt**2)).detach()
                return rmse, wproxy, gamma1, gamma2
            out = rand_permute_parallel(w02[:,:m,:,:], func, nperm, njobs)
            #print(out[0][0])
            w02 = w02[out[1],:,:,:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2]
            gamma2 = out[0][3]
        if add_dim>0:
            w02add = torch.einsum('i,ijkl->ijkl', gamma2, w02)
            GammaMat, feat_add = conv_RF_parallel(w02add, w01, m, epochs, device, njobs)
            #compress
            w02add = w02add[:,:m,:,:]
            w02 = w02[:,:m,:,:]
            w01 = w01[:m,:,:,:]
            gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01, fact*w02add, fact*feat_add, epochs, device)
            #gamma2, gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:,:,:], 100*w02[:,:m,:,:], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            gamma1, _, gamma2, _ = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        mwider = w01.size(0)
        add_dim = mwider-m
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        #w02 = w02*100
        #w01 = w01*100
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        #fact=1
        w02 = w02*fact
        w01 = w01*fact
        fact=100
        if nperm > 0:
            def func(x):
                gamma1 = torch.randn(m) #torch.randn(mwider)
                gamma1 = 0.1*gamma1/torch.sqrt(torch.sum(gamma1**2))
                gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact/np.sqrt(n2)*x, gamma1[:m], epochs, device)
                gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*x, act_scale, act_shift, gamma1, torch.ones(n2)/np.sqrt(n2), epochs, device)
                wproxy = torch.einsum('i,im,m,mj->ij', gamma2, x, gamma1[:m], w01[:m,:])
                rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
                return rmse, wproxy, gamma1, gamma2

            out = rand_permute_parallel(w02[:,:m], func, nperm, njobs)
            w02 = w02[out[1],:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2]
            gamma2 = out[0][3]
        else:
            #gamma1 = torch.randn(mwider)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #finetuning
            #gamma1 = fc_1to2layer_learn_GD_LBFGS_M(wt, M, gamma1, epochs, device)
            #gamma1 = C@gamma1
            rij = torch.einsum('im,m,mj->ij', w02, gamma1, w01)
            gamma2 = torch.sum(rij*wt,dim=1)/torch.sum(rij**2,dim=1)
            #gamma1 = fc_1to2layer_learn_GD_LBFGS_M(wt, M, C, gamma1, gamma2, epochs, device)
            #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02, gamma1, epochs, device)
            gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()

    return rmse, wproxy, [w02, w01, gamma2, gamma1]

def conv_RF_parallel_accurate(w02, w01, m, epochs, device, njobs):
    mwider = w01.size(0)
    add_dim = mwider-m
    n2 = w02.size(0)
    n1 = w01.size(1)
    def compute_feat(w02l, w01l, w02m, w01m):
        wt = torch.einsum('ikc,jpq->ijkc', w02l, w01l)
        #wt = wt/torch.sqrt(torch.mean(wt**2))
        gamma1 = torch.randn(m)#np.sqrt(m)
        gamma1.data = gamma1.data/torch.sqrt(torch.sum(gamma1.data**2))
        fact=1000/torch.mean(torch.abs(wt))
        #print("w: " + str(torch.mean(torch.abs(wt))))
        gamma1 = conv_1to2layer_project_gamma1(fact*wt, w01m, fact*w02m, torch.ones(wt.size(0)), device)
        gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01m, fact*w02m, gamma1, epochs, device)
        feat = torch.einsum('m,imkc,mjpq->ijkc', gamma1.data, w02m, w01m)
        #print(torch.sqrt(torch.mean((feat-wt)**2)))
        feat = feat.reshape([feat.size(0),feat.size(1),-1])
        return gamma1.detach(), feat.detach()
    results = Parallel(n_jobs=njobs)(delayed(compute_feat)(w02[:,ind+m,:,:], w01[ind+m,:,:,:], w02[:,:m,:,:], w01[:m,:,:,:]) for ind in range(add_dim))
    GammaMat = torch.zeros((m,add_dim))
    feat_add = torch.zeros((add_dim, n2, n1, w02.size(2)*w02.size(3)), requires_grad=False)
    for i in range(len(results)):
        GammaMat[:,i] = results[i][0]
        feat_add[i,:,:,:] = results[i][1]
    return GammaMat, feat_add

def fc_RF_parallel_accurate(w02, w01, m, epochs, device, njobs):
    mwider = w01.size(0)
    add_dim = mwider-m
    n2 = w02.size(0)
    n1 = w01.size(1)
    def compute_feat(w02l, w01l, w02m, w01m):
        wt = torch.outer(w02l,w01l)
        gamma1 = torch.randn(m)/np.sqrt(m)
        fact=10/torch.mean(torch.abs(wt))
        gamma1 = fc_1to2layer_exact_project_without_gamma2(fact*wt, w01m, fact*w02m, device)
        gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01m, fact*w02m, gamma1, epochs, device)
        feat = torch.einsum('m,im,mj->ij', gamma1.data, w02m, w01m)
        return gamma1.detach(), feat.detach()
    results = Parallel(n_jobs=njobs)(delayed(compute_feat)(w02[:,ind+m], w01[ind+m,:], w02[:,:m], w01[:m,:]) for ind in range(add_dim))
    GammaMat = torch.zeros((m,add_dim))
    feat_add = torch.zeros((add_dim, n2, n1), requires_grad=False)
    for i in range(len(results)):
        GammaMat[:,i] = results[i][0]
        feat_add[i,:,:] = results[i][1]
    return GammaMat, feat_add

def proxy_target_accurate_permute(wt, w01, w02, epochs, m, device, nperm, njobs):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mexact = min(full_dim(wt),m)
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        add_dim = mwider-m
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        w02 = w02*fact
        w01 = w01*fact
        ###
        fact=100
        #gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact/np.sqrt(n2)*w02, gamma1, epochs, device)
        #wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
        #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
        #if torch.isnan(rmse):
        #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("Wider error:", rmse)
        #represent additional random features
        #rmse, wproxy, param = proxy_target_init_flexible(wt, w01[:mexact,:], x, epochs, mexact, "cpu")
        if nperm > 0:
            def func(x):
                return  proxy_target_init_flexible(wt, w01[:mexact,:,:,:], x, epochs, mexact, "cpu")
            out = rand_permute_parallel(w02[:,:mexact,:,:], func, nperm, njobs)
            #print(out[0][0])
            w02 = w02[out[1],:,:,:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2][3]
            gamma2 = out[0][2][2]
        else:
            rmse, wproxy, [_, _, gamma2, gamma1] = proxy_target_init_flexible(wt, w01[:mexact,:,:,:], x, epochs, mexact, "cpu")
        if add_dim>0:
            w02add = torch.einsum('i,ijkl->ijkl', gamma2, w02)
            GammaMat, feat_add = conv_RF_parallel_accurate(w02add, w01, m, epochs, device, njobs)
            #gamma1.data = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02[:,:m,:,:]/np.sqrt(n2), gamma1, w01[:m,:,:,:])
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error proxy:", rmse)
            #gamma1 = torch.randn(m)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #compress
            w02add = w02add[:,:m,:,:]
            w02 = w02[:,:m,:,:]
            w01 = w01[:m,:,:,:]
            gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01, fact*w02add, fact*feat_add, epochs, device)
            #gamma2, gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:,:,:], 100*w02[:,:m,:,:], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02/np.sqrt(n2), gamma1, w01)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)
            #retrain compressed network with better initialization
            #gamma2 = torch.ones(n2)/np.sqrt(n2)
            gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            gamma1, _, gamma2, _ = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        mexact = min(full_dim(wt),m)
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        #print("mwider: ", mwider)
        #print("m: ", m)
        add_dim = mwider-m
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        #w02 = w02*100
        #w01 = w01*100
        fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        #fact=1
        w02 = w02*fact
        w01 = w01*fact
        fact=100
        #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device)
        #gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        #print("Error identity:", rmse)
        #err, _, _ = proxy_target_accurate(wt, 1, epochs, m, "He",  "cpu")
        #dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02[:,:m], w01[:m,:]),p=2)
        #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        #w02 = w02[col_ind]
        mexact = min(full_dim(wt),m)
        #rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        #print("Error exact:", rmse)
        if nperm > 0:
            def func(x):
                return  proxy_target_init_flexible(wt, w01[:mexact,:], x, epochs, mexact, "cpu")
            out = rand_permute_parallel(w02[:,:mexact], func, nperm, njobs)
            #print(out[0][0])
            w02 = w02[out[1],:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2][3]
            gamma2 = out[0][2][2]
            w02 = w02[out[1],:]
        else:
            rmse, wproxy, [_, _, gamma2, gamma1] = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
        if add_dim>0:
            w02add = torch.einsum('i,ij->ij', gamma2, w02)
            GammaMat, feat_add = fc_RF_parallel_accurate(w02add, w01, m, epochs, device, njobs)
            w01 = w01[:m,:]
            w02 = w02[:,:m]
            w02add = w02add[:,:m]
            gamma1 = torch.randn(m)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #w02=w02*np.sqrt(n2)
            #gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1, fact/np.sqrt(n2)*feat_add, epochs, device)
            gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact*w02add[:,:m], gamma1[:m], fact*feat_add, epochs, device)
            #gamma2 = torch.ones(n2)/np.sqrt(n2)

            #gamma2, gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:], 100*w02[:,:m], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('im,m,mj->ij', w02add, gamma1, w01))**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)

            #finetune from here
            gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            #w02 = w02/np.sqrt(n2)
            gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target_accurate_permute_precond(wt, w01, w02, epochs, m, device, nperm, njobs):
    random.seed(1)
    torch.manual_seed(1)
    wt[torch.isnan(wt)] = 0
    act_scale = 1
    act_shift = 5
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mexact = min(full_dim(wt),m)
        #mfull=n2*(n1*k1*k2-1)+1
        #mwider = n2*n1*k1*k2+500
        #m = min(m,n2*n1*k1*k2) #min(m,mfull)
        mwider = w01.size(0) #+ add_dim
        add_dim = mwider-m
        #initialize trainable parameters in a trainable way
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        bt = torch.zeros(n2,device=device)
        #fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        #w02 = w02*fact
        #w01 = w01*fact
        ###
        #fact=100
        if nperm > 0:
            def func(x):
                return  proxy_target_init_flexible(wt, w01[:mexact,:,:,:], x, epochs, mexact, "cpu")
            out = rand_permute_parallel(w02[:,:mexact,:,:], func, nperm, njobs)
            #print(out[0][0])
            w02 = w02[out[1],:,:,:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2][3]
            gamma2 = out[0][2][2]
        else:
            rmse, wproxy, [_, _, gamma2, gamma1] = proxy_target_init_flexible(wt, w01[:mexact,:,:,:], x, epochs, mexact, "cpu")
        if add_dim>0:
            w02add = torch.einsum('i,ijkl->ijkl', gamma2, w02)
            GammaMat, feat_add = conv_RF_parallel_accurate(w02add, w01, m, epochs, device, njobs)
            #gamma1.data = gamma1.data[:m] + torch.matmul(GammaMat,gamma1.data[m:])
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02[:,:m,:,:]/np.sqrt(n2), gamma1, w01[:m,:,:,:])
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error proxy:", rmse)
            #gamma1 = torch.randn(m)
            #gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #compress
            w02add = w02add[:,:m,:,:]
            w02 = w02[:,:m,:,:]
            w01 = w01[:m,:,:,:]
            gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01, fact*w02add, fact*feat_add, epochs, device)
            #gamma2, gamma1, gamma_add = conv_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:,:,:], 100*w02[:,:m,:,:], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #wproxy = torch.einsum('imkc,m,mjlb->ijkc', w02/np.sqrt(n2), gamma1, w01)
            #rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)
            #retrain compressed network with better initialization
            #gamma2 = torch.ones(n2)/np.sqrt(n2)
            gamma1 = conv_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            gamma1, _, gamma2, _ = conv_1to2layer_learn_GD_LBFGS(fact*wt, bt, w01, fact*w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
        #print("LBFGS error:", rmse)
    else:
        #print(100*wt[0,:15])
        n2, n1 = wt.size()
        mexact = min(full_dim(wt),m)
        #mfull=n2*(n1-1)+1
        #m = min(m,n2*n1)
        mwider = w01.size(0)
        #gamma1 = torch.randn(mwider)
        #print("mwider: ", mwider)
        #print("m: ", m)
        add_dim = mwider-m
        gamma1 = torch.randn(mwider)
        gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
        gamma2 = torch.ones(n2)/np.sqrt(n2)
        #w02 = w02*100
        #w01 = w01*100
        #fact = 2/(torch.mean(torch.abs(w02)) + torch.mean(torch.abs(w01)))
        #fact=1
        #w02 = w02*fact
        #w01 = w01*fact
        #fact=100
        #print("condition nbr w01")
        #U1, S1, Vh1 = torch.linalg.svd(w01)
        #print(S1[0]/S1[-1])
        #print("condition nbr w02")
        #U2, S2, Vh2 = torch.linalg.svd(w02)
        #print(S2[0]/S2[-1])
        #wt = (1/S2)*torch.transpose(U2,0,1)@wt@torch.transpose(Vh1,0,1)/S1
        #w01 = U1[:,:len(S1)]
        #w02 = Vh2[:len(S2),:]
        #C = torch.transpose(Vh,0,1)[:,:ss] * St[:ss]/S @ Vh[:ss,:]

        #gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01[:m,:], fact*w02[:,:m]/np.sqrt(n2), gamma1[:m], epochs, device)
        #gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        #print("Error identity:", rmse)
        #err, _, _ = proxy_target_accurate(wt, 1, epochs, m, "He",  "cpu")
        #dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02[:,:m], w01[:m,:]),p=2)
        #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
        #w02 = w02[col_ind]
        mexact = min(full_dim(wt),m)
        #rmse, _, _ = proxy_target_init_flexible(wt, w01[:mexact,:], w02[:,:mexact], epochs, mexact, "cpu")
        #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1[:m], w01[:m,:]))**2)).detach()
        #print("Error exact:", rmse)
        if nperm > 0:
            def func(x):
                return  proxy_target_init_flexible(wt, w01[:mexact,:], x, epochs, mexact, "cpu")
            out = rand_permute_parallel(w02[:,:mexact], func, nperm, njobs)
            #print(out[0][0])
            w02 = w02[out[1],:]
            rmse = out[0][0]
            wproxy = out[0][1]
            gamma1 = out[0][2][3]
            gamma2 = out[0][2][2]
            w02 = w02[out[1],:]
        else:
            rmse, wproxy, [_, _, gamma2, gamma1] = proxy_target_init_flexible(wt, w01[:mexact,:], x, epochs, mexact, "cpu")

        if add_dim>0:
            w02add = torch.einsum('i,ij->ij', gamma2, w02)
            GammaMat, feat_add = fc_RF_parallel_accurate(w02add, w01, m, epochs, device, njobs)
            w01 = w01[:m,:]
            w02 = w02[:,:m]
            w02add = w02add[:,:m]
            gamma1 = torch.randn(m)
            gamma1 = gamma1/torch.sqrt(torch.sum(gamma1**2))
            #w02=w02*np.sqrt(n2)
            #gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact/np.sqrt(n2)*w02[:,:m], gamma1, fact/np.sqrt(n2)*feat_add, epochs, device)
            gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_without_gamma2_wider(fact*wt, w01[:m,:], fact*w02add[:,:m], gamma1[:m], fact*feat_add, epochs, device)
            #gamma2 = torch.ones(n2)/np.sqrt(n2)

            #gamma2, gamma1, gamma_add = fc_1to2layer_learn_GD_LBFGS_wider(100*wt, w01[:m,:], 100*w02[:,:m], 100*feat_add, epochs, device)
            gamma1.data = gamma1.data + torch.matmul(GammaMat,gamma_add.data)
            #rmse =  torch.sqrt(torch.mean((wt-torch.einsum('im,m,mj->ij', w02add, gamma1, w01))**2)).detach()
            #if torch.isnan(rmse):
            #    print("na")
            #    rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("Compression error:", rmse)

            #finetune from here
            gamma1 = fc_1to2layer_learn_GD_LBFGS_without_gamma2(fact*wt, w01, fact*w02add, gamma1, epochs, device)
            #w02 = w02/np.sqrt(n2)
            gamma1, _, gamma2, _ = fc_1to2layer_learn_GD_LBFGS(fact*wt, torch.zeros(n2,device=device), w01[:m,:], fact*w02[:,:m], act_scale, act_shift, gamma1, gamma2, epochs, device)
            wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02[:,:m], gamma1, w01[:m,:])
            rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
            if torch.isnan(rmse):
                rmse = torch.sqrt(torch.mean(wt**2)).detach()
            #print("LBFGS error:", rmse)
    return rmse, wproxy, [w02, w01, gamma2, gamma1]


def proxy_target_LBFGS_overparam(wt, seed, epochs, m, init, device, add_dim, nperm, njobs, accurate, precond):
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    mfull = full_dim_wide(wt)
    #m = min(m, mfull)
    mwider = mfull + add_dim
    if init in ["He", "ortho", "ER", "uni", "ER_norm", "ER_sign", "sign", "precond", "id"]:
        if init == "He":
            #w01, w02 = init_He(wt.size(), mwider, seed)
            w01, w02 = init_He(wt.size(), m+add_dim, seed)
        elif init == "ortho":
            w01, w02 = init_ortho(wt.size(), m+add_dim, seed)
        elif init == "uni":
            w01, w02 = init_uni(wt.size(), m+add_dim, seed)
        elif init == "precond":
            w01, w02 = init_precond(wt.size(), m+add_dim, seed)
        elif init == "id":
            w01, w02 = init_id(wt.size(), m+add_dim, seed)
        elif init == "ER":
            w01, w02 = init_ER(wt.size(), m+add_dim, seed)
        elif init == "ER_norm":
            w01, w02 = init_ER_norm(wt.size(), m+add_dim, seed)
        elif init == "ER_sign":
            w01, w02 = init_ER_sign(wt.size(), m+add_dim, seed)
        elif init == "sign":
            w01, w02 = init_sign(wt.size(), m+add_dim, seed)
        else: #if init == "aij":
            w01, w02 = init_aij(wt.size(), m+add_dim, seed)
        if len(wt.size())>2:
            try:
                dist_mat_rows = torch.cdist(wt.reshape(wt.size(0),-1),torch.einsum('ikmn,kjlp->ijmn', w02, w01).reshape(wt.size(0),-1),p=2)
                row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                w02 = w02[col_ind]
            except ValueError:
                w02 = w02
        else:
            try:
                precond=True #False#False#True
                if precond:
                    if init in ["He", "ortho", "uni", "sign", "precond"]:
                        w02, gamma1 = svd_permute_fc(w01,w02,wt)
                    elif init in ["id"]:
                        rmse, w02, col_ind = id_permute(w01,w02,wt)
                    else:
                        dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
                        row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                        w02 = w02[col_ind]
                else:
                    dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
                    row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                    w02 = w02[col_ind]

            except ValueError:
                w02 = w02
        #rmse, wproxy, x = proxy_target_LBFGS_overp(wt, w01, w02, epochs, m, add_dim, device)
        #rmse, wproxy, x = proxy_target_LBFGS_RF(wt, w01, w02, epochs, m, add_dim, device)
        #rmse, wproxy, x = proxy_target_LBFGS_prune_narrow(wt, w01, w02, epochs, 50, m, device)
        #rmse, wproxy, x = proxy_target_LBFGS_prune_mag(wt, w01, w02, epochs, 50, m, device)
        #rmse, wproxy, x = proxy_target_LBFGS_extend(wt, w01, w02, epochs, 100, m, device)
        #rmse, wproxy, x = proxy_target_LBFGS_RF_wider(wt, w01, w02, epochs, 100, m, device, True, "GammaMat10")
        #rmse, wproxy, x = proxy_target_LBFGS_RF_parallel(wt, w01, w02, epochs, m, device, njobs)
        #rmse, wproxy, x = proxy_target_LBFGS_RF_parallel_permute_develop(wt, w01, w02, epochs, m, device, njobs, True)
        if accurate:
            precond=False#True
            if precond:
                rmse, wproxy, x = proxy_target_accurate_permute_precond(wt, w01, w02, epochs, m, device, nperm, njobs)
            else:
                rmse, wproxy, x = proxy_target_accurate_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
        else:
            rmse, wproxy, x = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
            #precond=True
            #rmse, wproxy, x = proxy_target_LBFGS_precond_fc(wt, w01, w02, epochs, device, precond)
        #print(rmse)
        #print("accurate:")
        #rmse, wproxy, x = proxy_target_accurate_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
    else:
        #rmse, wproxy, x = proxy_target_ER_equiv_permute(wt, seed, epochs, m, device)
        rmse, wproxy, x = proxy_target_ER_equiv_rand_permute(wt, seed, epochs, m, nperm, device)
    return rmse, wproxy, x

def proxy_target_LBFGS_precond(wt, seed, epochs, m, init, device, accurate):
    nperm=0
    njobs=1
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    mfull = full_dim_wide(wt)
    #m = min(m, mfull)
    mwider = mfull
    if init in ["He", "ortho", "ER", "uni", "ER_norm", "ER_sign", "sign", "precond", "id"]:
        if init == "He":
            #w01, w02 = init_He(wt.size(), mwider, seed)
            w01, w02 = init_He(wt.size(), m, seed)
        elif init == "ortho":
            w01, w02 = init_ortho(wt.size(), m, seed)
        elif init == "uni":
            w01, w02 = init_uni(wt.size(), m, seed)
        elif init == "precond":
            w01, w02 = init_precond(wt.size(), m, seed)
        elif init == "id":
            w01, w02 = init_id(wt.size(), m, seed)
        elif init == "ER":
            w01, w02 = init_ER(wt.size(), m, seed)
        elif init == "ER_norm":
            w01, w02 = init_ER_norm(wt.size(), m, seed)
        elif init == "ER_sign":
            w01, w02 = init_ER_sign(wt.size(), m, seed)
        elif init == "sign":
            w01, w02 = init_sign(wt.size(), m, seed)
        else: #if init == "aij":
            w01, w02 = init_aij(wt.size(), m, seed)
        if len(wt.size())>2:
            try:
                dist_mat_rows = torch.cdist(wt.reshape(wt.size(0),-1),torch.einsum('ikmn,kjlp->ijmn', w02, w01).reshape(wt.size(0),-1),p=2)
                row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                w02 = w02[col_ind]
            except ValueError:
                w02 = w02
        else:
            #print(accurate)
            if accurate:
                #precond=True
                #if precond:
                err, _, _ = proxy_target_init_flexible_fc_precond(wt, w01, w02, epochs, m, device, False)
                errCond, _, condNumber = proxy_target_init_flexible_fc_precond(wt, w01, w02, epochs, m, device, True)
                print("condNumber: ", condNumber)
                #errCond, _, _ = proxy_target_accurate_permute_precond(wt, w01, w02, epochs, m, device, nperm, njobs)
                #else:
                #err, _, _ = proxy_target_accurate_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
                if init in ["id"]:
                    errPermCond, _, col_ind = id_permute(w01,w02,wt)
                    errPermCond = torch.tensor([errPermCond])
                    condNumber = torch.tensor([1])
                    errPerm, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, False)
                elif init in ["precond"]:
                    w01, w02, Umiss = init_precond_umiss(wt.size(), m, seed)
                    if Umiss.numel() > 0:
                        col_ind, condNumber = svd_permute_precond_fc(Umiss,wt)
                        errPerm, _, _ = proxy_target_init_flexible_fc_precond(wt, w01, w02[col_ind], epochs, m, device, False)
                        errPermCond, _, condNumber = proxy_target_init_flexible_fc_precond(wt, w01, w02[col_ind], epochs, m, device, True)
                    else:
                        errPerm = err
                        errPermCond = errCond
                        condNumber=torch.tensor([1])
                else:
                    col_ind, condNumber, _, _, _ = svd_permute_cond_fc(w01,w02,wt)
                    #errPerm, _, _ = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02[col_ind], epochs, m, device, 0, 1)
                    #errPerm, _, _ = proxy_target_accurate_permute(wt, w01, w02[col_ind], epochs, m, device, nperm, njobs)
                    #errPermCond, _, _ = proxy_target_accurate_permute_precond(wt, w01, w02[col_ind], epochs, m, device, nperm, njobs)
                    errPerm, _, _ = proxy_target_init_flexible_fc_precond(wt, w01, w02[col_ind], epochs, m, device, False)
                    errPermCond, _, condNumber = proxy_target_init_flexible_fc_precond(wt, w01, w02[col_ind], epochs, m, device, True)
                    print("condNumber permuted: ", condNumber)
            else:
                #normal run without preconditioning or permutation
                #err, _, _ = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02, epochs, m, device, 0, 1)
                err, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02, epochs, device, False)
                #conditioning without permutation
                errCond, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02, epochs, device, True)
                #permutation, no special init/conditioning
                if init in ["id"]:
                    errPermCond, _, col_ind = id_permute(w01,w02,wt)
                    errPermCond = torch.tensor([errPermCond])
                    condNumber = torch.tensor([1])
                    errPerm, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, False)
                elif init in ["precond"]:
                    w01, w02, Umiss = init_precond_umiss(wt.size(), m, seed)
                    #print(Umiss)
                    #print(Umiss.size())
                    if Umiss.numel() > 1:
                        col_ind, condNumber = svd_permute_precond_fc(Umiss,wt)
                        errPerm, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, False)
                        errPermCond, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, True)
                    else:
                        errPerm = err
                        errPermCond = errCond
                        condNumber=torch.tensor([1])
                else:
                    col_ind, condNumber, _, _, _ = svd_permute_cond_fc(w01,w02,wt)
                    #errPerm, _, _ = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02[col_ind], epochs, m, device, 0, 1)
                    errPerm, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, False)
                    #permutation with better init from svd + conditioning
                    #ss=S.size(0)
                    #C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
                    #perm_inverse = torch.argsort(col_ind)
                    #gamma1 =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt[perm_inverse].reshape((-1,))
                    #M = torch.einsum('ik,kj->ijk', w02[col_ind], w01)
                    #M = M.reshape((wt.numel(),m))
                    #M = M@C
                    #gamma1 =  torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt[perm_inverse].reshape((-1,))
                    #M=0
                    #C=0
                    #errPermCond, _, _ = proxy_target_LBFGS_init(wt, w01, w02[col_ind], gamma1, M, C, epochs, m, device)
                    errPermCond, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, True)
                    #better init from svd + conditioning
                    #gamma1 =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
                    #M = torch.einsum('ik,kj->ijk', w02, w01)
                    #M = M.reshape((wt.numel(),m))
                    #M = M@C
                    #gamma1 =  torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
                    #errCond, _, _ = proxy_target_LBFGS_init(wt, w01, w02, gamma1, M, C, epochs, m, device)

                # try:
                #     precond=True
                #     if precond:
                #         if init in ["He", "ortho", "uni", "sign", "precond"]:
                #             w02, gamma1 = svd_permute_fc(w01,w02,wt)
                #         elif init in ["id"]:
                #             rmse, w02 = id_permute(w01,w02,wt)
                #         else:
                #             dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
                #             row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                #             w02 = w02[col_ind]
                #     else:
                #         dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
                #         row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                #         w02 = w02[col_ind]
                #
                # except ValueError:
                #     w02 = w02
        #else:
        #    rmse, wproxy, x = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
    else:
        rmse, _, _ = proxy_target_ER_equiv_rand_permute(wt, seed, epochs, m, nperm, device)
    return err, errCond, errPerm, errPermCond, condNumber

def proxy_target_LBFGS_precond_noPerm(wt, seed, epochs, m, init, device, accurate):
    nperm=0
    njobs=1
    dd = torch.numel(wt)
    random.seed(seed)
    torch.manual_seed(seed)
    mfull = full_dim_wide(wt)
    #m = min(m, mfull)
    mwider = mfull
    if init in ["He", "ortho", "ER", "uni", "ER_norm", "ER_sign", "sign", "precond", "id"]:
        if init == "He":
            #w01, w02 = init_He(wt.size(), mwider, seed)
            w01, w02 = init_He(wt.size(), m, seed)
        elif init == "ortho":
            w01, w02 = init_ortho(wt.size(), m, seed)
        elif init == "uni":
            w01, w02 = init_uni(wt.size(), m, seed)
        elif init == "precond":
            w01, w02 = init_precond(wt.size(), m, seed)
        elif init == "id":
            w01, w02 = init_id(wt.size(), m, seed)
        elif init == "ER":
            w01, w02 = init_ER(wt.size(), m, seed)
        elif init == "ER_norm":
            w01, w02 = init_ER_norm(wt.size(), m, seed)
        elif init == "ER_sign":
            w01, w02 = init_ER_sign(wt.size(), m, seed)
        elif init == "sign":
            w01, w02 = init_sign(wt.size(), m, seed)
        else: #if init == "aij":
            w01, w02 = init_aij(wt.size(), m, seed)
        if len(wt.size())>2:
            try:
                dist_mat_rows = torch.cdist(wt.reshape(wt.size(0),-1),torch.einsum('ikmn,kjlp->ijmn', w02, w01).reshape(wt.size(0),-1),p=2)
                row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
                w02 = w02[col_ind]
            except ValueError:
                w02 = w02
        else:
            #normal run without preconditioning or permutation
            #err, _, _ = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02, epochs, m, device, 0, 1)
            err, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02, epochs, device, False)
            #conditioning without permutation
            errCond, condNumber, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02, epochs, device, True)
            #permutation, no special init/conditioning
            # if init in ["id"]:
            #     errPermCond, _, col_ind = id_permute(w01,w02,wt)
            #     errPermCond = torch.tensor([errPermCond])
            #     condNumber = torch.tensor([1])
            #     errPerm, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, False)
            # else:
            #     col_ind, condNumber, Vh, S, U = svd_permute_cond_fc(w01,w02,wt)
            #     #errPerm, _, _ = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02[col_ind], epochs, m, device, 0, 1)
            #     errPerm, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, False)
                #permutation with better init from svd + conditioning
                #ss=S.size(0)
                #C = torch.transpose(Vh,0,1)[:,:ss]/S @ Vh[:ss,:]
                #perm_inverse = torch.argsort(col_ind)
                #gamma1 =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt[perm_inverse].reshape((-1,))
                #M = torch.einsum('ik,kj->ijk', w02[col_ind], w01)
                #M = M.reshape((wt.numel(),m))
                #M = M@C
                #gamma1 =  torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt[perm_inverse].reshape((-1,))
                #M=0
                #C=0
                #errPermCond, _, _ = proxy_target_LBFGS_init(wt, w01, w02[col_ind], gamma1, M, C, epochs, m, device)
                #errPermCond, _, _ = proxy_target_LBFGS_precond_fc(wt, w01, w02[col_ind], epochs, device, True)
                #better init from svd + conditioning
                #gamma1 =  torch.transpose(Vh,0,1)[:,:ss] @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
                #M = torch.einsum('ik,kj->ijk', w02, w01)
                #M = M.reshape((wt.numel(),m))
                #M = M@C
                #gamma1 =  torch.transpose(Vh,0,1)[:,:ss]/S @ torch.transpose(U,0,1)[:ss,:] @ wt.reshape((-1,))
                #errCond, _, _ = proxy_target_LBFGS_init(wt, w01, w02, gamma1, M, C, epochs, m, device)

            # try:
            #     precond=True
            #     if precond:
            #         if init in ["He", "ortho", "uni", "sign", "precond"]:
            #             w02, gamma1 = svd_permute_fc(w01,w02,wt)
            #         elif init in ["id"]:
            #             rmse, w02 = id_permute(w01,w02,wt)
            #         else:
            #             dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
            #             row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
            #             w02 = w02[col_ind]
            #     else:
            #         dist_mat_rows = torch.cdist(wt,torch.einsum('ik,kj->ij', w02, w01),p=2)
            #         row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
            #         w02 = w02[col_ind]
            #
            # except ValueError:
            #     w02 = w02
        if accurate:
            err, _, _ = proxy_target_init_flexible_fc_precond(wt, w01, w02, epochs, m, device, False)
            errCond, _, condNumber = proxy_target_init_flexible_fc_precond(wt, w01, w02, epochs, m, device, True)
            print("condNumber: ", condNumber)
            #errCond, _, _ = proxy_target_accurate_permute_precond(wt, w01, w02, epochs, m, device, nperm, njobs)
            #precond=False
            #if precond:
                #rmse, wproxy, x = proxy_target_accurate_permute_precond(wt, w01, w02, epochs, m, device, nperm, njobs)
            #else:
                #rmse, wproxy, x = proxy_target_accurate_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
        #else:
        #    rmse, wproxy, x = proxy_target_LBFGS_RF_parallel_permute(wt, w01, w02, epochs, m, device, nperm, njobs)
    else:
        rmse, _, _ = proxy_target_ER_equiv_rand_permute(wt, seed, epochs, m, nperm, device)
    return err, errCond, condNumber


#permute

#prune
    # REVIEW:

# def find_max(k,w02,w01):
#     i = torch.argmax(torch.abs(w02[:,k]))
#     j = torch.argmax(torch.abs(w01[k,:]))
#     return np.array([i,j])
#
# def proxy_target_LBFGS_init(wt, w01, w02, epochs, m, device):
#     random.seed(1)
#     torch.manual_seed(1)
#     wt[torch.isnan(wt)] = 0
#     act_scale = 1.0
#     act_shift = 0.0
#     if len(wt.size()) > 2:
#         n2, n1, k1, k2 = wt.size()
#         mfull=n2*(n1*k1*k2-1)+1
#         #print(m)
#         m = min(m,mfull)
#         bt = torch.ones(n2,device=device)
#
#
#         #optimal permutation for gamma2 = ones
#         #maskt = torch.where(torch.abs(wt)>0.1*torch.mean(torch.abs(wt)),1.0,0.0)
#         #dist_mat_rows = torch.cdist(wt,torch.einsum('ikmn,kjlb->ijmn', w02, w01),p=2)
#         #row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
#         #w02 = w02[col_ind]
#         gamma1 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
#         gamma2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
#
#
#         gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
#         wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
#         rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#         if torch.isnan(rmse):
#             #rmse2 = 100000
#             gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
#             wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
#             rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#     else:
#         n2, n1 = wt.size()
#         mfull=n2*(n1-1)+1
#         #m = min(m,mfull)
#         #permute out neurons assuming gamma2=ones
#         #w02 = permute_outer(wt, w02, w01)
#         #assume that gamma1 matches wt values with largest possible in wproxy for a single gamma1
#         wpr = np.zeros(wt.size())
#         ind_gamma = [find_max(k,w02,w01) for k in range(m)]
#         #for k in range(m):
#         #    i = torch.argmax(torch.abs(w02[:,k]))
#         #    j = torch.argmax(torch.abs(w01[k,:]))
#         #    wpr[i,j] = 1
#         wpr[ind_gamma] = 1
#         #w02 = permute_outer_match(wt, w02, w01)
#         #dist_mat_rows = torch.cdist(wt,wproxy,p=2)
#         dist_mat_rows = torch.abst(wt)*wpr
#         dist_mat_rows = scipy.spatial.distance.cdist(wt.detach().numpy(),wpr,lambda x, y: np.sum(np.abs(x*y)))
#         dist_mat_rows = np.max(dist_mat_rows)-dist_mat_rows
#         row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
#         w02 = w02[col_ind]
#         #match largest weights first
#         torch.sort(torch.abs(wt),descending=True)
#
#         #gamma1 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
#         #gamma2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1, device)
#         bt = torch.ones(n2,device=device)
#         gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
#         wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
#         rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#         if torch.isnan(rmse):
#             #rmse2 = 100000
#             gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, torch.ones(m), torch.ones(n2), epochs, device)
#             wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
#             rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#     return rmse, wproxy, [w02, w01, gamma2, gamma1]

def full_dim(wt):
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*(n1*k1*k2-1)+1
    else:
        n2, n1 = wt.size()
        mfull=n2*(n1-1)+1
    return mfull

def full_dim_wide(wt):
    if len(wt.size()) > 2:
        n2, n1, k1, k2 = wt.size()
        mfull=n2*n1*k1*k2
    else:
        n2, n1 = wt.size()
        mfull=n2*n1
    return mfull

def permute_outer(wt, w02, w01):
    if len(wt.size())>2:
        try:
            wt = wt.reshape(wt.size(0),-1)
            wt = torch.einsum('ij,i->ij', wt, torch.sum(wt**2,dim=1))
            wproxy = torch.einsum('ikmn,kjlp->ijmn', w02, w01).reshape(wt.size(0),-1)
            wproxy = torch.einsum('ij,i->ij', wproxy, torch.sum(wproxy**2,dim=1))
            dist_mat_rows = torch.cdist(wt,wproxy,p=2)
            row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
            w02 = w02[col_ind]
        except ValueError:
            w02 = w02
    else:
        try:
            wt = torch.einsum('ij,i->ij', wt, torch.sum(wt**2,dim=1))
            wproxy = torch.einsum('ik,kj->ij', w02, w01)
            wproxy = torch.einsum('ij,i->ij', wproxy, torch.sum(wproxy**2,dim=1))
            dist_mat_rows = torch.cdist(wt,wproxy,p=2)
            row_ind, col_ind = linear_sum_assignment(dist_mat_rows.cpu().detach().numpy())
            w02 = w02[col_ind]
        except ValueError:
            w02 = w02
    return w02

# def init_LBFGS(wt, w02, w01, nsub):
#     epochs = 50
#     nsub = min(nsub,wt.size(0))
#     #if len(wt.size()) > 2:
#         # wt.reshape(wt.size(0),-1)
#         # w02.reshape(w02.size(0),-1)
#         # w01.reshape(w01.size(0),-1)
#     #else:
#     for i in range(nsub):
#         _, _, out = proxy_target_init_flexible(wt[sub,:], w01[sub1,:], w02[sub,sub1], epochs, len(sub1), "cpu")
#         gamma2[sub] = out[2]
#         gamma1[sub1] = out[3]
#
#     return gamma1, gamma2


# def proxy_target_init_flexible(wt, w01, w02, epochs, m, device):
#     random.seed(1)
#     wt[torch.isnan(wt)] = 0
#     act_scale = 1
#     act_shift = 5
#     if len(wt.size()) > 2:
#         n2, n1, k1, k2 = wt.size()
#         mfull=n2*(n1*k1*k2-1)+1
#         #print(m)
#         m = min(m,mfull)
#         torch.manual_seed(1)
#         bt = torch.ones(n2,device=device)
#         gamma1, bias1, gamma2, bias2 = conv_1to2layer_exact_pruned(wt, bt, w01, w02, act_scale, act_shift, device)
#         gamma1, bias1, gamma2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
#         wproxy = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2, w02, gamma1, w01)
#         rmse = torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#         if torch.isnan(rmse):
#             rmse = torch.sqrt(torch.mean(wt**2)).detach()
#         #Alternative
#         gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, gamma2, device)
#         gamma1_2, bias1, gamma2_2, bias2 = conv_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
#         wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#         gamma1_2 = conv_1to2layer_project_gamma1(wt, w01, w02, torch.ones(wt.size(0)), device)
#         gamma2_2 = conv_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
#         wproxy2 = torch.einsum('i,ikmn,k,kjlb->ijmn', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#     else:
#         n2, n1 = wt.size()
#         mfull=n2*(n1-1)+1
#         m = min(m,mfull)
#         torch.manual_seed(1)
#         bt = torch.ones(n2,device=device)
#         #Alternative 1:
#         gamma1, bias1, gamma2, bias2 = fc_1to2layer_exact_remove_conditions(wt, bt, w01, w02, act_scale, act_shift, device)
#         gamma1, bias1, gamma2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1, gamma2, epochs, device)
#         wproxy = torch.einsum('i,im,m,mj->ij', gamma2, w02, gamma1, w01)
#         rmse =  torch.sqrt(torch.mean((wt-wproxy)**2)).detach()
#         if torch.isnan(rmse):
#             rmse = torch.sqrt(torch.mean(wt**2)).detach()
#         #Alternative 2: projections
#         gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, torch.einsum('i,im->im',gamma2,w02), device)
#         gamma1_2, bias1, gamma2_2, bias2 = fc_1to2layer_learn_GD_LBFGS(wt, bt, w01, w02, act_scale, act_shift, gamma1_2, gamma2, epochs, device)
#         wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#         gamma1_2 = fc_1to2layer_exact_project_without_gamma2(wt, w01, w02, device)
#         gamma2_2 = fc_1to2layer_project_gamma2(wt, w01, w02, gamma1_2, device)
#         wproxy2 = torch.einsum('i,im,m,mj->ij', gamma2_2, w02, gamma1_2, w01)
#         rmse2 =  torch.sqrt(torch.mean((wt-wproxy2)**2)).detach()
#         if torch.isnan(rmse2):
#             rmse2 = torch.sqrt(torch.mean(wt**2)).detach()
#         if rmse2 < rmse:
#             rmse = rmse2
#             wproxy = wproxy2
#     return rmse, wproxy, [w02, w01, gamma2, gamma1]
