import os

from visualize import visualize, vis_image, to_img, psnr
import argparse
from dual_relu import main as dual_main
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
from siren_pytorch import Sine
from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
import time

from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx

# import custom classes
from utils.datasets import SliceData
from utils.subsample import RandomMaskFunc,  PoissonLoadMaskFunc

batch_size = 2

circular_pad = False

act_type = "relu"

if circular_pad:
    pad_size = 1
else:
    pad_size = 0

class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=True):
        """
        Args:
            mask_func (utils.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        self.mask_func = mask_func
        #self.resolution = resolution
        self.use_seed = use_seed

    def __call__(self, kspace, maps, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        # Convert everything from numpy arrays to tensors
        kspace = cplx.to_tensor(kspace).unsqueeze(0)
        maps   = cplx.to_tensor(maps).unsqueeze(0)
        target = cplx.to_tensor(target).unsqueeze(0)
        norm = torch.sqrt(torch.mean(cplx.abs(target)**2))

        #print(kspace.shape)
        #print(maps.shape)
        #print(target.shape)

        # Apply mask in k-space
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = ss.subsample(kspace, self.mask_func, seed, mode='2D')

        # Normalize data...
        if 0:
            A = T.SenseModel(maps, weights=mask)
            image = A(masked_kspace, adjoint=True)
            magnitude = cplx.abs(image)
        elif 1:
            # ... by magnitude of zero-filled reconstruction
            A = T.SenseModel(maps)
            image = A(masked_kspace, adjoint=True)
            magnitude_vals = cplx.abs(image).reshape(-1)
            k = int(round(0.05 * magnitude_vals.numel()))
            scale = torch.min(torch.topk(magnitude_vals, k).values)
        else:
            # ... by power within calibration region
            calib_size = 10
            calib_region = cplx.center_crop(masked_kspace, [calib_size, calib_size])
            scale = torch.mean(cplx.abs(calib_region)**2)
            scale = scale * (calib_size**2 / kspace.size(-3) / kspace.size(-2))

        masked_kspace /= scale
        target /= scale
        mean = torch.tensor([0.0], dtype=torch.float32)
        std = scale
        
        #get zf image
        A = T.SenseModel(maps, weights=mask)
        zf_image = A(masked_kspace, adjoint=True)
        
        # Get rid of batch dimension...
        masked_kspace = masked_kspace.squeeze(0)
        maps = maps.squeeze(0)
        target = target.squeeze(0)
        #return masked_kspace, maps, target, mean, std, norm
        return target, zf_image,# maps, mask


def create_datasets(args):
    train_mask = PoissonLoadMaskFunc(args.accelerations, args.calib_size)
    dev_mask = PoissonLoadMaskFunc(args.accelerations, args.calib_size)

    train_data = SliceData(
        root=os.path.join(str(args.datadir), 'train'),
        transform=DataTransform(train_mask, args),
        sample_rate=args.sample_rate
    )
    dev_data = SliceData(
        root=os.path.join(str(args.datadir), 'val'),
        transform=DataTransform(dev_mask, args, use_seed=True),
        sample_rate=args.sample_rate
    )
    return dev_data, train_data


def create_data_loaders(args):
    dev_data, train_data = create_datasets(args)
    display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]

    train_loader = DataLoader(
        dataset=train_data,
        #batch_size=args.batch_size,
        batch_size=len(train_data),
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        #batch_size=args.batch_size,
        batch_size=len(dev_data),
        num_workers=8,
        pin_memory=True,
    )
    display_loader = DataLoader(
        dataset=display_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader, dev_loader, display_loader

def add_noise(img,noise_std):
    noise = torch.randn(img.size()) * noise_std
    noisy_img = img + noise.to(device)
    return noisy_img

def gen_noise(img, noise_std):
    noise = torch.randn(img.size()) * noise_std
    return noise

class Quad(nn.Module):
    def __init__(self):
        super(Quad,self).__init__() 

    def forward(self, input):
        return input**2

activations = {"relu":nn.ReLU(True),"sine":Sine(1.),"quad":Quad()}

#Data
class MNISTData():

    def __init__(self, train_batch_size, test_batch_size):

        all_transforms = [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]
        transform = transforms.Compose(all_transforms)
        trainset = MNIST(root='./data', train=True, download=True, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=8)
        testset = MNIST(root='./data', train=False, download=True, transform=transform)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=8)
        
class CIFAR10Data():

    def __init__(self, train_batch_size, test_batch_size):

        all_transforms = [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
        transform = transforms.Compose(all_transforms)
        trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=8)
        testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=8)

class FastMRIData():
    def __init__(self, train_loader, test_loader):
        self.trainloader = train_loader
        self.testloader = test_loader
        
class NoisyData():
    def __init__(self, trainset, testset):
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
        self.testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8)
        
class PrepareData3D(Dataset):
    def __init__(self, X, y, z):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y
        
        if not torch.is_tensor(z):
            self.z = torch.from_numpy(z)
        else:
            self.z = z
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.z[idx]
    
class PrepareData3DDC(Dataset):
    def __init__(self, X, y, z,t,k):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X   
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y
        if not torch.is_tensor(z):
            self.z = torch.from_numpy(z)
        else:
            self.z = z
        if not torch.is_tensor(t):
            self.t = torch.from_numpy(t)
        else:
            self.t = t
        if not torch.is_tensor(k):
            self.k = torch.from_numpy(k)
        else:
            self.k = k
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.z[idx], self.t[idx], self.k[idx]


#network
class FCNet(nn.Module):
    def __init__(self,dim,chans):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_channels = 1, out_channels = 1, kernel_size = 1),
            activations[act_type],
            nn.Linear((dim+2*pad_size)*(dim+2*pad_size),(dim+2*pad_size)*(dim+2*pad_size)))
        self.layer1 = nn.Linear((dim+2*pad_size)*(dim+2*pad_size)*chans,(dim+2*pad_size)*(dim+2*pad_size)*chans)
        self.act1 = activations[act_type]
        self.layer2 = nn.Linear((dim+2*pad_size)*(dim+2*pad_size)*chans,(dim+2*pad_size)*(dim+2*pad_size)*chans)
    def forward(self, x):

        for i in range(args.num_unrolled):
            feats = self.act1(self.layer1(x))
            x = self.layer2(feats)
        return x,feats


#network
class ConvNet0(nn.Module):
    def __init__(self):
        super(ConvNet0, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=1, out_channels=1024, kernel_size=7, stride=1, padding=3, bias=False)
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False)
    
    def forward(self, x):
        out = x
        out = self.layer2(self.act1(self.layer1(out)))
        out = out + x
        return out
    
class ConvNet0DC(nn.Module):
    def __init__(self):
        super(ConvNet0DC, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=2, out_channels=1024, kernel_size=7, stride=1, padding=3, bias=False, padding_mode = 'circular')
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=1024, out_channels=2, kernel_size=1, stride=1, padding=0, bias=False, padding_mode = 'circular')
    
    def forward(self, x):
        out = x
        out = self.layer2(self.act1(self.layer1(out)))
        out = out + x
        out_dc = data_consistency(out,x)
        return out_dc
    
    
class ConvNet1DC(nn.Module):
    def __init__(self):
        super(ConvNet1DC, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=2, out_channels=1024, kernel_size=7, stride=1, padding=3, bias=False, padding_mode = 'circular')
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=1024, out_channels=2, kernel_size=1, stride=1, padding=0, bias=False, padding_mode = 'circular')
    
    def forward(self, x,A):

        x = x.view(-1,args.patch,args.patch,1,2)
        A.weights = A.weights.squeeze(1)
        grad_x = A(A(x), adjoint=True) - x
        out = x -2*grad_x
        #out = x
        out = out.view(-1,2,args.patch,args.patch)
        out = self.layer2(self.act1(self.layer1(out)))
        x = x.view(-1,2,args.patch,args.patch)
        out_dc = out + x
        out_dc = out_dc.view(-1,1,args.dim,args.dim)
        return out_dc

class UnrolledNet0(nn.Module):
    def __init__(self):
        super(UnrolledNet0, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet0()] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet0() for i in range(args.num_unrolled)])   
    
    def forward(self, x):

        for convnet in self.convblocks:
            x = convnet(x)
        return x
    
class UnrolledNet0DC(nn.Module):
    def __init__(self):
        super(UnrolledNet0DC, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet1DC()] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet1DC() for i in range(args.num_unrolled)])   
    
    def forward(self, x,A):

        for convnet in self.convblocks:
            x = convnet(x,A)
        return x
    

#network
class ConvNet1(nn.Module):
    def __init__(self):
        super(ConvNet1, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=5, stride=1, padding=2, bias=False)
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0, bias=False)
    
    def forward(self, x):
        x = self.layer2(self.act1(self.layer1(x)))
        return x

class UnrolledNet1(nn.Module):
    def __init__(self):
        super(UnrolledNet1, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet1()] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet1() for i in range(args.num_unrolled)])   
    
    def forward(self, x):

        for convnet in self.convblocks:
            x = convnet(x)
        return x

#network
class ConvNet2(nn.Module):
    def __init__(self):
        super(ConvNet2, self).__init__()

        self.layer1 = nn.Conv2d(in_channels=1, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=False)
        self.act1 = activations[act_type]
        self.layer2 = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
    
    def forward(self, x):
        x = self.layer2(self.act1(self.layer1(x)))
        return x
    
class UnrolledNet2(nn.Module):
    def __init__(self):
        super(UnrolledNet2, self).__init__()
        
        if args.share_weights:
            self.convblocks = nn.ModuleList([ConvNet2()] * args.num_unrolled)
        else:
            self.convblocks = nn.ModuleList([ConvNet2() for i in range(args.num_unrolled)])   
    
    def forward(self, x):

        for convnet in self.convblocks:
            x = convnet(x)
        return x

def preprocess(img):
    
    #t0 = time.time()
    img = torch.squeeze(img)
    for i in range(len(img)):

        cur_img = img[i,:,:]
        img_cpu = cur_img.cpu().detach().numpy()
        pca = PCA(n_components=args.lowrank)
        vis_out_cpu = pca.fit_transform(img_cpu)  
        #zz = np.matmul(vis_out_cpu,pca.components_) + pca.mean_
        pca_inv = pca.inverse_transform(vis_out_cpu)
        torch_lr = torch.from_numpy(pca_inv)
        img[i,:,:] = torch_lr
        
        vis = False
        if vis:
            pca_inv -= pca_inv.min()
            pca_inv /= pca_inv.max()
            fig, axs = plt.subplots(1,1)
            axs.imshow(pca_inv.reshape(28,28),cmap='gray')
            plt.savefig(args.log_dir+'/LR_input_np'+args.problem+'_'+str(args.noise_std)+'.png')

        #t1 = time.time()
        #img_lr =  torch.pca_lowrank(img[i,:,:],args.lowrank)
        #A = torch.matmul(img_lr[0]*img_lr[1],torch.transpose(img_lr[2],0,1)) + torch.mean(img[i,:,:],dim=0)
        #print('{} seconds torch'.format(time.time() - t1))
        
    img = img.view(img.size(0),1, img.size(1),img.size(2))  
    #print('{} seconds preprocess'.format(time.time() - t0)) #takes 1 second
    return img

def downsample(img):
    img = img.view(-1,320,320,1,2)
    imgf = T.fft2(img)
    cut_imgf = imgf[:,160-int(args.patch/2):160+int(args.patch/2),160-int(args.patch/2):160+int(args.patch/2),:,:]
    new_img = T.ifft2(cut_imgf)
    if args.multi_chan:
        img = torch.reshape(new_img,(-1,2,args.patch,args.patch))
    else:
        img = torch.reshape(new_img,(-1,1,args.patch,args.patch))
    return img

def data_consistency(output,zf):
    out = output.view(-1,args.patch,args.patch,1,2)
    outf = T.fft2(out)
    zf = zf.view(-1,args.patch,args.patch,1,2)
    zff = T.fft2(zf)
    outf[zff>1e-8] = zff[zff>1e-8]
    new_img = T.ifft2(outf)
    if args.multi_chan:
        img = torch.reshape(new_img,(-1,2,args.patch,args.patch))
    else:
        img = torch.reshape(new_img,(-1,1,args.patch,args.patch))
    return img
    
def loss_primal(yhat, y, model, args):

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model
    
    # term 1
    criterion = nn.MSELoss()
    loss = 0.5*criterion(yhat,y)   #averaged over batch and pixels -- per-pixel loss
    #loss = 0.5 * torch.norm(yhat - y)**2

    # term 2
    if args.num_unrolled>1:
        for i in range(args.num_unrolled):
            loss = loss + args.beta/2 * torch.norm(model_prox.convblocks[i].layer1.weight)**2   
            loss = loss + args.beta/2 * torch.norm(model_prox.convblocks[i].layer2.weight)**2   
    else:
        loss = loss + args.beta/2 * torch.norm(model_prox.layer1.weight)**2
        loss = loss + args.beta/2 * torch.norm(model_prox.layer2.weight)**2

    # cost averaging
    return loss  #/len(y)

def train(model, data):

    trainset = data.trainloader
    testset = data.testloader
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 100, 0.1)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=args.verbose,
                                                           factor=0.1,
                                                          patience=20,
                                                           eps=1e-8)
    writer = SummaryWriter(log_dir=args.log_dir)

    for epoch in range(args.primal_epochs):
        total_loss = 0
        for i, datum in enumerate(trainset, 0):

            img = datum[0].to(device)
            if args.mode == 'recon':
                zf = datum[1].to(device)
            noise = datum[2].to(device)
            if args.data_cons:
                maps = datum[3].to(device)
                mask = datum[4].to(device)
                A = T.SenseModel(maps, weights=mask)
            

            if args.dataset == 'fastmri':
                if args.multi_chan:
                    img = img.view(-1,2,args.dim,args.dim)
                    noise = noise.view(-1,2,args.dim,args.dim)
                    zf = zf.view(-1,2,args.dim,args.dim)
                else:
                    img = img.view(-1,1,args.dim,args.dim)
                    noise = noise.view(-1,1,args.dim,args.dim)
                    zf = zf.view(-1,1,args.dim,args.dim)
                
                if args.patch<320:
                    img = downsample(img)
                    zf = downsample(zf)
                    #img = img[:,:,80:160,80:160]
                    noise = noise[:,:,160-int(args.patch/2):160+int(args.patch/2),160-int(args.patch/2):160+int(args.patch/2)]
                    #noise = torch.nn.functional.interpolate(noise, scale_factor = 0.25)
                    #zf = zf[:,:,80:160,80:160]

            if args.lowrank < 28:
                img = preprocess(img)

            if circular_pad:
                img = nn.functional.pad(img, 2*(pad_size,)+2*(pad_size,), mode='circular')

            #img = img.view(img.size(0),1, -1)  #only for FCNet
            if args.mode == 'den':
                noisy_img = img + noise
            if args.mode == 'recon':
                noisy_img = zf

            # ===================forward=====================

            if args.data_cons:
                output = model(noisy_img,A)
            else:
                output = model(noisy_img)
            loss = loss_primal(output, img, model, args)
            #loss = criterion(output, img)
            total_loss += loss.item()
            
            train_psnr = psnr(output,img)

            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer.add_scalar('Train_Loss', loss.item(),  epoch*len(trainset) + i)
            writer.add_scalar('Train_PSNR', train_psnr.item(), epoch*len(trainset) + i)
        writer.add_scalar('Train_Loss_Epoch', total_loss/len(trainset),  epoch)
        scheduler.step(epoch)
        test_loss = test(model,testset,writer,epoch)

        # ===================log========================
        print('epoch [{}/{}], train_loss:{:.4f}, test_loss:{:.4f}'.format(epoch+1, args.primal_epochs, total_loss/len(trainset), test_loss))

        
        if epoch % 5 == 0:


            if args.multi_chan:
                vis_noisy_img = noisy_img[i,:,:,:]
                vis_output = output[i,:,:,:]
                vis_img = img[i,:,:,:]
        
            vis_noisy_img = torch.reshape(noisy_img,(-1,args.patch,args.patch,2))
            vis_output= torch.reshape(output,(-1,args.patch,args.patch,2))
            vis_img = torch.reshape(img,(-1,args.patch,args.patch,2))
            residual = vis_img-vis_noisy_img
            vis_residual = cplx.abs(residual)
            vis_noisy_img = cplx.abs(vis_noisy_img)
            vis_output = cplx.abs(vis_output)
            vis_img = cplx.abs(vis_img)
            #res_hist = torch.histc(residual)
            #res_hist_cpu = res_hist.cpu().detach().numpy()
            #plt.plot(res_hist_cpu)
            #plt.savefig(args.log_dir+'/res_hist_'+args.mode+'_'+str(args.accelerations[0])+'.png')
            all_imgs = torch.cat((vis_img, vis_output, vis_noisy_img, vis_residual), dim=2)
            vis_image(all_imgs,writer,epoch, args,4,'primal_train_allimgs')
            vis_image(vis_img,writer,epoch, args,1,'primal_train_gt')
            vis_image(vis_output,writer,epoch*int(noisy_img.shape[0]/2)+i,args,1,'primal_train_output')
            vis_image(vis_noisy_img,writer,epoch*int(noisy_img.shape[0]/2)+i,args,1,'primal_train_input')

    writer.close()

    return loss

def test(model, testset, writer, epoch):
    criterion = nn.MSELoss()
    test_loss = 0
    test_psnr = 0
    with torch.no_grad():
        for i, datum in enumerate(testset, 0):

            img = datum[0].to(device)
            if args.mode == 'recon':
                zf = datum[1].to(device)
            noise = datum[2].to(device)
            if args.data_cons:
                maps = datum[3].to(device)
                mask = datum[4].to(device)
                A = T.SenseModel(maps, weights=mask)
            
            if args.dataset == 'fastmri':
                if args.multi_chan:
                    img = img.view(-1,2,args.dim,args.dim)
                    noise = noise.view(-1,2,args.dim,args.dim)
                    zf = zf.view(-1,2,args.dim,args.dim)
                else:
                    img = img.view(-1,1,args.dim,args.dim)
                    noise = noise.view(-1,1,args.dim,args.dim)
                    zf = zf.view(-1,1,args.dim,args.dim)
                if args.patch<320:
                    img = downsample(img)
                    #img = img[:,:,80:160,80:160]
                    noise = noise[:,:,160-int(args.patch/2):160+int(args.patch/2),160-int(args.patch/2):160+int(args.patch/2)]
                if args.mode == 'recon':
                    zf = zf.view(-1,1,args.dim,args.dim)
                    zf = downsample(zf)
                    #zf = zf[:,:,80:160,80:160]
            if args.lowrank < 28:
                img = preprocess(img)
            
            if circular_pad:
                img = nn.functional.pad(img, 2*(pad_size,)+2*(pad_size,), mode='circular')

            if args.mode == 'den':
                noisy_img = img + noise
            if args.mode == 'recon':
                noisy_img = zf
                
            # ===================forward=====================
            if args.data_cons:
                output = model(noisy_img,A)
            else:
                output = model(noisy_img)
            #output = data_consistency(output,zf)
            #loss = criterion(output, img)
            loss = loss_primal(output, img, model, args)
            test_loss += loss.item()/len(testset)
            test_psnr += psnr(output,img)/len(testset)
        
        
            if args.multi_chan:
                vis_noisy_img = noisy_img[i,:,:,:]
                vis_output = output[i,:,:,:]
                vis_img = img[i,:,:,:]
            if i%20 == 0:
                cur_psnr = psnr(output,img)
                vis_noisy_img = torch.reshape(noisy_img,(-1,args.patch,args.patch,2))
                vis_output= torch.reshape(output,(-1,args.patch,args.patch,2))
                vis_img = torch.reshape(img,(-1,args.patch,args.patch,2))
                vis_noisy_img = cplx.abs(vis_noisy_img)
                vis_output = cplx.abs(vis_output)
                vis_img = cplx.abs(vis_img)
                all_imgs = torch.cat((vis_img, vis_output, vis_noisy_img), dim=2)
                vis_image(all_imgs,writer,epoch,args,3,'primal_test_allimgs'+str(i)+'_'+str(cur_psnr))
            #vis_image(vis_img,writer,epoch,args,1,'primal_test_gt')
            #vis_image(vis_output,writer,epoch,args,1,'primal_test_output')
            #vis_image(vis_noisy_img,writer,epoch,args,1,'primal_test_input')
        #writer.add_scalar('Test_PSNR', psnr.item(), epoch*len(trainset) + i)
        # print('Test_loss:{:.4f}'.format(test_loss))
        writer.add_scalar('Test_Loss', test_loss, epoch)
        writer.add_scalar('Test_PSNR', test_psnr, epoch)
    return test_loss


if __name__ == "__main__":
    
    """ This is executed when run from the command line """
    parser = argparse.ArgumentParser()
    # Required positional argument
    parser.add_argument("--alternating", help="Whether to use alternating (default) or joint minimization of dual variables", default=False, type=bool)
    parser.add_argument("--optimizer", help='SGD or Adam', default='Adam')
    parser.add_argument("--lr", help='Initial Learning Rate', default=1e-3, type=float)
    parser.add_argument("--momentum", help='Momentum for SGD', default=0.9, type=float)
    parser.add_argument("--primal_epochs", help='Number of epochs to train for the primal', default=200, type=int)
    parser.add_argument("--epochs", help='Number of epochs to train if training jointly, or number of epochs per variable descent if alternating', default=200, type=int)
    parser.add_argument("--alternations", help='Number of alternations if training alternating', default=10, type=int)
    parser.add_argument("--bs", help='Batch size', default=100, type=int)
    parser.add_argument('--dataset', help='Dataset', default='fastmri', type=str)
    parser.add_argument('--beta', help='Regularization parameter', default=1e-5, type=float)
    parser.add_argument('--rho', help='Hinge loss parameter (penalizes infeasible solutions)', default=1e-8, type=float)
    parser.add_argument('--rho_factor', help='Factor by which rho increases each epoch (to force feasibility)', default=1.1, type=float)
    parser.add_argument('--device', help='Device to use', default='cuda', type=str)
    parser.add_argument('--verbose', help='Show intermediate steps', default=True, type=bool)
    parser.add_argument('--modeldir', help='Directory to save final model', default='./models', type=str)
    parser.add_argument('--outputstr', help='String output', default='dual_model.pkl', type=str)
    parser.add_argument('--datadir', help='Directory of dataset', default='', type=str)
    parser.add_argument('--P', help='number of random relu patterns to use', default=8000, type=int)
    parser.add_argument('--signpth', help='Path to sign patterns, load if exists, otherwise save', default='sign_patterns', type=str)
    parser.add_argument('--lowrank', help='Frequency to print', default=28, type=int)
    parser.add_argument('--printfreq', help='Frequency to print', default=5, type=int)
    parser.add_argument('--mode', help='recon (reconstruction) or den (denoising)', default='recon', type=str)
    parser.add_argument('--noise_std', help='standard deviation of noise to add if denoising', default=0.5, type=float)
    parser.add_argument('--use_saved_checkpoint', help='use saved checkpoint for initialization', default=True, type=bool)
    parser.add_argument('--log_dir', help='Directory for tensorboard logs', default='./logs', type=str)
    parser.add_argument('--run', help='Name of the current run', default='run0', type=str)
    
    #MRI params
    parser.add_argument('--acc', default=3, type=int,
                        help='Range of acceleration rates to simulate in training data.')
    parser.add_argument('--calib-size', type=int, default=16, help='Size of calibration region')
    parser.add_argument('--sample-rate', type=float, default=0.05, help='Fraction of total volumes to include')
    parser.add_argument('--patch', default=80, type=int, help='Patch size for training images')
    parser.add_argument('--num-emaps', type=int, default=1, help='Number of ESPIRiT maps')
    parser.add_argument('--image-mode', type=str, default='separate', help='Magnitude train, separate train or channels train')
    parser.add_argument('--num_unrolled', type=int, default=1, help='Number of unrolled iterations')
    parser.add_argument('--share_weights', type=bool, default=False, help='Weight sharing for unrolled')
    parser.add_argument('--multi_chan', type=bool, default=False, help='Real and imag as channels')
    parser.add_argument('--data_cons', type=bool, default=False, help='Whether to use data consistency')
    
    
    args = parser.parse_args()
    args.run = 'sub50_patch80_poisson_paper_results'
    args.log_dir = './primal_results/'+args.dataset+str(args.acc) +'_'+args.mode+'_'+args.run + '_val'
    args.load_dir = './primal_results/'+args.dataset+str(args.acc) +'_'+args.mode+'_'+args.run 
    args.datadir = './data'
    args.use_saved_checkpoint = False
    args.accelerations = [args.acc,args.acc]  
    
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    #identify GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    
    if args.dataset == 'MNIST':
        data = MNISTData(60000, 10000)
        args.dim = 28
        args.chans = 1
    elif args.dataset == 'CIFAR-10':
        data = CIFAR10Data(50000, 10000)
        args.dim = 32
        args.chans = 3
    elif args.dataset == 'fastmri':
        args.datadir = ''
        train_loader, dev_loader, display_loader = create_data_loaders(args)
        data = FastMRIData(train_loader, dev_loader)
        args.dim = 320
        args.chans = 1
   
    if args.data_cons:
        for A_test, y_test,maps_test,mask_test in data.testloader:
            pass
    
        print('generating noisy test set')
        test_noise = gen_noise(A_test, args.noise_std)
        testset = PrepareData3DDC(A_test, y_test, test_noise,maps_test,mask_test)
        print(len(testset))
    else:
        for A_test, y_test in data.testloader:
            pass

        print('generating noisy test set')
        test_noise = gen_noise(A_test, args.noise_std)
        testset = PrepareData3D(A_test, y_test,test_noise)
        print(len(testset))

    model = ConvNet0()

    if torch.cuda.device_count() > 1:
        args.parallel = False
        #model = nn.DataParallel(model)    #parallelize in BATCH DIMENSION. it creates a model variable for each gpu and assigns a piece of data to each gpu -- it splits the models, computes per gpu, and aggregates the results
    else:
        args.parallel = False
        
    model.to(device)   #sends the net modules and buffers into cuda tensors

    #train $ save the model
    epoch = 25
    PATH = args.load_dir+'/.pth'
    model.load_state_dict(torch.load(args.load_dir+'/model_'+str(epoch)+'.pth'))
    #model.load_state_dict(torch.load(PATH))
    
    writer = SummaryWriter(log_dir=args.log_dir)
    test_loss = test(model,testset,writer,epoch)

    # visualize(model, data.testloader, args)
    # args.problem = 'dual'
    # cvx_loss = dual_main(args)
    # duality_gap = noncvx_loss.item() - cvx_loss.item()
    # print('duality_gap:',duality_gap)
    # writer = SummaryWriter(log_dir=args.log_dir)
    # writer.add_scalar('Duality_Gap', duality_gap)

    # #load and test
    # args.problem = 'primal'
    # model.load_state_dict(torch.load(PATH))
    # t0 = time.time()
    # test(model,data.testloader)
    # device = torch.device("cpu")
    # print('{} seconds'.format(time.time() - t0))
    
