import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import cvxpy as cvx
import numpy as np
import matplotlib.pyplot as plt
from utils import *

# the iterative shrinkage thresholding algorithm for Lasso
def ista(A, y, lambda_reg, zeta, maxiter):
    m, n = A.shape
    x_est = torch.zeros((n, 1)).cuda()
    x_est.requires_grad_(False)
    for iter_num in range(maxiter):
        x_est = x_est - zeta / m * torch.matmul(torch.transpose(A, 0, 1), torch.matmul(A, x_est) - y)
        x_est = shrinkage(x_est, zeta*lambda_reg) 
    return x_est

# the algorithm used in the CSGM paper
def csgm(A, y, vae, num_restarts=10, num_inner_iter=1000, inner_lr=0.01, regularization=False):
    k, n = 20, 784 # latent and ambient dimensions
    z_best = torch.randn(1, k).cuda()
    z_best.requires_grad_(False)
    for _ in range(num_restarts):  # Default: Performing 10 random restarts, i.e. repeat for 10 initial random z's
        z = torch.randn(1, k).cuda()
        z.requires_grad_(True)

        optimizer = optim.Adam([z], lr=inner_lr)  # Using same optimizer configuration as in the paper
        # Only parameter for optimization is z
        for _ in range(num_inner_iter):  # Default: Performing 1000 gradient steps, as in the CSGM paper
            AG_z = torch.mm(A, vae.decoder(z).view(-1, 1).cuda())  # Computing AG(z)
            if regularization:
                loss = torch.pow(torch.norm(AG_z - y), 2) + 0.1 * torch.pow(torch.norm(z), 2)
            else:
                loss = torch.pow(torch.norm(AG_z - y), 2)
            optimizer.zero_grad()
            loss.backward()  # Using automatic differentiation on the loss
            optimizer.step()
        with torch.no_grad():
            if loss < torch.pow(torch.norm(torch.mm(A, vae.decoder(z_best).view(-1, 1).cuda()) - y), 2):
                z_best = z  # Keeping z with smallest measurement error
    return vae.decoder(z_best).view(n, 1)

# the main function for MNIST
def solver(m, images, vae, lambda_reg=0.1, zeta=1., num_trials=10, maxiter=20, R=5, Delta=3, sigma=0.1, model='1bit', method='CSGM'):
    n = 784
    T = model_T(model)
    num_images = len(images)
    res = np.zeros((num_images*num_trials, n))
    counter = 0
    for _ in range(num_trials):
        A = genA(m, n).cuda()
        if model in ['relu', '1bitDither', 'unifQuant']:
            if model == 'relu':
                normal = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
                noise = sigma * normal.sample((m,)).cuda()
            elif model == '1bitDither':
                lambda_val = R * np.sqrt(np.log(m)) 
                noise = (2 * lambda_val * torch.rand((m,1)).cuda() - lambda_val) # the random dither
                T = 1.0 / lambda_val
            else:
                noise = (Delta * torch.rand((m,1)) - 0.5 * Delta).cuda() # the random dither
            noise.requires_grad_(False)
            
        for x_star in images:
            x_star = x_star.view(n,1).cuda()
            
            if model == '1bit':
                y = torch.sign(torch.mm(A, x_star/torch.norm(x_star)))
            elif model == 'relu':
                y = torch.max(torch.mm(A,x_star/torch.norm(x_star)) + noise, torch.zeros(m, 1).cuda())
            elif model == '1bitDither':
                y = torch.sign(torch.mm(A, x_star).cuda() + noise)
            elif model == 'unifQuant':
                y = unif_quant(torch.mm(A, x_star) + noise, Delta)
            else:
                print('Model not found!')
                return
            
            y.cuda()
            y.requires_grad_(False)
            
            if model in ['1bitDither', 'unifQuant']:
                if method == 'CSGM':
                    x_est = csgm(T*A, y, vae)
                elif method == 'Lasso':
                    if model == 'unifQuant':
                        lambda_reg *= 10
                    else:
                        lambda_reg *= 0.3
                    x_est = ista(A, y, lambda_reg, zeta, maxiter)
                else:
                    print('Method not found!')
                    return
            else: 
                if method == 'CSGM':
                    x_est = csgm(T*A/torch.norm(x_star), y, vae)
                elif method == 'Lasso':
                    x_est = ista(A, y, lambda_reg, zeta, maxiter)
                else:
                    print('Method not found!')
                    return
            i = 0
            res[counter+i*num_images*num_trials] = x_est.view(1, n).detach().cpu().numpy()
            print(counter)
            counter += 1
    return res
    
    
# the algorithm used in the CSGM paper; slightly modified for CelebA
def csgm_celeba(A, y, dcgan, num_restarts=10, num_inner_iter=1000, inner_lr=0.01, regularization=False):
    k, n = 100, 12288 # latent and ambient dimensions
    z_best = torch.randn(1, k, 1, 1).cuda()
    z_best.requires_grad_(False)
    for _ in range(num_restarts):  # Default: Performing 10 random restarts, i.e. repeat for 10 initial random z's
        z = torch.randn(1, k, 1, 1).cuda()
        z.requires_grad_(True)

        optimizer = optim.Adam([z], lr=inner_lr)  # Using same optimizer configuration as in the paper
        # Only parameter for optimization is z
        for _ in range(num_inner_iter):  # Default: Performing 1000 gradient steps, as in the CSGM paper
            AG_z = torch.mm(A, dcgan(z).view(-1, 1).cuda())  # Computing AG(z)
            if regularization:
                loss = torch.pow(torch.norm(AG_z - y), 2) + 0.1 * torch.pow(torch.norm(z), 2)
            else:
                loss = torch.pow(torch.norm(AG_z - y), 2)
            optimizer.zero_grad()
            loss.backward()  # Using automatic differentiation on the loss
            optimizer.step()
        sample = dcgan(z).cuda()
        with torch.no_grad():
            if loss < torch.pow(torch.norm(torch.mm(A, dcgan(z_best).view(-1, 1).cuda()) - y), 2):
                z_best = z  # Keeping z with smallest measurement error
    return dcgan(z_best).view(n, 1)
    
# the main function for CelebA
def solver_celeba(m, images, dcgan, num_trials=10, maxiter=10, R=10, Delta=100, sigma=0.1, model='1bit'):
    n = 12288
    T = model_T(model)
    num_images = len(images)
    res = np.zeros((num_images*num_trials, n))
    counter = 0
    for _ in range(num_trials):
        A = genA(m, n).cuda()
        for x_star in images:
            x_star = x_star.view(n,1).cuda()
            
            if model == '1bit':
                y = torch.sign(torch.mm(A, x_star/torch.norm(x_star)))
            elif model == 'relu':
                normal = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
                noise = sigma * normal.sample((m,)).cuda()
                noise.requires_grad_(False)
                y = torch.max(torch.mm(A,x_star/torch.norm(x_star)) + noise, torch.zeros(m, 1).cuda())
            elif model == '1bitDither':
                lambda_val = R * np.sqrt(np.log(m)) 
                noise = (2 * lambda_val * torch.rand((m,1)).cuda() - lambda_val)
                noise.requires_grad_(False)
                y = torch.sign(torch.mm(A, x_star).cuda() + noise)
                T = 1.0 / lambda_val
            elif model == 'unifQuant':
                noise = (Delta * torch.rand((m,1)) - 0.5 * Delta).cuda()
                noise.requires_grad_(False)
                y = unif_quant(torch.mm(A, x_star) + noise, Delta)
            else:
                print('Model not found!')
                return
            
            y.cuda()
            y.requires_grad_(False)
            
            if model in ['1bitDither', 'unifQuant']:
                x_est = csgm_celeba(T*A, y, dcgan)  
            else: 
                x_est = csgm_celeba(T*A/torch.norm(x_star), y, dcgan) 
            i = 0
            res[counter+i*num_images*num_trials] = x_est.view(1, n).detach().cpu().numpy()
            print(counter)
            counter += 1
    return res
