import torch
import numpy as np 
import numpy.linalg as LA
import torch.nn as nn
from torch.nn import Linear, Conv2d, SELU
from torch import sigmoid
import torchvision
import torch.optim as optim
from torch.autograd import Variable
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
import time 
from PIL import Image
from prior_MNIST import VAE
from torchvision import datasets, transforms
import gc
import copy
from models import GenericStackedNet, DictNet, A_matrix, A_2dconv
from PIL import Image, ImageFilter
from mpl_toolkits.axes_grid1 import make_axes_locatable
    
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

def transform(inp, flip=False, blur=False, colorjitter=False):
    if flip:
        inp = torchvision.transforms.RandomHorizontalFlip(p=1)(inp)
    if blur:
        inp = inp.filter(ImageFilter.BoxBlur(2))
    if colorjitter:
        inp = torchvision.transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.4)(inp)
    inp = torchvision.transforms.ToTensor()(inp)
    return inp

def make_new_MNIST_datasets(transform_func, batch_size, MNIST_DIR="/mnt/home/hlawrence/ceph/datasets"):
    ds_transform = torchvision.datasets.MNIST(MNIST_DIR,train=True, download=False, transform=transform_func)
    dl_transform = torch.utils.data.DataLoader(ds_transform, batch_size=batch_size)
    dl_transform_iterator = iter(dl_transform)
    
    ds_test_transform = torchvision.datasets.MNIST(MNIST_DIR,train=False, download=False, transform=transform_func)
    dl_test_transform = torch.utils.data.DataLoader(ds_test_transform, batch_size=batch_size)
    dl_test_transform_iterator = iter(dl_test_transform)
    
    return ds_transform, dl_transform, dl_transform_iterator, ds_test_transform, dl_test_transform, dl_test_transform_iterator

def make_shape(diameter=5, rows=10, cols=10, shape='circle'):
    rownums=np.linspace(-(rows/2), rows/2, rows)
    colnums=np.linspace(-(cols/2), cols/2, cols)
    Y, X = np.meshgrid(rownums, colnums)
    radius = diameter/2
    if shape == 'circle':
        msk = (Y**2 + X**2 <= (radius)**2)
    elif shape == 'oval':
        msk = ((Y/radius)**2 + (X/(radius/2))**2 <= 1)
    elif shape == 'square':
        msk = (abs(Y) < (radius)) & (abs(X) < radius)
    elif shape == 'triangle':
        msk = (Y <= (2*X + radius)) & (Y <= -2*X + radius) & (Y > -1*radius)
    elif shape == 'hourglass':
        R = rows / 2
        D = cols / 2
        a = (R**2) / (D / 2)
        b = (a*D) / 2
        msk = (a*Y - b <= (X**2)) & (-1*a*Y - b <= (X**2))
    elif shape == 'hourglass2':
        R = rows / 2
        D = cols / 2
        a = (R**2) / (D / 2)
        b = (a*D) / 2
        msk = (a*X - b <= (Y**2)) & (-1*a*X - b <= (Y**2))
    return torch.tensor(msk).float()
    
def make_A_shapes(numshapes, numsizes, numrows, numcols):
    if numshapes > 4:
        print('Not that many shapes available')
    numpix = min(numrows, numcols)
    A = torch.zeros(numrows*numcols, numshapes*numsizes)
    diams = np.linspace(0,numpix,numsizes+1)
    diams = diams[1:]
    counter = 0
    shapes = ['circle', 'oval', 'square', 'triangle', 'hourglass', 'hourglass2']
    for i in range(numshapes):
        for j in range(numsizes):
            print('A', A.shape, 'counter', counter, 'diams', len(diams), 'j', j, 'i', i, 'numshapes', numshapes)
            A[:, counter] = make_shape(diams[j], rows=numrows, cols=numcols, shape=shapes[i]).view(-1)
            counter += 1
    return A

def estimateSNR(add_noise, test_sample_func):
    sig = test_sample_func()
    noisy = add_noise(sig)
    noise = noisy - sig
    return 10*torch.log10(torch.norm(sig) / torch.norm(noise))

def compose2(f, g):
    return lambda: f(g())

def round_digits(mat, digits):
    with torch.no_grad():
        return torch.round(mat * 10**digits) / (10**digits)

def find_best_z(stacked_net, ys, k=5, iters=100, printevery=300, zsqnorm_fac=0, lr = 1e-3, doplot=False, digits=-1):
    # stacked_net has been trained!
    P = ys.shape[0]
    ys = ys.reshape(ys.shape[0], -1)
    zs_autodiff = Variable(torch.normal(0, 1, (P, k)).to(device), requires_grad=True) # #?torch.zeros((P, k), device=device)
    optimizer = optim.Adam([zs_autodiff] + list(stacked_net.parameters()), lr=lr)
    torch.autograd.set_detect_anomaly(True)
    losses_in_y = []
    for j in range(iters):
        optimizer.zero_grad()
        outputs = stacked_net(zs_autodiff) # P by m by 1
        loss_in_y = (torch.norm(outputs - ys)**2 / torch.norm(ys)**2) + zsqnorm_fac*torch.norm(zs_autodiff); losses_in_y.append(loss_in_y.detach().cpu())
        if printevery is not None and j % printevery == 0:
            print('Epoch %d Iter %d Loss (normalized) y %f' % (i, j, loss_in_y))
        loss_in_y.backward(retain_graph=True)
        optimizer.step()
    stacked_net(zs_autodiff) - ys
    if doplot:
        plt.figure()
        plt.plot(torch.log10(torch.tensor(losses_in_y)))
        plt.xlabel('iter')
        plt.ylabel('log10 loss in y')
        plt.show()
    if digits != -1:
        with torch.no_grad():
            zs_autodiff = round_digits(zs_autodiff, digits)
    return zs_autodiff.clone().detach(), losses_in_y

def loss_in_ys(ys, ys_recon):
    ys = ys.reshape(ys.shape[0], -1)
    ys_recon = ys_recon.reshape(ys_recon.shape[0], -1)
    if type(ys) == type(np.random.rand(3)):
        return np.linalg.norm(ys - ys_recon, ord='fro')**2 / np.linalg.norm(ys, ord='fro')**2
    return torch.norm(ys - ys_recon)**2 / torch.norm(ys)**2

def recon_err_DL(dict_learner, test_sample_func, num_batch, add_noise_func=None, topK=None, track_coeffs=False, returnstd=False, digits=-1):
    avgloss = 0
    n = dict_learner.components_.shape[0]
    m = dict_learner.components_.shape[1]
    all_coeffs = []
    all_errs = []
    for i in range(num_batch):
        if add_noise_func is not None:
            ys_noiseless = test_sample_func()
            ys = add_noise_func(ys_noiseless)
            ys_noiseless = ys_noiseless.detach().cpu()
        else:
            ys = test_sample_func()
        ys = ys.detach().cpu()
        P = ys.shape[0]
        ys = ys.view(ys.shape[0], -1)
        ys = np.array(ys)
        ys_DL = dict_learner.transform(ys)
        if digits != -1:
            ys_DL = np.around(ys_DL, decimals=digits)
        if track_coeffs:
            all_coeffs.append(ys_DL)
        if topK is None:
            full_ys_DL = ys_DL @ dict_learner.components_
        else:
            temp = np.argsort(np.abs(ys_DL), axis=1) 
            tc = temp.shape[1]
            inds = temp[:, (tc-topK):tc]
            rowinds = np.matlib.repmat(np.array(range(P)).reshape(P, 1), 1, topK)
            full_ys_DL = np.zeros((P, m))
            for i in range(P):
                perm = inds[i, :]
                full_ys_DL[i, :] = np.matmul(ys_DL[i, perm].reshape(1, -1), dict_learner.components_[perm, :])

        if add_noise_func is not None:
            this_err = loss_in_ys(ys_noiseless, full_ys_DL)
        else:
            this_err = loss_in_ys(ys, full_ys_DL)
        avgloss += this_err
        all_errs.append(this_err)
    if returnstd:
        if track_coeffs:
            return avgloss / num_batch, all_coeffs, np.std(np.array(all_errs)), np.std(np.log10(np.array(all_errs))) 
        else:
            return avgloss / num_batch, np.std(np.array(all_errs)), np.std(np.log10(np.array(all_errs))) 
    else:
        if track_coeffs:
            return avgloss / num_batch, all_coeffs
        else:
            return avgloss / num_batch

def recon_err_DL_svd(DLcomponents, DLtransform, test_sample_func, num_batch, add_noise_func=None, topK=None, track_coeffs=False, returnstd=False, digits=-1):
    avgloss = 0
    n = DLcomponents.shape[0]
    m = DLcomponents.shape[1]
    all_coeffs = []
    all_zs = []
    all_errs = []
    for i in range(num_batch):
        if add_noise_func is not None:
            if track_coeffs:
                ys_noiseless, zs = test_sample_func()
                all_zs.append(zs)
            else:
                ys_noiseless = test_sample_func()
            ys = add_noise_func(ys_noiseless)
            ys_noiseless = ys_noiseless.detach().cpu()
        else:
            ys = test_sample_func()
        ys = ys.detach().cpu()
        P = ys.shape[0]
        ys = ys.view(ys.shape[0], -1)
        ys = np.array(ys)
        ys_DL = DLtransform(ys)
        if digits != -1:
            ys_DL = np.around(ys_DL, decimals=digits)
        if track_coeffs:
            all_coeffs.append(ys_DL)
        if topK is None:
            full_ys_DL = ys_DL @ DLcomponents
        else:
            temp = np.argsort(np.abs(ys_DL), axis=1) 
            tc = temp.shape[1]
            inds = temp[:, (tc-topK):tc]
            rowinds = np.matlib.repmat(np.array(range(P)).reshape(P, 1), 1, topK)
            full_ys_DL = np.zeros((P, m))
            for i in range(P):
                perm = inds[i, :]
                full_ys_DL[i, :] = np.matmul(ys_DL[i, perm].reshape(1, -1), DLcomponents[perm, :])

        if add_noise_func is not None:
            this_err = loss_in_ys(ys_noiseless, full_ys_DL)
        else:
            this_err = loss_in_ys(ys, full_ys_DL)
        avgloss += this_err
        all_errs.append(this_err)

    if returnstd:
        if track_coeffs:
            return avgloss / num_batch, all_coeffs, zs, np.std(np.array(all_errs)), np.std(np.log10(np.array(all_errs))) 
        else:
            return avgloss / num_batch, np.std(np.array(all_errs)), np.std(np.log10(np.array(all_errs))) 
    else:
        if track_coeffs:
            return avgloss / num_batch, all_coeffs, zs
        else:
            return avgloss / num_batch

def visualize_compare(k, test_sample_func, add_noise_func, A, G, reshape_visualize_func, numex=3, DLcomponents=None, DLtransform=None, topK=None, dict_learner_MOD=None, A_altmin=None, iters=400, lr=1e-2, doplot=False, saveplot=None, extraname="", residual=False, vminmax=None):
    if add_noise_func is not None:
        ys_noiseless = test_sample_func()
        ys_noiseless.to(device)
        ys = add_noise_func(ys_noiseless)
    else:
        ys = test_sample_func()

    fig, axes = plt.subplots(numex, 6, figsize=(17,6))
    axes[0, 0].set_title('True')
    axes[0, 1].set_title('Autodiff')
    axes[0, 2].set_title('Altmin')
    axes[0, 4].set_title('kSVD')
    axes[0, 3].set_title('MOD')
    axes[0, 5].set_title('True (noisy)')

    fig2, axes2 = plt.subplots(numex, 6, figsize=(17,6))
    axes2[0, 0].set_title('True')
    axes2[0, 1].set_title('Autodiff')
    axes2[0, 2].set_title('Altmin')
    axes2[0, 4].set_title('kSVD')
    axes2[0, 3].set_title('MOD')
    axes2[0, 5].set_title('True (noisy)')

    for i in range(numex):
        for j in range(5):
            axes[i, j].xaxis.set_visible(False);  axes[i, j].yaxis.set_visible(False)
            axes2[i, j].xaxis.set_visible(False);  axes2[i, j].yaxis.set_visible(False)

    if vminmax is None:
        vmin = torch.min(ys_noiseless)
        vmax = torch.max(ys_noiseless)
    else:
        vmin = vminmax[0]
        vmax = vminmax[1]

    m, n = A.shape
    A_model_forautodiff = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': A}) 
    intermediate_stacked_net = GenericStackedNet(A_model_forautodiff, G)
    intermediate_stacked_net.A.load_state_dict(A_model_forautodiff.state_dict())
    
    zs, losses_in_y = find_best_z(intermediate_stacked_net, ys, k=k, iters=iters, lr=lr, printevery=None, zsqnorm_fac=0, doplot=doplot)
    ys_recon_autodiff = intermediate_stacked_net(zs)

    if residual:
        subtractoff = ys_noiseless
    else:
        subtractoff = torch.zeros(ys_noiseless.shape).to(ys_noiseless.device)

    def add_colorbar(out_imshow, an_axis, a_fig):
        divider = make_axes_locatable(an_axis)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        a_fig.colorbar(out_imshow, cax=cax, orientation='vertical')

    for i in range(numex):
        add_colorbar(axes[i, 0].imshow(reshape_visualize_func(ys_noiseless[i, ...]), vmin=vmin, vmax=vmax), axes[i,0], fig)
        add_colorbar(axes[i, 5].imshow(reshape_visualize_func(ys[i, ...]), vmin=vmin, vmax=vmax), axes[i,5], fig)
        add_colorbar(axes[i, 1].imshow(reshape_visualize_func(ys_recon_autodiff[i, ...]-subtractoff[i,...]), vmin=vmin, vmax=vmax), axes[i,1], fig)
        add_colorbar(axes2[i, 0].imshow(reshape_visualize_func(ys_noiseless[i, ...])), axes2[i,0], fig2)
        add_colorbar(axes2[i, 5].imshow(reshape_visualize_func(ys[i, ...])), axes2[i,5], fig2)
        add_colorbar(axes2[i, 1].imshow(reshape_visualize_func(ys_recon_autodiff[i, ...]-subtractoff[i,...])), axes2[i,1], fig2)

    if A_altmin is not None:
        A_model_forautodiff = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': A_altmin}) 
        intermediate_stacked_net = GenericStackedNet(A_model_forautodiff, G)
        intermediate_stacked_net.A.load_state_dict(A_model_forautodiff.state_dict())
        
        zs, losses_in_y = find_best_z(intermediate_stacked_net, ys, k=k, iters=iters, lr=lr, printevery=None, zsqnorm_fac=0, doplot=doplot)
        ys_recon_altmin= intermediate_stacked_net(zs)
        for i in range(numex):
            add_colorbar(axes[i, 2].imshow(reshape_visualize_func(ys_recon_altmin[i, ...]-subtractoff[i,...]), vmin=vmin, vmax=vmax), axes[i,2], fig)
            add_colorbar(axes2[i, 2].imshow(reshape_visualize_func(ys_recon_altmin[i, ...]-subtractoff[i,...])), axes2[i,2], fig2)

    if DLcomponents is not None or dict_learner_MOD is not None:
        ys = ys.detach().cpu()
        P = ys.shape[0]
        ys = ys.view(ys.shape[0], -1)
        ys = np.array(ys)

    if DLcomponents is not None: 
        # Do SVD
        ys_DL = DLtransform(ys)
        if topK is None:
            full_ys_DL = ys_DL @ DLcomponents
        else:
            temp = np.argsort(np.abs(ys_DL), axis=1) 
            tc = temp.shape[1]
            inds = temp[:, (tc-topK):tc]
            rowinds = np.matlib.repmat(np.array(range(P)).reshape(P, 1), 1, topK)
            full_ys_DL = np.zeros((P, m))
            for i in range(P):
                perm = inds[i, :]
                full_ys_DL[i, :] = np.matmul(ys_DL[i, perm].reshape(1, -1), DLcomponents[perm, :])
        for i in range(numex):
            add_colorbar(axes[i, 4].imshow(reshape_visualize_func(full_ys_DL[i, ...]-np.array(subtractoff[i,...].cpu())), vmin=vmin, vmax=vmax), axes[i,4], fig)
            add_colorbar(axes2[i, 4].imshow(reshape_visualize_func(full_ys_DL[i, ...]-np.array(subtractoff[i,...].cpu()))), axes2[i,4], fig2)

    if dict_learner_MOD is not None:
        ys_MOD = dict_learner_MOD.transform(ys)
        if topK is None:
            full_ys_MOD = ys_MOD @ dict_learner_MOD.components_
        else:
            temp = np.argsort(np.abs(ys_MOD), axis=1) 
            tc = temp.shape[1]
            inds = temp[:, (tc-topK):tc]
            rowinds = np.matlib.repmat(np.array(range(P)).reshape(P, 1), 1, topK)
            full_ys_MOD = np.zeros((P, m))
            for i in range(P):
                perm = inds[i, :]
                full_ys_MOD[i, :] = np.matmul(ys_MOD[i, perm].reshape(1, -1), dict_learner_MOD.components_[perm, :])
        for i in range(numex):
            add_colorbar(axes[i, 3].imshow(reshape_visualize_func(full_ys_MOD[i, ...]-np.array(subtractoff[i,...].cpu())), vmin=vmin, vmax=vmax), axes[i,3], fig)
            add_colorbar(axes2[i, 3].imshow(reshape_visualize_func(full_ys_MOD[i, ...]-np.array(subtractoff[i,...].cpu()))), axes2[i,3], fig2)

    if saveplot is not None:
        fig.savefig(saveplot + extraname + 'DirectCompareVisualization.png')
        fig2.savefig(saveplot + extraname + 'DirectCompareVisualization_diffscales.png')
    if doplot:
        plt.show()

def visualize_filters(A_true, A, A_init, G, reshape_visualize_func, DLcomponents=None, \
        dict_learner_MOD=None, A_altmin=None, doplot=False, saveplot=None, extraname=""):
    
    m, n = A_true.shape
    fig, axes = plt.subplots(6, min(n, 10), figsize=(15,6))

    for i in range(6):
        for j in range(min(n, 10)):
            axes[i, j].xaxis.set_visible(False);  axes[i, j].yaxis.set_visible(False)


    axes[0, 0].set_ylabel('True')
    axes[1, 0].set_ylabel('Init')
    axes[2, 0].set_ylabel('Autodiff')
    axes[3, 0].set_ylabel('Altmin')
    axes[5, 0].set_ylabel('kSVD')
    axes[4, 0].set_ylabel ('MOD')

    for i in range(min(n, 10)):
        axes[0, i].imshow(reshape_visualize_func(A_true[:, i]))
        axes[1, i].imshow(reshape_visualize_func(A_init[:, i]))
        axes[2, i].imshow(reshape_visualize_func(A[:, i]))
        if A_altmin is not None:
            axes[3, i].imshow(reshape_visualize_func(A_altmin[:, i]))
        if DLcomponents is not None:
            axes[5, i].imshow(reshape_visualize_func(DLcomponents[i, :])) # diff convention
        if dict_learner_MOD is not None:
            axes[4, i].imshow(reshape_visualize_func(dict_learner_MOD.components_[i, :]))

    if saveplot is not None:
        plt.savefig(saveplot + extraname + 'CompareDictionaries.png')
    if doplot:
        plt.show()
    plt.close(fig)

def recon_err_A(intermediate_A, test_sample_func, num_batch, G, true_A=None, add_noise_func=None, zsqnorm_fac=0, k=None, iters=400, lr=1e-2, doplot=False, saveplot=None, extraname=None, returnstd=False, true_stacked_net=None, P=None, pz=None, digits=-1):
    avgloss = 0
    if true_A is not None:
        loss_A = torch.norm(true_A - intermediate_A)**2 / torch.norm(true_A)**2
    all_errs = []
    for i in range(num_batch):
        if add_noise_func is not None:
            if true_stacked_net is None:
                ys_noiseless = test_sample_func()
            else:
                with torch.no_grad():
                    true_zs = pz(P)
                    ys_noiseless = true_stacked_net(true_zs)
            ys_noiseless.to(device)
            ys = add_noise_func(ys_noiseless)
        else:
            if true_stacked_net is not None:
                true_zs = pz(P)
                ys = true_stacked_net(true_zs)
            else:
                ys = test_sample_func()
        ys = ys.to(device)
        # create intermediate_stacked_net
        m = intermediate_A.shape[0]
        n = intermediate_A.shape[1]
        A_model_forautodiff = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': intermediate_A}) 
        intermediate_stacked_net = GenericStackedNet(A_model_forautodiff, G)
        intermediate_stacked_net.A.load_state_dict(A_model_forautodiff.state_dict())
        
        zs, losses_in_y = find_best_z(intermediate_stacked_net, ys, k=k, iters=iters, lr=lr, printevery=None, zsqnorm_fac=zsqnorm_fac, doplot=doplot, digits=digits)
        ys_recon = intermediate_stacked_net(zs)
        if add_noise_func is not None:
            thiserr = np.array(loss_in_ys(ys_noiseless, ys_recon).data.cpu())
        else:
            thiserr = np.array(loss_in_ys(ys, ys_recon).data.cpu())
        avgloss += thiserr
        all_errs.append(thiserr)
        if doplot or (saveplot is not None):
            fig = plt.figure()
            plt.plot(torch.log10(torch.tensor(losses_in_y)))
            plt.xlabel('iter')
            plt.ylabel('Log10 loss in y')
            if saveplot is not None:
                if extraname is None:
                    plt.savefig(saveplot + 'ReconErrA_Batch' + str(i) + '.png')
                else:
                    plt.savefig(saveplot + extraname + 'ReconErrA_Batch' + str(i) + '.png')
            if doplot:
                plt.show()
            plt.close(fig)

            fig = plt.figure()
            nm = 1
            for i in range(min(nm, P)):
                plt.subplot(1,nm,i+1)
                plt.plot(zs[0,...].detach().cpu(), label='Recovered')
                plt.plot(true_zs[0,...].detach().cpu(), label='True')
            plt.legend()
            if saveplot is not None:
                if extraname is None:
                    plt.savefig(saveplot + 'Compare_z_Batch' + str(i) + '.png')
                else:
                    plt.savefig(saveplot + extraname + 'Compare_z_Batch' + str(i) + '.png')
            if doplot:
                plt.show()
            plt.close(fig)

    all_errs = np.array(all_errs)
    if returnstd:
        if true_A is not None:
            return avgloss / num_batch, loss_A, torch.std(torch.tensor(all_errs)), torch.std(torch.log10(torch.tensor(all_errs))) 
        else:
            return avgloss / num_batch, torch.std(np.array(all_errs)), torch.std(torch.log10(torch.tensor(all_errs))) 
    else:
        if true_A is not None:
            return avgloss / num_batch, loss_A
        else:
            return avgloss / num_batch

def normalize_cols(mat):
    # mat is a tensor
    newmat = torch.zeros(mat.shape, device=mat.device, dtype=mat.dtype)
    for i in range(mat.shape[1]):
        newmat[:, i] = mat[:, i] / torch.norm(mat[:, i])
    return newmat

def evaluate_generative_model(G, pz, numsamples=100, device=torch.device('cpu')):
    # Sample from pz
    # Approximate E[G(z)G(z)^T]
    zs = pz(numsamples) # numsamples by k
    print('shape of zs', zs.shape)
    Gzs = G(zs) # numsamples by n
    n = Gzs.shape[1]
    res = torch.matmul(Gzs.view(numsamples, n, 1), Gzs.view(numsamples, 1, n)) 
    EGG = torch.sum(res,dim=0) / numsamples
    print('zs', zs.device, 'Gzs', Gzs.device, 'res', res.device, 'EGG', EGG.device, 'device', device)
    u, s, v = torch.svd(EGG - torch.eye(n, n, device=device))
    sv_max = torch.max(s)
    sv_min = torch.min(s)
    scale_factor = (sv_max + sv_min) / 2
    epsilon = max(abs(sv_max*scale_factor - 1), abs(sv_min*scale_factor - 1))
    return scale_factor, epsilon

def estimate_lipschitz(G, pz, numsamples=1000):
    # note: estimates Lipschitz constant * radius of pz, i.e. ||G(z)||/||z||
    zs = pz(numsamples) # numsamples by k
    Gzs = G(zs) # numsamples by n
    mags = torch.norm(Gzs, dim=1) / torch.norm(zs, dim=1)
    return torch.max(mags)

def estimateCinv(G, pz, numsamples=1000):
    zs = pz(numsamples) # numsamples by k
    Gzs = G(zs) # numsamples by n
    n = Gzs.shape[-1]
    allouterprods = torch.matmul(Gzs.view(numsamples, n, 1), Gzs.view(numsamples, 1, n))
    EGzGzT = (1/numsamples) * torch.sum(allouterprods, 0)
    U,S,Vh = torch.svd(EGzGzT)
    return 1 / torch.min(S)

def estimate_S_REC(A, G, pz, numsamples=5, delta=0.1):
    zs = pz(numsamples) # numsamples by k
    Gzs = G(zs) # numsamples by n 
    pairdiffs = Gzs[:, None, :] - Gzs[None, :, :] # numsamples by numsamples by n 
    n = pairdiffs.shape[2]
    pairdiffs = pairdiffs.view(numsamples**2, n, 1)
    AGzs = torch.matmul(A, pairdiffs) #numsamples**2 by n by 1
    
    # (numsamples choose 2) by n
    pair_norms = torch.norm(pairdiffs[..., 0], dim=1)
    gamma = torch.min((torch.norm(AGzs[..., 0], dim=1) + delta) / pair_norms)
    return gamma 


def pz_maker(k, pz_type='random', device=torch.device('cpu'), n=None):
    # what format to stack samples in? think: going into G, probably want a batch dimension
    # then will apply A, might need to switch dimensions
    # prehaps easier for ys and zs to be arrays rather than lists...TBD
    # samples from z's distribution: 
    # zs is P by k torch array

    # return function wrapper for each k s.t. pz_func has k built into it. 

    if pz_type == 'random':
        return (lambda P: torch.randn(P, k, device=device))
    if pz_type == 'sparse':
        # return indices. with replacement!
        return (lambda P: torch.randint(low=0, high=n, size=(P, k), device=device))

def draw_fresh_samples(A, G, P, pz_func, noise=0, Amodel=False):
    # Returns list of length P of np arrays, each of length m
    # actually: batch dimension, P by m array 
    # For debugging, also returns the corresponding zs 

    with torch.no_grad():
        zs = pz_func(P) # P by k
        # G(zs): P by n by 1 (will there be a 1????)
        # A: m by n
        if Amodel:
            ys = A(G(zs))
            ys += noise*torch.randn(ys.shape, device=ys.device)
        else:
            ys = DictNet(A, G)(zs) 
            ys += noise*torch.randn(ys.shape, device=ys.device)
        return ys, zs

class MyLossFn(torch.nn.Module):
    def __init__(self, ys):
        super(MyLossFn, self).__init__()
        self.ys = ys
    
    def forward(self, outputs):
        ys = self.ys
        return torch.norm(outputs - ys) / torch.norm(ys)

def decode(A, G, ys, k, n, maxiter=100, eps=1e-3, lr=1e-3, display=False, pz_pdf=None, device=torch.device('cpu')):
    # given A, G, and measured ys, optimize to find the right zs
    st=time.time()

    P = ys.shape[0]
    
    zs = torch.zeros((P, k), device=device, requires_grad=True) #torch.normal(0, 1, (P, k)) #?

    optimizer =  optim.Adam([zs], lr=lr) # optim.Adam([zs], lr=lr)
    torch.autograd.set_detect_anomaly(True)
    et=time.time()
    #print('Setup time in decode: ', (et-st)/60)

    losses = []
    st=time.time()
    all_back_time = 0
    all_opt_time = 0

    # no forward pass information error??? or warning?
    loss_fn = MyLossFn(ys)
    for i in range(maxiter):
        optimizer.zero_grad()

        Gzs = G(zs) 
        outputs = torch.matmul(A, Gzs.view(P, n, 1))
        #print('outputs', outputs.shape, 'ys', ys.shape)
        loss = torch.norm(outputs - ys) / torch.norm(ys) # loss_fn(outputs) 
        loss.backward(retain_graph=False)
        optimizer.step()

        this_loss = loss 
        losses.append(this_loss.detach().cpu())

        if display:
            print('Iter %d Loss (normalized) %f' % (i, this_loss))
        
        all_back_time = -1
        
    et=time.time()
    losses_arr = np.array(losses)
    
    zs_to_return = copy.deepcopy(zs.data)
    zs_to_return.requires_grad = False

    del zs
    gc.collect()

    return zs_to_return, losses_arr

def decode_updated(A, G, ys, k, maxiter=100, eps=1e-3, lr=1e-3, display=False, pz_pdf=None, device=torch.device('cpu'), zsqnorm_fac=0, init='zeros'):
    # given A, G, and measured ys, optimize to find the right zs
    st=time.time()
    P = ys.shape[0]
    
    if init=='zeros':
        zs = torch.zeros((P, k), device=device, requires_grad=True) #torch.normal(0, 1, (P, k)) #?
    else:
        zs = Variable(torch.normal(0, 1, (P, k)).to(device), requires_grad=True)

    optimizer =  optim.Adam([zs], lr=lr) 
    torch.autograd.set_detect_anomaly(True)
    et=time.time()

    losses = []
    st=time.time()
    all_back_time = 0
    all_opt_time = 0

    # no forward pass information error??? or warning?
    for i in range(maxiter):
        optimizer.zero_grad()

        outputs = A(G(zs))

        loss = (torch.norm(outputs - ys)**2 / torch.norm(ys)**2) + zsqnorm_fac*(torch.norm(zs)**2)
        loss.backward(retain_graph=False)
        optimizer.step()
        this_loss = loss 
        losses.append(this_loss.detach().cpu())

        if display:
            print('Iter %d Loss (normalized) %f' % (i, this_loss))
        
        all_back_time = -1

        
    et=time.time()
    losses_arr = np.array(losses)

    return zs.detach(), losses_arr

def update(A, zs, ys, G, eta=0.01):
    # A is m by n
    # zs: P by k
    # ys: P by m by 1

    with torch.no_grad():
        P = ys.shape[0]
        Gzs = G(zs).view(zs.shape[0], A.shape[1], 1) # P by n by 1
        grad_in_A = torch.matmul(torch.matmul(A, Gzs) - ys.view(P, ys.shape[1], 1), Gzs.permute(0, 2, 1)) # P by m by n
        averaged_grad = torch.sum(grad_in_A, dim=0) / P
        newA = A - eta*averaged_grad

        del A
        gc.collect()

        return newA

def update_generic_A(A, zs, ys, G, maxiter=5, lr=0.01, display=False, use_optimizer=True):
    # given zs and G, optimize to find the right A. how many steps?
    st=time.time()
    P = ys.shape[0]

    if use_optimizer:
        optimizer =  optim.Adam(A.parameters(), lr=lr) 
    torch.autograd.set_detect_anomaly(True)
    et=time.time()

    losses = []
    st=time.time()
    all_back_time = 0
    all_opt_time = 0

    for i in range(maxiter):
        if use_optimizer:
            optimizer.zero_grad()

        Gzs = G(zs) 
        outputs = A(G(zs))

        loss = torch.norm(outputs - ys)**2 / torch.norm(ys)**2 
        loss.backward(retain_graph=False)
        if use_optimizer:
            optimizer.step()
        this_loss = loss 
        losses.append(this_loss.detach().cpu())

        all_back_time = -1

        
    et=time.time()
    losses_arr = np.array(losses)

    return A, losses_arr


def project_to_circulant(A):
    # assume: A is square!
    n = A.shape[1]
    new_a_vec = torch.zeros(n)
    for i in range(n):
        lst = [A[j, j % n] for j in range(n)]
        new_a_vec[i] = sum(lst) / len(lst)
        for j in range(n):
            A[j, j%n] = new_a_vec[i]
    return A

def project(A, Ainit, delta):
    # column-wise projection of A onto Ainit: vector x onto vector y, x + beta*(y-x) = y + delta * a
    # a is an arbitrary vector 
    newA = torch.zeros(A.shape, device=A.device)
    for i in range(Ainit.shape[1]):
        x = A[:, i]
        y = Ainit[:, i]
        newA[:, i] = y + (delta * (x - y) / torch.norm(x - y))
    return newA

def set_models_equal(A_tochange, A):
    for obj in A.__dict__['_parameters'].keys():
        A_tochange.__dict__['_parameters'][obj] = A.__dict__['_parameters'][obj]
    return A_tochange

def random_init_model(A):
    for obj in A.__dict__['_parameters'].keys():
        val = A.__dict__['_parameters'][obj]
        if type(val) is torch.Tensor:
            A.__dict__['_parameters'][obj] = torch.randn(val.shape)
    return A

def get_gaussian_kernel(shp, var=0.8, dofftshift=False):
    # assume last 2 are the big ones

    locs = np.linspace(-5,5,shp[-1])
    locs = np.fft.fftshift(locs)
    my_filter1 = np.exp(-1 * (locs**2)/var) #= np.zeros(locs.shape); my_filter[0]=1 #

    locs2 = np.linspace(-5,5,shp[-2])
    locs2 = np.fft.fftshift(locs2)
    my_filter2 = np.exp(-1 * (locs2**2)/var)

    mat = my_filter2.reshape(-1, 1) * (my_filter1.reshape(1,-1))
    if dofftshift:
        mat = np.fft.fftshift(mat)
    return torch.tensor(mat.reshape(shp))


def add_pattern(zs, m, n, pattern_type='checker'):
        # m by n image pointwise applied to each?? only where value is high??
        # will want to reshape appropriately
        # zs will be P by m*n
        if pattern_type == 'checker':
            sidelen = int(m**0.5)
            pattern = torch.tensor(np.indices((sidelen, sidelen)).sum(axis=0) % 2, device=zs.device)
        msk = zs > (torch.norm(zs)/(m*n*100))
        notmsk = ~msk
        return (zs*msk.float()*pattern.view(-1) + zs*notmsk.float()).view(zs.shape)

def set_up_A_model(opts_dict):
    n, m = opts_dict['n'], opts_dict['m']
    device = opts_dict['device']
    if opts_dict['type'] == 'linear':
        A_model = A_matrix(n=n, m=m).to(device)
        Ainit = opts_dict['init']
        A_model.A.weight.data = copy.deepcopy(Ainit)
        return A_model
    elif opts_dict['type'] == '2dconv':
        nonlinearitytype = opts_dict['nonlinearity']
        A_model =  A_2dconv(n=n, nonlinearity=nonlinearitytype).to(device)
    return A_model

def lightweight_denoising_loss(num_batch, lr, iters, P, k, G, test_sample_func, add_noise_func, do_plots=False, show_every_recon=False):
    test_losses = []
    for i in range(num_batch):
        ys_noiseless = test_sample_func()
        ys_noiseless.to(device)
        ys = add_noise_func(ys_noiseless)

        zs_autodiff = Variable(torch.normal(0, 1, (P, k)).to(device), requires_grad=True) # Gradient is probably 0 if you initialize to zero
        optimizer = optim.Adam([zs_autodiff] , lr=lr)
        torch.autograd.set_detect_anomaly(True)
        losses_this_epoch = []
        for j in range(iters):
            optimizer.zero_grad() 
            outputs = G(zs_autodiff) # P by m by 1
            if i==0 and show_every_recon and j%100==0:
                plt.imshow(outputs[0,...].reshape(28,28).detach().cpu())
                plt.title('iter'+str(j)+'recon')
                plt.show()

            loss_in_y = (torch.norm(outputs - ys)**2 / torch.norm(ys)**2) # + zsqnorm_fac*torch.norm(zs_autodiff)**2; 
            losses_this_epoch.append(loss_in_y.detach().cpu())
            loss_in_y.backward(retain_graph=True)
            optimizer.step()
            
        if do_plots:
            plt.plot(np.log10(np.array(losses_this_epoch))) #.cpu())
            plt.title('Log10 Losses over iterations in lightweight_denoising_loss')
            plt.show()
        lss = loss_in_ys(ys_noiseless, outputs).data.cpu()
        test_losses.append(lss)

        del zs_autodiff, optimizer
        gc.collect()
    return test_losses, ys_noiseless, outputs, ys

def make_visual_plots(pz, k, n, G, current_A, A_true, Ainit, bigtitle=None, noise_level=None, do_MNIST=False, deconv=False, updated=False, seed=None, device=torch.device('cpu'), do_plots=True):
    if seed is not None:
        np.random.seed(seed)
        torch.random.manual_seed(seed)
    with torch.no_grad():
        true_z = pz(1)
        from_vae = G(true_z) 
        if updated:
            transformed_from_vae = A_true(from_vae)
        else:
            transformed_from_vae = torch.matmul(A_true, from_vae.view(1, -1, 1))
        noisy_transformed_from_vae = transformed_from_vae + noise_level*torch.randn(transformed_from_vae.shape, device=transformed_from_vae.device)
    if updated:
        zs, meas_losses_in_z = decode_updated(current_A, G, noisy_transformed_from_vae, k, maxiter=200, lr=5e-2, device=device) # on noisy_transformed_from_Vae .view(1, -1)
    else:
        zs, meas_losses_in_z = decode(current_A, G, noisy_transformed_from_vae, k, n, maxiter=200, lr=5e-2, device=device) # on noisy_transformed_from_Vae .view(1, -1)
    
    with torch.no_grad():
        if updated:
            init_denoised_transformed_from_vae = Ainit(G(zs))
        else:
            init_denoised_transformed_from_vae = DictNet(Ainit, G)(zs) # minimize over z with learned A

    if not updated:
        print('Ainit vs A_true, normalized', torch.norm(Ainit-A_true)/torch.norm(A_true))
    else:
        if '_modules' in A_true.__dict__.keys():
            if 'convlayer' in A_true.__dict__['_modules'].keys():
                trueAconv = A_true.convlayer.__dict__['_parameters']['weight'].data
                initAconv = Ainit.convlayer.__dict__['_parameters']['weight'].data
                print('Ainit vs A_true, normalized', torch.norm(trueAconv-initAconv)/torch.norm(trueAconv))
        elif 'A' in A_true.__dict__.keys(): 
            trueA = A_true.A.weight.data
            initA = Ainit.A.weight.data
            print('Ainit vs A_true, normalized', torch.norm(trueA-initA)/torch.norm(trueA))
    if updated:
        zs_from_Atrue, ignore = decode_updated(A_true, G, noisy_transformed_from_vae, k, maxiter=200, lr=5e-2, device=device) #.view(1, -1)
    else:
        zs_from_Atrue, ignore = decode(A_true, G, noisy_transformed_from_vae, k, n, maxiter=200, lr=5e-2, device=device) #.view(1, -1)

    with torch.no_grad():
        if not updated:
            corr_ztoo_denoised_transformed_from_vae = DictNet(A_true, G)(zs_from_Atrue)
            should_be_correct_meas = DictNet(A_true, G)(true_z)
            denoised_transformed_from_vae = DictNet(current_A, G)(zs)
            corr_denoised_transformed_from_vae = DictNet(A_true, G)(zs)
        else:
            corr_ztoo_denoised_transformed_from_vae = A_true(G(zs_from_Atrue))
            denoised_transformed_from_vae = current_A(G(zs))
            corr_denoised_transformed_from_vae =  A_true(G(zs))

    if updated:
        zs_from_Ainit, ignore = decode_updated(Ainit, G, noisy_transformed_from_vae, k, maxiter=200, lr=5e-2, device=device)
    else:
        zs_from_Ainit, ignore = decode(Ainit, G, noisy_transformed_from_vae, k, n, maxiter=200, lr=5e-2, device=device)
    
    with torch.no_grad():
        if updated:
            init_ztoo_denoised_transformed_from_vae = Ainit(G(zs_from_Ainit))
        else:
            init_ztoo_denoised_transformed_from_vae = DictNet(Ainit, G)(zs_from_Ainit)

    def l2err(mat):
        return torch.norm(mat-transformed_from_vae) / torch.norm(transformed_from_vae)

    if do_MNIST and do_plots:
        plt.figure(figsize=(18,3))

        numtotal = 8

        if deconv:
            numtotal += 3

        plt.subplot(1,numtotal,1)
        plt.imshow(from_vae.view(28, 28).detach().cpu())
        plt.title('Pre-transformed')

        plt.subplot(1,numtotal,2)
        plt.imshow(transformed_from_vae.view(28, 28).detach().cpu())
        plt.title('True')

        plt.subplot(1,numtotal,3)
        plt.imshow(noisy_transformed_from_vae.view(28,28).detach().cpu())
        plt.title('Noisy')

        plt.subplot(1,numtotal,4)
        plt.title('A=init')
        plt.imshow(init_denoised_transformed_from_vae.view(28,28).detach().cpu())

        plt.subplot(1,numtotal,5)
        plt.title('A=learned')
        plt.imshow(denoised_transformed_from_vae.view(28,28).detach().cpu())

        plt.subplot(1,numtotal,6)
        plt.title('A=correct')
        plt.imshow(corr_denoised_transformed_from_vae.view(28,28).detach().cpu())

        plt.subplot(1,numtotal,7)
        plt.title('A=correct, new z')
        plt.imshow(corr_ztoo_denoised_transformed_from_vae.view(28, 28).detach().cpu())

        plt.subplot(1,numtotal,8)
        plt.title('A=init, new z')
        plt.imshow(init_ztoo_denoised_transformed_from_vae.view(28, 28).detach().cpu())

        if deconv:
            plt.subplot(1,numtotal,numtotal-2)
            plt.title('G(z-A_true)')
            plt.imshow(G(zs_from_Atrue).view(28,28).detach().cpu())

            plt.subplot(1,numtotal,numtotal-1)
            plt.title('G(z-A_init)')
            plt.imshow(G(zs_from_Ainit).view(28,28).detach().cpu())

            plt.subplot(1,numtotal,numtotal)
            plt.title('G(z-learned A)')
            plt.imshow(G(zs).view(28,28).detach().cpu())

        if bigtitle is not None:
            plt.suptitle(bigtitle)
        plt.show()

    with torch.no_grad():
        learned_err = torch.norm(denoised_transformed_from_vae-transformed_from_vae) / torch.norm(transformed_from_vae)
        
        return learned_err
