import os

from visualize import visualize, vis_image, to_img, psnr
import argparse
from dual_1by1_conv import main as dual_main
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
import time


circular_pad = False

act_type = "relu"

if circular_pad:
    pad_size = 1
else:
    pad_size = 0

def add_noise(img,noise_std):
    noise = torch.randn(img.size()) * noise_std
    noisy_img = img + noise.to(device)
    return noisy_img

def gen_noise(img, noise_std, dist):
    if dist =='gaussian':
        noise = torch.randn(img.size())* noise_std
    else:
        m = torch.distributions.exponential.Exponential(torch.tensor([1/np.sqrt(noise_std)]))
        noise = m.sample(img.size()).squeeze(4)
    return noise

class Quad(nn.Module):
    def __init__(self):
        super(Quad,self).__init__() 

    def forward(self, input):
        return input**2

activations = {"relu":nn.ReLU(True),"quad":Quad()}

#Data
class MNISTData():

    def __init__(self, train_batch_size, test_batch_size):

        all_transforms = [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]
        transform = transforms.Compose(all_transforms)
        trainset = MNIST(root='./data', train=True, download=True, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=False, num_workers=8)
        testset = MNIST(root='./data', train=False, download=True, transform=transform)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=8)
        
class CIFAR10Data():

    def __init__(self, train_batch_size, test_batch_size):

        all_transforms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
        transform = transforms.Compose(all_transforms)
        trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=8)
        testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=8)


class NoisyData():
    def __init__(self, trainset, testset, batch_size):
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=8)
        
class PrepareData3D(Dataset):
    def __init__(self, X, y, z):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y
        
        if not torch.is_tensor(z):
            self.z = torch.from_numpy(z)
        else:
            self.z = z
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.z[idx]


#network
class FCNet(nn.Module):
    def __init__(self,dim,chans):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 1),
            activations[act_type],
            nn.Linear((dim+2*pad_size)*(dim+2*pad_size),(dim+2*pad_size)*(dim+2*pad_size)))
        self.layer1 = nn.Linear((dim+2*pad_size)*(dim+2*pad_size)*chans,(dim+2*pad_size)*(dim+2*pad_size)*chans)
        self.act1 = activations[act_type]
        self.layer2 = nn.Linear((dim+2*pad_size)*(dim+2*pad_size)*chans,(dim+2*pad_size)*(dim+2*pad_size)*chans)
    def forward(self, x):
        feats = self.act1(self.layer1(x))
        x = self.layer2(feats)
        return x,feats


#network
class ConvNet0(nn.Module):
    def __init__(self):
        super(ConvNet0, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=1, out_channels=1024, kernel_size=7, stride=1, padding=3, bias=False)
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False)
    
    def forward(self, x):
        x = self.layer2(self.act1(self.layer1(x)))
        return x

class UnrolledNet0(nn.Module):
    def __init__(self):
        super(UnrolledNet0, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet0()] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet0() for i in range(args.num_unrolled)])   
    
    def forward(self, x):
        for convnet in self.convblocks:
            x = convnet(x)
        return x

#network
class ConvNet1(nn.Module):
    def __init__(self, kernel_size, padding, skip_connection, primal_filters=512):
        super(ConvNet1, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=1, out_channels=primal_filters, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=primal_filters, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False)
        self.skip = skip_connection
    
    def forward(self, x):
        y = self.layer2(self.act1(self.layer1(x)))
        if self.skip:
            y += x
        return y

class UnrolledNet1(nn.Module):
    def __init__(self, kernel_size, padding, skip_connection):
        super(UnrolledNet1, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet1(kernel_size, padding, skip_connection)] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet1(kernel_size, padding, skip_connection) for i in range(args.num_unrolled)])   
    
    def forward(self, x):
        for convnet in self.convblocks:
            x = convnet(x)
        return x


#network
class ConvNet2(nn.Module):
    def __init__(self):
        super(ConvNet2, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
    
    def forward(self, x):
        x = self.layer2(self.act1(self.layer1(x)))
        return x
    
class UnrolledNet2(nn.Module):
    def __init__(self):
        super(UnrolledNet2, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet2()] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet2() for i in range(args.num_unrolled)])   
    
    def forward(self, x):

        for convnet in self.convblocks:
            x = convnet(x)
        return x

def preprocess(img):
    
    #t0 = time.time()
    img = torch.squeeze(img)
    for i in range(len(img)):

        cur_img = img[i,:,:]
        img_cpu = cur_img.cpu().detach().numpy()
        pca = PCA(n_components=args.lowrank)
        vis_out_cpu = pca.fit_transform(img_cpu)  
        #zz = np.matmul(vis_out_cpu,pca.components_) + pca.mean_
        pca_inv = pca.inverse_transform(vis_out_cpu)
        torch_lr = torch.from_numpy(pca_inv)
        img[i,:,:] = torch_lr
        
        vis = False
        if vis:
            pca_inv -= pca_inv.min()
            pca_inv /= pca_inv.max()
            fig, axs = plt.subplots(1,1)
            axs.imshow(pca_inv.reshape(28,28),cmap='gray')
            plt.savefig(args.log_dir+'/LR_input_np'+args.problem+'_'+str(args.noise_std)+'.png')

        #t1 = time.time()
        #img_lr =  torch.pca_lowrank(img[i,:,:],args.lowrank)
        #A = torch.matmul(img_lr[0]*img_lr[1],torch.transpose(img_lr[2],0,1)) + torch.mean(img[i,:,:],dim=0)
        #print('{} seconds torch'.format(time.time() - t1))
        
    img = img.view(img.size(0),1, img.size(1),img.size(2))  
    #print('{} seconds preprocess'.format(time.time() - t0)) #takes 1 second
    return img

    
def loss_primal(yhat, y, model, args):

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model
    
    # term 1
    criterion = nn.MSELoss()
    loss = 0.5*criterion(yhat,y)   #averaged over batch and pixels -- per-pixel loss
    #loss = 0.5 * torch.norm(yhat - y)**2

    if args.num_unrolled>1:
        for i in range(args.num_unrolled):
            loss = loss + args.beta/2 * torch.norm(model_prox.convblocks[i].layer1.weight)**2   
            loss = loss + args.beta/2 * torch.norm(model_prox.convblocks[i].layer2.weight)**2   
    else:
        loss = loss + args.beta/2 * torch.norm(model_prox.layer1.weight)**2
        loss = loss + args.beta/2 * torch.norm(model_prox.layer2.weight)**2

    # cost averaging
    return loss  #/len(y)

def train(model, data):

    trainset = data.trainloader
    testset = data.testloader
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.primal_lr)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, 0.1)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=args.verbose,
                                                           factor=0.5,
                                                          patience=300,
                                                           eps=1e-15)
    writer = SummaryWriter(log_dir=args.log_dir)

    for epoch in range(args.primal_epochs):
        total_loss = 0
        for i, datum in enumerate(trainset, 0):

            img = datum[0].to(device)
            noise = datum[2].to(device)

            #print(img.size())
            if args.lowrank < 28:
                img = preprocess(img)

            if circular_pad:
                img = nn.functional.pad(img, 2*(pad_size,)+2*(pad_size,), mode='circular')

            #img = img.view(img.size(0),1, -1)  #only for FCNet
            noisy_img = img + noise


            # ===================forward=====================
            output = model(noisy_img)
            loss = loss_primal(output, img, model, args)
            #loss = criterion(output, img)
            total_loss += loss.item()
            
            train_psnr = psnr(output,img)

            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer.add_scalar('Train_PSNR', train_psnr.item(), epoch*len(trainset) + i)
        scheduler.step(epoch)
        writer.add_scalar('train_loss/primal', total_loss/len(trainset),  epoch)

        if epoch % args.printfreq == 0:
            test_loss = test(model,testset,writer,epoch)

            # ===================log========================
            print('epoch [{}/{}], train_loss:{:.4f}, test_loss:{:.4f}'.format(epoch+1, args.primal_epochs, total_loss/len(trainset), test_loss))

            
            if epoch % 20 == 0:
                vis_image(img,writer,epoch,args,'primal_train_gt')
                vis_image(output,writer,epoch,args,'primal_train_output')
                vis_image(noisy_img,writer,epoch,args,'primal_train_input')

    writer.close()

    return loss

def test(model, testset, writer, epoch):
    criterion = nn.MSELoss()
    test_loss = 0
    with torch.no_grad():
        for i, datum in enumerate(testset, 0):

            img = datum[0].to(device)
            noise = datum[2].to(device)
            
            if args.lowrank < 28:
                img = preprocess(img)

            if circular_pad:
                img = nn.functional.pad(img, 2*(pad_size,)+2*(pad_size,), mode='circular')

            #img = img.view(img.size(0),1, -1)  #only for FCNet
            noisy_img = img + noise

            # ===================forward=====================
            output = model(noisy_img)
            #loss = criterion(output, img)
            loss = loss_primal(output, img, model, args)
            test_loss += loss.item()/len(testset)
            test_psnr = psnr(output,img)
        writer.add_scalar('test_loss/primal', test_loss, epoch)
        writer.add_scalar('Test_PSNR', test_psnr, epoch)
        #writer.add_scalar('Test_PSNR', psnr.item(), epoch*len(trainset) + i)
        # print('Test_loss:{:.4f}'.format(test_loss))
    vis_image(img,writer,epoch,args,'primal_test_gt')
    vis_image(output,writer,epoch,args,'primal_test_output')
    vis_image(noisy_img,writer,epoch,args,'primal_test_input')

    return test_loss


if __name__ == "__main__":
    
    """ This is executed when run from the command line """
    parser = argparse.ArgumentParser()
    # Required positional argument
    parser.add_argument("--alternating", help="Whether to use alternating (default) or joint minimization of dual variables", default=False, type=bool)
    parser.add_argument("--optimizer", help='SGD or Adam', default='Adam')
    parser.add_argument("--lr", help='Initial Learning Rate', default=1e-4, type=float)
    parser.add_argument("--primal_lr", help='Initial Learning Rate', default=1e-3, type=float)
    parser.add_argument("--momentum", help='Momentum for SGD', default=0.9, type=float)
    parser.add_argument("--primal_epochs", help='Number of epochs to train for the primal', default=100, type=int)
    parser.add_argument("--epochs", help='Number of epochs to train if training jointly, or number of epochs per variable descent if alternating', default=100, type=int)
    parser.add_argument("--alternations", help='Number of alternations if training alternating', default=10, type=int)
    parser.add_argument("--bs", help='Batch size', default=25, type=int)
    parser.add_argument('--dataset', help='Dataset', default='MNIST', type=str)
    parser.add_argument('--beta', help='Regularization parameter', default=1e-5, type=float)
    parser.add_argument('--rho', help='Hinge loss parameter (penalizes infeasible solutions)', default=1e-6, type=float)
    parser.add_argument('--rho_factor', help='Factor by which rho increases each epoch (to force feasibility)', default=1.0, type=float)
    parser.add_argument('--device', help='Device to use', default='cuda', type=str)
    parser.add_argument('--verbose', help='Show intermediate steps', default=True, type=bool)
    parser.add_argument('--modeldir', help='Directory to save final model', default='./models', type=str)
    parser.add_argument('--outputstr', help='String output', default='dual_model.pkl', type=str)
    parser.add_argument('--datadir', help='Directory of dataset', default='', type=str)
    parser.add_argument('--P', help='number of random relu patterns to use', default=8000, type=int)
    parser.add_argument('--signpth', help='Path to sign patterns, load if exists, otherwise save', default='sign_patterns', type=str)
    parser.add_argument('--lowrank', help='Frequency to print', default=25, type=int)
    parser.add_argument('--printfreq', help='Frequency to print', default=5, type=int)
    parser.add_argument('--mode', help='cls (classification) or den (denoising)', default='den', type=str)
    parser.add_argument('--noise_std', help='standard deviation of noise to add if denoising', default=0.75, type=float)
    parser.add_argument('--use_saved_checkpoint', help='use saved checkpoint for initialization', default=True, type=bool)
    parser.add_argument('--log_dir', help='Directory for tensorboard logs', default='./logs', type=str)
    parser.add_argument('--run', help='Name of the current run', default='run0', type=str)
    parser.add_argument('--dataset_subsample_factor', help='Factor by which to subsample the dataset for training', default=0.001, type=float)
    parser.add_argument('--kernel_size', default=3, type=int)
    parser.add_argument('--padding', default=1, type=int)
    parser.add_argument('--sign_subsample_factor', help='Factor by which to subsample the training set to determine sign patterns', default=0.1, type=float)
    parser.add_argument('--num_unrolled', type=int, default=0, help='Number of unrolled iterations')
    parser.add_argument('--share_weights', type=bool, default=False, help='Weight sharing for unrolled')
    parser.add_argument('--skip_connection', type=bool, default=False, help='Residual skip connection')
    parser.add_argument('--dist', type=str, default='gaussian')
    parser.add_argument('--primal_filters', type=int, default=25)

    args = parser.parse_args()
    
    args.run = 'unr0'
    #args.primal_epochs = 10
    args.log_dir = './paper_results/'+args.dataset+'_twolayer_conv_relu_'+str(args.noise_std) + '_'+ str(args.kernel_size)+ '_' + str(args.num_unrolled) + '_' + str(args.dataset_subsample_factor) + '_' + str(args.primal_filters) + '_' + args.dist
    # args.datadir = './data'
    args.use_saved_checkpoint = False
    
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    #identify GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(args)
    # args.problem = 'dual'
    # cvx_loss = dual_main(args)

    args.problem = 'primal'
    
    if args.dataset == 'MNIST':
        data = MNISTData(int(60000*args.dataset_subsample_factor), 10000)
        args.dim = 28
        args.chans = 1
    elif args.dataset == 'CIFAR-10':
        data = CIFAR10Data(int(50000*args.dataset_subsample_factor), 10000)
        args.dim = 32
        args.chans = 3

    print('generating noisy train set')
    for A, y in data.trainloader:
        break

    noise = gen_noise(A, args.noise_std, args.dist)
    print('training set size', A.shape)

    trainset = PrepareData3D(A, y, noise)
    
    for A_test, y_test in data.testloader:
        pass

    print('generating noisy test set')
    test_noise = gen_noise(A_test, args.noise_std, args.dist)
    testset = PrepareData3D(A_test, y_test, test_noise)

    data = NoisyData(trainset, testset, args.bs)

    if args.num_unrolled > 0:
        model = UnrolledNet1(args.kernel_size, args.padding, args.skip_connection)
    else:
        model = ConvNet1(args.kernel_size, args.padding, args.skip_connection, args.primal_filters)

    if torch.cuda.device_count() > 1:
        args.parallel = True
        model = nn.DataParallel(model)    #parallelize in BATCH DIMENSION. it creates a model variable for each gpu and assigns a piece of data to each gpu -- it splits the models, computes per gpu, and aggregates the results
    else:
        args.parallel = False
        
    model.to(device)   #sends the net modules and buffers into cuda tensors

    #train $ save the model
    args.problem = 'primal'
    print("starting training")
    noncvx_loss = train(model, data)

    PATH = args.log_dir+'/.pth'
    torch.save(model.state_dict(), PATH)
    

    # visualize(model, data.testloader, args)
    # duality_gap = noncvx_loss.item() - cvx_loss.item()
    # print('duality_gap:',duality_gap)
    # writer = SummaryWriter(log_dir=args.log_dir)
    # writer.add_scalar('Duality_Gap', duality_gap)

    # #load and test
    # args.problem = 'primal'
    # model.load_state_dict(torch.load(PATH))
    # t0 = time.time()
    # test(model,data.testloader)
    # device = torch.device("cpu")
    # print('{} seconds'.format(time.time() - t0))
    
