import torchvision.transforms as transforms
import numpy as np
from numpy import matlib
import torch
from prior_MNIST import VAE
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torchvision
import torch
from torch.autograd import Variable
import torch.optim as optim
from PIL import Image, ImageFilter
import models
from models import A_matrix, GenericStackedNet, make_G_redundant, make_G_redundant_list
import utils
from utils import decode_updated, update_generic_A, set_up_A_model, pz_maker, estimateSNR, estimate_lipschitz, estimate_S_REC, estimateCinv
from utils import find_best_z, loss_in_ys, recon_err_DL, recon_err_A, compose2, recon_err_DL_svd
from utils import make_shape, make_A_shapes, visualize_compare, visualize_filters, transform, make_new_MNIST_datasets
import copy
import time
import gc
#import scipy
import sklearn.decomposition
from sklearn.datasets import make_sparse_coded_signal
from sklearn.decomposition import MiniBatchDictionaryLearning
from opts import learn_A_autodiff, learn_A_altmin
from ksvd import ApproximateKSVD
import argparse
import pickle
import os
import numpy.linalg as LA
from mpl_toolkits.axes_grid1 import make_axes_locatable

parser = argparse.ArgumentParser(description='Dictionary learning experiments')
parser.add_argument('--ksvd', dest='ksvd', action='store_true', help='Run k-SVD dictionary learning baseline', default=False)
parser.add_argument('--topk', dest='topk', action='store_true', help='Use only top k components for DL recon', default=False)
parser.add_argument('--dlsparsity', dest='dlsparsity', action='store_true', help='For any DL methods, plot how sparse they are', default=False)
parser.add_argument('--MOD', dest='MOD', action='store_true', help='Run MOD dictionary learning baseline', default=False)
parser.add_argument('--autodiff', dest='autodiff', action='store_true', help='Run autodiff with G', default=False)
parser.add_argument('--projectA', dest='projectA', action='store_true', help='Project A close to initialization, only possible for autodiff at this point', default=False)
parser.add_argument('--altmin', dest='altmin', action='store_true', help='Run alternating minimization with G', default=False)
parser.add_argument('--name', default='0', type=str, help='name of run')
parser.add_argument('--saveplots', dest='saveplots', action='store_true', help='save plots')
parser.add_argument('--epochs', default=5, type=int, help='total epochs to run')
parser.add_argument('--batchfortesting', default=3, type=int, help='how many batches to draw for testing')
parser.add_argument('--A', default='rand', type=str, help='type of A: rand, ortho, shapes')
parser.add_argument('--sizesmallshapes', default=8, type=int, help='random seed')
parser.add_argument('--G', default='dependencies', type=str, help='type of generative model G')
parser.add_argument('--MNIST_flip', dest='MNIST_flip', action='store_true', help='Flip MNIST', default=False)
parser.add_argument('--MNIST_blur', dest='MNIST_blur', action='store_true', help='Blur MNIST', default=False)
parser.add_argument('--MNIST_color', dest='MNIST_color', action='store_true', help='Color jitter MNIST', default=False)
parser.add_argument('--P', default=10, type=int, help='number of samples per batch / epoch')
parser.add_argument('--k', default=3, type=int, help='dimensionality of latent variable. overwritten when not a choice based on A or G')
parser.add_argument('--n', default=10, type=int, help='dimensionality of coefficient space / number of dictionary vectors. ignored for certain G options')
parser.add_argument('--m', default=9, type=int, help='number of measurements. must be m>n for ksvd to not complain')
parser.add_argument('--noise', default=0.1, type=float, help='sampling noise')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--numseeds', default=1, type=int, help='number of random seeds to try')
parser.add_argument('--init', default=-1, type=float, help='initialization noise; totally random if <0')
parser.add_argument('--numinits', default=1, type=int, help='number of random initalizations to try')

args = parser.parse_args()

for seed_ind in range(args.numseeds):
	for init_ind in range(args.numinits):
		this_seed = args.seed + seed_ind
		np.random.seed(this_seed)
		torch.manual_seed(this_seed)

		savedir = 'results/' + args.name + '/' + 'Seed_' + str(seed_ind) + '/' + 'Init_' + str(init_ind) + '/'
		os.makedirs(name=savedir, exist_ok=True)
		filename = savedir + 'DataDump.pkl'

		if args.saveplots:
			saveplotname = savedir
		else:
			saveplotname = None

		savedict = {'args': args}
		savedict['this_seed'] = this_seed
		savedict['init_ind'] = init_ind

		if torch.cuda.is_available():
		    device = torch.device('cuda')
		else:
		    device = torch.device('cpu')
		k = args.k;
		m = args.m 
		P = args.P 
		def normalize(x, fx):
		    return fx * torch.norm(x)/torch.norm(fx)

		true_A_exists = True
		if args.G == 'dependencies':
			funclist = [(lambda x: normalize(x, (x**2))), (lambda x: normalize(x, torch.sin(x))), (lambda x: normalize(x, x)), (lambda x: normalize(x, torch.exp(x))), (lambda x: normalize(x, x**3))]
			n = (len(funclist)+1)*k; 
			G = make_G_redundant_list(funclist)  
		elif args.G == 'few_dependencies':
			funclist = [(lambda x: normalize(x, (x**2))), (lambda x: normalize(x, torch.sin(x)))]
			n = (len(funclist)+1)*k; 
			G = make_G_redundant_list(funclist) 
		elif args.G == 'polynomials':
			funclist = [(lambda x: normalize(x, (x**2))), (lambda x: normalize(x, 3*(x**3)-2*(x**2))), (lambda x: normalize(x, (10*(x**2)-x)))]
			n = (len(funclist)+1)*k; 
			G = make_G_redundant_list(funclist) 
		elif args.G == 'group_sparsity':
			n = args.n
			adjmat = (torch.rand(n, k) > 0.4).float().to(device)
			G = models.make_G_bipartite(adjmat, nonlinear_func='square_and_sum_and_relu')
		elif args.G == 'MNIST_vae':
			vae = VAE().to(device)
			vae.load_state_dict(torch.load('model.pt'))
			k = 20
			m = 784
			numrow = 28
			numcol = 28
			n = 400
			G = vae.decode_without_last
		elif args.G == 'MNIST_real':
			true_A_exists = False
			vae = VAE().to(device)
			vae.load_state_dict(torch.load('model.pt'))
			k = 20
			m = 784
			numrow = 28
			numcol = 28
			n = 400
			G = vae.decode_without_last

			MNIST_DIR = "/mnt/home/hlawrence/ceph/datasets"
			batch_size = P
			ds, dl, dl_iterator, ds_test, dl_test, dl_test_iterator = make_new_MNIST_datasets(lambda im: transform(im, flip=args.MNIST_flip, blur=args.MNIST_blur, colorjitter=args.MNIST_color), batch_size, MNIST_DIR=MNIST_DIR)
			print('just defined dl_iterator', type(dl_iterator))

		print('n', n, 'm',m,'P',P)
		pz = pz_maker(k, device=device)

		if args.A == 'shapes':
			numcol_shapes = 28
			numrow_shapes = 28
			numrow = numrow_shapes
			numcol = numcol_shapes
			m = 28**2
			if args.G == 'MNIST_vae':
				print('ERROR: MNIST VAE option for A not compatible with group_sparsity for G')
			elif args.G == 'group_sparsity':
				if n <= 6:
					A_true = make_A_shapes(n, 1, numrows=28, numcols=28).float().to(device)
				else:
					A_true = make_A_shapes(6, int(np.ceil(n/6)), numrows=28, numcols=28).float().to(device)
					A_true = A_true[0:m, 0:n]
			else:
				A_true = make_A_shapes(len(funclist)+1, k, numrows=28, numcols=28).float().to(device)
			A_true = utils.normalize_cols(A_true)
		elif args.A == 'small_shapes':
			numrow_shapes = args.sizesmallshapes
			numcol_shapes = args.sizesmallshapes
			numrow = numrow_shapes
			numcol = numcol_shapes
			m = numrow_shapes*numcol_shapes
			if args.G == 'MNIST_vae':
				print('ERROR: MNIST VAE option for A not compatible with group_sparsity for G')
			elif args.G == 'group_sparsity':
				if n <= 6:
					A_true = make_A_shapes(n, 1, numrows=numrow_shapes, numcols=numcol_shapes).float().to(device)
				else:
					print('n', n, 'np.ceil(n/6)', int(np.ceil(n/6)))
					A_true = make_A_shapes(6, int(np.ceil(n/6)), numrows=numrow_shapes, numcols=numcol_shapes).float().to(device)
					A_true = A_true[0:m, 0:n]
			else:
				A_true = make_A_shapes(len(funclist)+1, k, numrows=numrow_shapes, numcols=numcol_shapes).float().to(device)
			A_true = utils.normalize_cols(A_true)
		elif args.A == 'ortho':
			mat = np.random.rand(max(m,n), max(m,n))
			[orthomat, s, v] = LA.svd(mat)
			A_true = utils.normalize_cols(torch.from_numpy(orthomat[0:m, 0:n])).float().to(device) # m by n
		elif args.A == 'MNIST_vae':
			A_true = vae.fc4.weight.detach().float().to(device)
		elif args.A == 'MNIST_real':
			A_true = torch.zeros(m, n).float().to(device)
		elif args.A == 'CelebA_real':
			A_true = torch.zeros(m, n).float().to(device)
		else:
			A_true = utils.normalize_cols(torch.rand(m, n)).float().to(device)
		A_true_model = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': A_true})
		if true_A_exists: 
			true_stacked_net = GenericStackedNet(A_true_model, G)
		else:
			true_stacked_net = None

		if args.init < 0:
			A_init = utils.normalize_cols(torch.rand(m, n)).float().to(device) 
		else:
			A_init = utils.normalize_cols(A_true + args.init*torch.rand(m,n).float().to(device)) # # was 0.1 noise

		A_init_model = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': A_init}) 
		if true_A_exists:
			print('Error in A_init', torch.norm(A_true - A_init) / torch.norm(A_true))

		if args.G == 'MNIST_real' or args.G == 'CelebA_real':
			def make_test_sample_functions(dl_iterator=dl_iterator, dl_test_iterator=dl_test_iterator):
				def test_sample_func_train(dl_iterator=dl_iterator):
					with torch.no_grad():
						temp, _ = next(dl_iterator)
						return temp.view(temp.shape[0], -1).float().to(device)
				def test_sample_func_test(dl_test_iterator=dl_test_iterator):
					with torch.no_grad():
						temp, _ = next(dl_test_iterator)
						return temp.view(temp.shape[0], -1).float().to(device)
				return test_sample_func_train, test_sample_func_test

		if true_A_exists:
			def test_sample_func_train(): # train doesn't matter, just for convenience / commonality between the two
			    with torch.no_grad():
			        return true_stacked_net(pz(P))
			test_sample_func_test = test_sample_func_train
		elif args.G == 'MNIST_real' or args.G == 'CelebA_real':
			test_sample_func_train, test_sample_func_test = make_test_sample_functions()


		savedict['A_true'] = A_true
		savedict['A_init'] = A_init

		sampling_noise = args.noise

		def add_noise(invec):
		    with torch.no_grad():
		        return invec + sampling_noise*torch.randn((P, m), device=device)

		sampler = compose2(add_noise, test_sample_func_train)

		numsamples_for_tests = 1000
		do_estimates = True
		
		if do_estimates:
			est_L = estimate_lipschitz(G, pz, numsamples=numsamples_for_tests)
			delta = 0.1
			est_gamma = estimate_S_REC(A_true, G, pz, numsamples=numsamples_for_tests, delta=delta)

			est_Cinv_norm = estimateCinv(G, pz, numsamples=numsamples_for_tests)

			SNR = estimateSNR(add_noise, test_sample_func_train)

			# Once these are done, reset to start

			if args.G == 'MNIST_real':
				test_sample_func_train, test_sample_func_test = make_test_sample_functions(dl_iterator=iter(dl), dl_test_iterator=iter(dl_test))			

			print('SNR', SNR, 'L', est_L, 'gamma', est_gamma, 'delta', delta, 'Cinvnorm', est_Cinv_norm)
			with open(savedir + 'Params', 'w') as f:
				f.write('SNR:' + str(SNR) + '\n')
				f.write('Est Lipschitz constant: ' + str(est_L) + '\n')
				f.write('(if MNIST_real, ignore gamma/delta in next line) G is ' + args.G + '\n')
				f.write('Est gamma with delta ' + str(delta) + ':' + str(est_gamma) + '\n')
				f.write('Est C*^{-1} norm ' + str(est_Cinv_norm))
			savedict['SNR'] = SNR
			savedict['est_L'] = est_L
			savedict['delta'] = delta
			savedict['est_gamma'] = est_gamma
			savedict['est_Cinv_norm'] = est_Cinv_norm

		epochs = args.epochs
		num_batch = args.batchfortesting

		if args.topk:
			topK = k 
		else:
			topK = None

		if args.ksvd:
			Y = np.array(sampler().detach().cpu())
			SVD_components = []
			SVD_transforms = []
			svd_sample_counts = []
			for i in range(epochs):
			    print('KSVD Epoch %d' % i)
			    svd_sample_counts.append(Y.shape[0])
			    aksvd = ApproximateKSVD(n_components=n)
			    dictionary = aksvd.fit(Y).components_
			    SVD_components.append(dictionary)
			    SVD_transforms.append(aksvd.transform)
			    gamma = aksvd.transform(Y)
			    newsamples = np.array(sampler().detach().cpu())
			    Y = np.concatenate((Y, newsamples), axis=0)

			all_recon_err_svd = []  
			all_std_svd = []
			all_std_log_svd = []
			for i in range(epochs):
			    epoch_recon_err_svd, std_svd, std_log_svd = recon_err_DL_svd(DLcomponents=SVD_components[i], DLtransform=SVD_transforms[i], test_sample_func=test_sample_func_test, num_batch=num_batch, add_noise_func=add_noise, topK=topK, returnstd=True)
			    all_recon_err_svd.append(epoch_recon_err_svd)
			    all_std_svd.append(std_svd)
			    all_std_log_svd.append(std_log_svd)
			print('SVD recon err, noisy', all_recon_err_svd)

			if args.G == 'MNIST_real':
				test_sample_func_train, test_sample_func_test = make_test_sample_functions(dl_iterator=iter(dl), dl_test_iterator=iter(dl_test))
			if args.dlsparsity:
				topKrange = range(1,n+1)
				sparsity_recon_errs = torch.zeros(len(topKrange))
				for ind, topKuse in enumerate(topKrange):
				    sparsity_recon_errs[ind] = recon_err_DL_svd(DLcomponents=SVD_components[-1], DLtransform=SVD_transforms[-1], test_sample_func=test_sample_func_test, num_batch=num_batch*3, topK=topKuse)
				fsp = plt.figure()
				plt.plot(list(topKrange), np.log10(sparsity_recon_errs))
				plt.xlabel('Sparsity, k-SVD')
				plt.ylabel('log10 Recon. Error')
				plt.savefig(savedir + 'kSVD_Sparsity.png')
				plt.close(fsp)

			DLcomponents=SVD_components[-1]
			DLtransform=SVD_transforms[-1]

			savedict['all_recon_err_svd'] = all_recon_err_svd
			savedict['all_std_svd'] = all_std_svd
			savedict['all_std_log_svd'] = all_std_log_svd
			savedict['svd_sample_counts'] = svd_sample_counts
			if args.dlsparsity:
				savedict['svd_sparsity'] = sparsity_recon_errs
		else:
			DLcomponents=None
			DLtransform=None

		if args.G == 'MNIST_real':
			test_sample_func_train, test_sample_func_test = make_test_sample_functions(dl_iterator=iter(dl), dl_test_iterator=iter(dl_test))

		if args.MOD:
			n_components = n 
			dict_learner = MiniBatchDictionaryLearning(n_components=n_components, transform_algorithm='lasso_lars', random_state=42)

			DL_sample_count = 0;
			DL_sample_counts = []
			all_learners = []
			for i in range(epochs):
			    print('Epoch %d' % i)
			    ys = sampler().detach().cpu()
			    DL_sample_count += ys.shape[0] # should be batch_size
			    DL_sample_counts.append(DL_sample_count)
			    ys = ys.view(ys.shape[0], -1)
			    dict_learner.partial_fit(ys)
			    all_learners.append(dict_learner)
			    
			all_recon_err_DL = []
			all_std_DL, all_std_log_DL = [], []
			for i in range(epochs):
			    epoch_recon_err_DL, std_DL, std_log_DL = recon_err_DL(all_learners[i], test_sample_func_test, num_batch, add_noise_func=add_noise, topK=topK, returnstd=True)
			    all_recon_err_DL.append(epoch_recon_err_DL)
			    all_std_DL.append(std_DL)
			    all_std_log_DL.append(std_log_DL)

			if args.G == 'MNIST_real':
				test_sample_func_train, test_sample_func_test = make_test_sample_functions(dl_iterator=iter(dl), dl_test_iterator=iter(dl_test))
			if args.dlsparsity:
				topKrange = range(1,n+1)
				sparsity_recon_errs = torch.zeros(len(topKrange))
				for ind, topKuse in enumerate(topKrange):
				    sparsity_recon_errs[ind] = recon_err_DL(all_learners[i], test_sample_func_test, num_batch*3, add_noise_func=add_noise, topK=topKuse)
				fsp = plt.figure()
				plt.plot(list(topKrange), np.log10(sparsity_recon_errs))
				plt.xlabel('Sparsity, MOD')
				plt.ylabel('log10 Recon. Error')
				plt.savefig(savedir + 'MOD_Sparsity.png')
				plt.close(fsp)

			dict_learner_MOD = all_learners[-1]

			savedict['all_recon_err_DL'] = all_recon_err_DL
			savedict['all_std_DL'] = all_std_DL
			savedict['all_std_log_DL'] = all_std_log_DL
			savedict['DL_sample_counts'] = DL_sample_counts
			if args.dlsparsity:
				savedict['DL_sparsity'] = sparsity_recon_errs
		else:
			dict_learner_MOD = None

		if args.G == 'MNIST_real':
			test_sample_func_train, test_sample_func_test = make_test_sample_functions(dl_iterator=iter(dl), dl_test_iterator=iter(dl_test))

		if args.autodiff:
			Ainit_fixed_forautodiff = A_init
			stacked_net_from_autodiff, ys_auto, outputs, losses_in_y, total_samples_used, intermediate_As, sample_counts_A = \
			        learn_A_autodiff(sampler, G, Ainit_fixed_forautodiff, noise_level=0, track_intermediate=True, \
			                         printevery=1e9, printevery_epoch=1, epochs=epochs, perzepochs=1000,lr=1e-2, P=P, k=k, projectA=args.projectA, doplot=False, saveplot=saveplotname) # was True
			losses_A = []
			losses_dict = []
			all_std_A = []
			all_std_log_A = []
			print('Computing intermediate losses')
			start = time.time()
			for intermediate_A in intermediate_As:
			    lss_A, dict_loss, std_A, std_log_A = recon_err_A(intermediate_A, test_sample_func_test, num_batch, G, true_A=A_true, k=k, \
			    	true_stacked_net=true_stacked_net, P=P, pz=pz, add_noise_func=add_noise, iters=1000, lr=1e-2, doplot=False, saveplot=saveplotname, extraname='Autodiff', returnstd=True)
			    print('lss_A', lss_A)
			    losses_A.append(lss_A)
			    losses_dict.append(dict_loss)
			    all_std_A.append(std_A)
			    all_std_log_A.append(std_log_A)
			end = time.time()

			savedict['losses_A'] = losses_A
			savedict['losses_dict'] = losses_dict
			savedict['all_std_A'] = all_std_A
			savedict['all_std_log_A'] = all_std_log_A
			savedict['intermediate_As'] = intermediate_As
			savedict['sample_counts_A'] = sample_counts_A
			savedict['losses_in_y'] = losses_in_y

		if args.G == 'MNIST_real':
			test_sample_func_train, test_sample_func_test = make_test_sample_functions(dl_iterator=iter(dl), dl_test_iterator=iter(dl_test))

		if args.altmin:
			meas_errs, zs, meas_losses_in_z_am, meas_err, A_model, intermediate_As_am, sample_counts_am = learn_A_altmin(test_sample_func_train, G, A_init, noise_level=0, iterations=epochs, track_intermediate=True,\
			                     lr=1e-3, printevery=1e9, sample_cap=1e10, reuse_training_samples=False, \
			                     P=P, zsqnorm_fac=0, k=k, saveplot=saveplotname) # no projection is possible anyway
			losses_A_am = []
			losses_dict_am = [] 
			all_std_am = []
			all_std_log_am = []
			print('Computing intermediate losses')
			start = time.time()
			for intermediate_A in intermediate_As_am:
			    lss_A_am, dict_loss_am, std_am, std_log_am = recon_err_A(intermediate_A, test_sample_func_test, num_batch, G, true_A=A_true, \
			    	true_stacked_net=true_stacked_net, P=P, pz=pz, k=k, add_noise_func=add_noise, iters=1000, lr=1e-2, doplot=False, saveplot=saveplotname, extraname='Altmin', returnstd=True)
			    losses_A_am.append(lss_A_am)
			    losses_dict_am.append(dict_loss_am)
			    all_std_am.append(std_am)
			    all_std_log_am.append(std_log_am)
			end = time.time()

			A_altmin = intermediate_As_am[-1]

			savedict['losses_A_am'] = losses_A_am
			savedict['losses_dict_am'] = losses_dict_am
			savedict['all_std_am'] = all_std_am
			savedict['all_std_log_am'] = all_std_log_am
			savedict['intermediate_As_am'] = intermediate_As_am
			savedict['sample_counts_am'] = sample_counts_am
			savedict['meas_losses_in_z_am'] = meas_losses_in_z_am
		else:
			A_altmin = None

		# plot everything!

		fig = plt.figure(figsize=(15,6))
		ax = fig.add_subplot(1, 1, 1)

		fig2 = plt.figure(figsize=(15,6))
		ax2 = fig2.add_subplot(1, 1, 1)
		if args.MOD:
			ax.errorbar(x=DL_sample_counts, y=np.log10(all_recon_err_DL), yerr=all_std_log_DL, marker='o', linestyle='-', label='MOD')
			ax2.errorbar(x=DL_sample_counts, y=(all_recon_err_DL), yerr=all_std_DL, marker='o', linestyle='-', label='MOD')
		if args.ksvd:
			ax.errorbar(x=svd_sample_counts, y=np.log10(all_recon_err_svd), yerr=all_std_log_svd, marker='o', linestyle='-', label='k-SVD')
			ax2.errorbar(x=svd_sample_counts, y=(all_recon_err_svd), yerr=all_std_svd, marker='o', linestyle='-', label='k-SVD')
		if args.autodiff:
			ax.errorbar(x=sample_counts_A[0:len(losses_A)], y=np.log10(losses_A), yerr=np.array(all_std_log_A), marker='o', linestyle='-', label='Autodiff')
			ax2.errorbar(x=sample_counts_A[0:len(losses_A)], y=np.array(losses_A), yerr=np.array(all_std_A), marker='o', linestyle='-', label='Autodiff')
		if args.altmin:
			ax.errorbar(x=sample_counts_am, y=np.log10(losses_A_am), yerr=all_std_log_am, marker='o', linestyle='-', label='Alt min')
			ax2.errorbar(x=sample_counts_am, y=np.array(losses_A_am), yerr=all_std_am, marker='o', linestyle='-', label='Alt min')
		ax.set_xlabel('Samples')
		ax.set_ylabel('Log Squared Loss')
		ax.legend()
		fig.savefig(savedir + 'LogLosses.pdf')

		ax2.set_xlabel('Samples')
		ax2.set_ylabel('Squared Loss')
		ax2.legend()
		fig2.savefig(savedir + 'Losses.pdf')

		plt.close(fig)
		plt.close(fig2)


		with open(filename, 'wb') as f:
			pickle.dump(savedict, f)

		if (args.A == 'small_shapes' or args.A == 'shapes' or args.A == 'MNIST_real' or args.G == 'MNIST_vae' or args.G == 'MNIST_real') and args.autodiff:
			
			def reshape_visualize_func(vec):
				if type(vec) == type(torch.rand(3)):
					vec = vec.detach().cpu()
				return vec.reshape(numrow, numcol)

			def log_reshape_func(vec):
				if type(vec) == type(torch.rand(3)):
					vec = vec.detach().cpu()
					vec = torch.log10(torch.abs(vec))
				else:
					vec = np.log10(np.abs(vec))
				return vec.reshape(numrow, numcol)

			visualize_compare(k, test_sample_func_test, add_noise, A=intermediate_As[-1], G=G, reshape_visualize_func=reshape_visualize_func, numex=3, DLcomponents=DLcomponents, \
				DLtransform=DLtransform, topK=topK, dict_learner_MOD=dict_learner_MOD, A_altmin=A_altmin, iters=400, lr=1e-2, doplot=False, saveplot=savedir, extraname="", residual=False)

			visualize_compare(k, test_sample_func_test, add_noise, A=intermediate_As[-1], G=G, reshape_visualize_func=reshape_visualize_func, numex=3, DLcomponents=DLcomponents, \
				DLtransform=DLtransform, topK=topK, dict_learner_MOD=dict_learner_MOD, A_altmin=A_altmin, iters=400, lr=1e-2, doplot=False, saveplot=savedir, extraname="Residual", residual=True)

			visualize_compare(k, test_sample_func_test, add_noise, A=intermediate_As[-1], G=G, reshape_visualize_func=log_reshape_func, numex=3, DLcomponents=DLcomponents, \
				DLtransform=DLtransform, topK=topK, dict_learner_MOD=dict_learner_MOD, A_altmin=A_altmin, iters=400, lr=1e-2, doplot=False, saveplot=savedir, extraname="LogResidual", residual=True, vminmax=[-10, 1])

			# Now, visualize filters
			visualize_filters(A_true=A_true, A=intermediate_As[-1], A_init=A_init, G=G, reshape_visualize_func=reshape_visualize_func, DLcomponents=DLcomponents, dict_learner_MOD=dict_learner_MOD, A_altmin=A_altmin, saveplot=savedir, extraname="")

		if args.autodiff:
			# See if the learned cols of A align (they likely will not)
			A_true_normed = utils.normalize_cols(A_true)
			A_got_normed = utils.normalize_cols(intermediate_As[-1])
			innerprods = torch.abs(torch.matmul(torch.transpose(A_got_normed, 0, 1), A_true_normed))
			fig = plt.figure(figsize=(9,9))
			ax = fig.add_subplot(1,1,1)
			plt.imshow(innerprods.detach().cpu())
			plt.colorbar()
			fig.savefig(savedir + 'LearnedA_Autodiff_ColCorrespondences.pdf')

			innerprods = torch.abs(torch.matmul(torch.transpose(A_true_normed, 0, 1), A_true_normed))
			fig = plt.figure(figsize=(9,9))
			ax = fig.add_subplot(1,1,1)
			plt.imshow(innerprods.detach().cpu())
			plt.colorbar()
			fig.savefig(savedir + 'LearnedA_sanitycheck_ColCorrespondences.pdf')

		if args.altmin:
			A_true_normed = utils.normalize_cols(A_true)
			A_got_normed = utils.normalize_cols(intermediate_As_am[-1])
			innerprods = torch.abs(torch.matmul(torch.transpose(A_got_normed, 0, 1), A_true_normed))
			fig = plt.figure(figsize=(9,9))
			ax = fig.add_subplot(1,1,1)
			plt.imshow(innerprods.detach().cpu())
			plt.colorbar()
			fig.savefig(savedir + 'LearnedA_Altmin_ColCorrespondences.pdf')