import torchvision.transforms as transforms
import numpy as np
import torch
from prior_MNIST import VAE
from utils import pz_maker
import matplotlib.pyplot as plt
import torchvision
import torch
from torch.autograd import Variable
import torch.optim as optim
from PIL import Image, ImageFilter
import models
from models import A_matrix, GenericStackedNet
import utils
from utils import decode_updated, update_generic_A, set_up_A_model
import copy
import time
import gc
import scipy
import sklearn.decomposition
from sklearn.datasets import make_sparse_coded_signal
from sklearn.decomposition import MiniBatchDictionaryLearning


if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


def learn_A_autodiff(test_sample_func, G, Ainit_fixed_forautodiff, noise_level, epochs=5, track_intermediate=False,\
                     lr=1e-3, printevery=1e9, printevery_epoch=10,perzepochs=1000, sample_cap=1e10, reuse_training_samples=False, \
                     P=10, zsqnorm_fac=0, k=5, projectA=False, doplot=False, saveplot=None):
    losses_in_y = []
    autodiffstart = time.time()
    total_samples_used = 0
    used_up_samples_flag = False
    m = Ainit_fixed_forautodiff.shape[0]
    n = Ainit_fixed_forautodiff.shape[1]
    Ainit_forautodiff = copy.deepcopy(Ainit_fixed_forautodiff)
    if track_intermediate:
        intermediate_As = []
        sample_counts = []
    for i in range(epochs):
        if i==0 or not reuse_training_samples:
            if total_samples_used < sample_cap:
                ys_auto = test_sample_func() 
                if not used_up_samples_flag:
                    total_samples_used += P
            else:
                used_up_samples_flag = True
            ys_auto = ys_auto.reshape(P, -1).to(device)
            ys_auto += noise_level*torch.rand((P, m), device=device)

        A_model_forautodiff = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': Ainit_forautodiff}) 
        stacked_net = GenericStackedNet(A_model_forautodiff, G)
        stacked_net.A.load_state_dict(A_model_forautodiff.state_dict())
        zs_autodiff = Variable(torch.normal(0, 1, (P, k)).to(device), requires_grad=True) # Gradient is probably 0 if you initialize to zero
        optimizer = optim.Adam([zs_autodiff] + list(stacked_net.parameters()), lr=lr)
        torch.autograd.set_detect_anomaly(True)

        losses_this_epoch = []
        for j in range(perzepochs):
            optimizer.zero_grad() 
            outputs = stacked_net(zs_autodiff) # P by m by 1

            loss_in_y = (torch.norm(outputs - ys_auto)**2 / torch.norm(ys_auto)**2) + zsqnorm_fac*torch.norm(zs_autodiff)**2; 
            losses_this_epoch.append(loss_in_y.detach().cpu())
            losses_in_y.append(loss_in_y.detach().cpu())
            loss_in_y.backward(retain_graph=True)
            optimizer.step()

            if projectA:
                with torch.no_grad():
                    stacked_net.A.A.weight = torch.nn.Parameter(utils.project(stacked_net.A.A.weight, Ainit_fixed_forautodiff, torch.norm(stacked_net.A.A.weight) / (2*n)))
            if j % printevery == 0 and i % printevery_epoch == 0:
                print('Epoch %d Iter %d Loss (normalized) y %f' % (i, j, loss_in_y))
                print('    z epoch', j, 'zsautodiffnorm', torch.norm(zs_autodiff), type(zs_autodiff.grad))
        if doplot or saveplot is not None:  
            print('losses_this_epoch', losses_this_epoch[0:5], losses_this_epoch[-5:-1])
            ff = plt.figure(figsize=(12,6))
            plt.xlabel('Epoch')
            plt.plot(np.log10(losses_this_epoch))
            plt.ylabel('log10(Loss) per epoch')
            if saveplot is not None:
                plt.savefig(saveplot + 'Autodiff_Epoch' + str(i) + '.png')
            if doplot:
                plt.show()
            plt.close(ff)
        Ainit_forautodiff = copy.deepcopy(stacked_net.A.A.weight.detach()) # do not delete: used
        if track_intermediate:
            intermediate_As.append(Ainit_forautodiff)
            sample_counts.append(total_samples_used)
        del A_model_forautodiff
        gc.collect()
    autodiffend = time.time()

    print('Time elapsed for autodiff (min):', (autodiffend-autodiffstart)/60)
    print('Total samples used:', total_samples_used)
    if track_intermediate:
        return stacked_net, ys_auto, outputs, losses_in_y, total_samples_used, intermediate_As, sample_counts
    else:
        return stacked_net, ys_auto, outputs, losses_in_y, total_samples_used


def learn_A_altmin(test_sample_func, G, Ainit_fixed, noise_level, iterations=5, track_intermediate=False,\
                     lr=1e-3, printevery=1e9, sample_cap=1e10, reuse_training_samples=False, \
                     P=10, zsqnorm_fac=0, k=5, maxiterdecode=1000, maxiterA=1000, doplot=False, saveplot=None):
    sample_counts = []
    intermediate_As = []
    m = Ainit_fixed.shape[0]
    n = Ainit_fixed.shape[1]
    Ainit = copy.deepcopy(Ainit_fixed) 
    A_model = set_up_A_model({'n':n, 'm':m, 'device':device, 'type':'linear', 'init': Ainit})
    lr_A = 1e-3
    meas_errs = []; recon_errs = [];
    total_samples_used = 0
    for it in range(iterations):
        st_samples = time.time()
        if it==0 or not reuse_training_samples:
            ys = test_sample_func() 
            ys = ys.reshape(P, -1).to(device)
            ys += noise_level*torch.rand((P, m), device=device)
            total_samples_used += P

        zs, meas_losses_in_z = decode_updated(A_model, G, ys, k, maxiter=maxiterdecode, lr=1e-2, device=device, zsqnorm_fac=zsqnorm_fac, init='rand') #100,1e01  #40, lr=3e-1)
        if doplot or (saveplot is not None):
            fg = plt.figure()
            plt.plot(np.log10(np.array(meas_losses_in_z)))
            plt.ylabel('log loss')
            plt.title('Meas losses in z ' + str(it))
            if saveplot is not None:
                plt.savefig(saveplot + 'Altmin_z_Epoch' + str(it) + '.png')
            if doplot:
                plt.show()
            plt.close(fg)

        # In practice, Adam works FAR better than SGD?!
        # this is really supposed to be just 1 step in A...departure from theory
        A_model, losses_arr = update_generic_A(A_model, zs, ys, G, maxiter=maxiterA, lr=lr_A) # was 1000
        
        # Could add projection step here
        
        if doplot:
            fg = plt.figure()
            plt.plot(np.log10(np.array(losses_arr)))
            plt.ylabel('log loss')
            plt.title('Losses in A ' + str(it))
            if saveplot is not None:
                plt.savefig(saveplot + 'Altmin_A_Epoch' + str(it) + '.png')
            if doplot:
                plt.show()
            plt.close(fg)
        
        if track_intermediate:
            intermediate_As.append(copy.deepcopy(A_model.A.weight.detach()))
            sample_counts.append(total_samples_used)
            
        meas_err = float(torch.norm(ys - A_model(G(zs)))**2 / torch.norm(ys)**2); meas_errs.append(meas_err)

        if it % printevery == 0:
            print('Meta iteration', it, 'Errors | Meas (n; noise): {:0.4f}'.format(meas_err)) #, meas_err, 'Recon Err (no noise):', recon_err, 'Error in A', A_err, 'Error in z', z_err)
    print('Total samples used:', total_samples_used)
    if track_intermediate:
        return meas_errs, zs, meas_losses_in_z, meas_err, A_model, intermediate_As, sample_counts
    else:
        return meas_errs, zs, meas_losses_in_z, meas_err, A_model
