""" Solve the strong dual of the two-layer convolutional vector-output
ReLU activation NN training problem for denoising. Output convolution must be 
a 1x1 convolution. 
"""

# libraries 
import os
import numpy as np
import time
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import argparse
import torch.nn as nn
import time
from visualize import visualize, vis_image, to_img, psnr
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
#torch.cuda.set_device(1)

#Import MRI utils
from utils import transforms as T
from utils import subsample as ss
from utils import complex_utils as cplx

# import custom classes
from utils.datasets import SliceData
from utils.subsample import RandomMaskFunc, PoissonLoadMaskFunc

class DataTransform:
    """
    Data Transformer for training unrolled reconstruction models.
    """

    def __init__(self, mask_func, args, use_seed=True):
        """
        Args:
            mask_func (utils.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        self.mask_func = mask_func
        #self.resolution = resolution
        self.use_seed = use_seed

    def __call__(self, kspace, maps, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        # Convert everything from numpy arrays to tensors
        kspace = cplx.to_tensor(kspace).unsqueeze(0)
        maps   = cplx.to_tensor(maps).unsqueeze(0)
        target = cplx.to_tensor(target).unsqueeze(0)
        norm = torch.sqrt(torch.mean(cplx.abs(target)**2))

        #print(kspace.shape)
        #print(maps.shape)
        #print(target.shape)

        # Apply mask in k-space
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = ss.subsample(kspace, self.mask_func, seed, mode='2D')

        # Normalize data...
        if 0:
            A = T.SenseModel(maps, weights=mask)
            image = A(masked_kspace, adjoint=True)
            magnitude = cplx.abs(image)
        elif 1:
            # ... by magnitude of zero-filled reconstruction
            A = T.SenseModel(maps)
            image = A(masked_kspace, adjoint=True)
            magnitude_vals = cplx.abs(image).reshape(-1)
            k = int(round(0.05 * magnitude_vals.numel()))
            scale = torch.min(torch.topk(magnitude_vals, k).values)
        else:
            # ... by power within calibration region
            calib_size = 10
            calib_region = cplx.center_crop(masked_kspace, [calib_size, calib_size])
            scale = torch.mean(cplx.abs(calib_region)**2)
            scale = scale * (calib_size**2 / kspace.size(-3) / kspace.size(-2))

        masked_kspace /= scale
        target /= scale
        mean = torch.tensor([0.0], dtype=torch.float32)
        std = scale
        
        #get zf image
        A = T.SenseModel(maps, weights=mask)
        zf_image = A(masked_kspace, adjoint=True)
        
        # Get rid of batch dimension...
        masked_kspace = masked_kspace.squeeze(0)
        maps = maps.squeeze(0)
        target = target.squeeze(0)

        #return masked_kspace, maps, target, mean, std, norm
        return target, zf_image

def create_datasets(args):
    train_mask = PoissonLoadMaskFunc(args.accelerations, args.calib_size)
    dev_mask = PoissonLoadMaskFunc(args.accelerations, args.calib_size)

    train_data = SliceData(
        root=os.path.join(str(args.datadir), 'train'),
        transform=DataTransform(train_mask, args),
        sample_rate=args.sample_rate
    )
    dev_data = SliceData(
        root=os.path.join(str(args.datadir), 'val'),
        transform=DataTransform(dev_mask, args, use_seed=True),
        sample_rate=args.sample_rate
    )
    return dev_data, train_data


def create_data_loaders(args):
    dev_data, train_data = create_datasets(args)
    display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]
    print(len(dev_data))
    train_loader = DataLoader(
        dataset=train_data,
        #batch_size=args.batch_size,
        batch_size=len(train_data),
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        #batch_size=args.batch_size,
        batch_size=len(dev_data),
        num_workers=8,
        pin_memory=True,
    )
    display_loader = DataLoader(
        dataset=display_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader, dev_loader, display_loader

def downsample(img):
    img = img.view(-1,320,320,1,2)
    imgf = T.fft2(img)
    cut_imgf = imgf[:,120:200,120:200,:,:]
    new_img = T.ifft2(cut_imgf)
    img = torch.reshape(new_img,(-1,1,args.patch,args.patch))
    return img

def add_noise(img,noise_std):
    noise = torch.randn(img.size()) * noise_std
    noisy_img = img + noise
    return noisy_img

def gen_noise(img, noise_std):
    noise = torch.randn(img.size())* noise_std
    return noise

# functions for generating sign patterns
def check_if_already_exists(element_list, element):
    # check if element exists in element_list
    # where element is a numpy array
    return list(element) in element_list

    #for i in range(len(element_list)):
    #    if np.array_equal(element_list[i], element):
    #        return True
    #return False

class PrepareData(Dataset):
    def __init__(self, X, y):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


class PrepareData3D(Dataset):
    def __init__(self, X, y, z):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y
        
        if not torch.is_tensor(z):
            self.z = torch.from_numpy(z)
        else:
            self.z = z
        

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.z[idx]


def generate_sign_patterns(A, P, verbose=False):
    start = time.time()
    # generate sign patterns
    n, d = A.shape
    unique_sign_pattern_list = []  # sign patterns
    u_vector_list = []             # random vectors used to generate the sign paterns

    rows = np.arange(len(A))
    np.random.shuffle(rows)
    rows_selected = rows[:int(args.sign_subsample_factor*len(rows))]
    A_eff = A[rows_selected]
    t0 = time.time()

   
    for i in range(P):
        # obtain a sign pattern
        u = torch.randn(d,1) # sample u

        with torch.no_grad():
            sampled_sign_pattern = ((A_eff @ u >= 0)[:,0]).tolist()
        
        # check whether that sign pattern has already been used
        if not check_if_already_exists(unique_sign_pattern_list, sampled_sign_pattern):
            unique_sign_pattern_list.append(list(sampled_sign_pattern))
            u_vector_list.append(u.cpu().data.numpy())

            if verbose and len(u_vector_list) % 100 == 0:
                print('currently generated', len(u_vector_list), 'sign patterns')
                print('{} seconds'.format(time.time() - t0))

    if verbose:
        print("Number of unique sign patterns generated: " + str(len(unique_sign_pattern_list)))
        print('Generating sign patterns took', time.time() - start, 'seconds')
    return u_vector_list

# nonconvex model architecture
class ConvNetwork(nn.Module):
    def __init__(self, H, num_channels=1, kernel_size=3, padding=1):
        super(ConvNetwork, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(num_channels, H, kernel_size=kernel_size, padding=padding, bias=False), nn.ReLU())
        self.layer2 = nn.Conv2d(H, num_channels, 1, bias=False)

    def forward(self, x):
        x = x.float()
        feats = self.layer1(x)
        out = self.layer2(feats)
        return out, feats
    
# for trianing both sets of dual weights together
class JointLayer(torch.nn.Module):
    def __init__(self, kernel_size, padding, num_neurons, u_vectors):
        super(JointLayer, self).__init__()
        self.padding = padding
        self.kernel_size = kernel_size
        self.v = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, self.kernel_size, self.kernel_size), requires_grad=True)
        self.w = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, self.kernel_size, self.kernel_size), requires_grad=True)

        self.h = torch.nn.Parameter(data=torch.Tensor(u_vectors), requires_grad=False)

    def forward(self, x):
        x = x.float()
        sign_patterns = (torch.nn.functional.conv2d(x, self.h, padding=self.padding) >= 0).int()  # N x p x h x w

        Xw = torch.nn.functional.conv2d(x, self.w, padding=self.padding)
        Xv = torch.nn.functional.conv2d(x, self.v, padding=self.padding)

        DXw = torch.mul(sign_patterns, Xw)
        DXv = torch.mul(sign_patterns, Xv)
        
        relu_term_w = (torch.max(-2*DXw + Xw, torch.Tensor([0]).to(args.device)))
        relu_term_v = (torch.max(-2*DXv + Xv, torch.Tensor([0]).to(args.device)))
        penalty = torch.sum(relu_term_w**2) + torch.sum(relu_term_v**2)

        DXwv = DXw - DXv #  N x P
        y_pred = torch.sum(DXwv, dim=1, keepdim=True) # N x1 x  h x w
        
        return y_pred, penalty

def loss_func_cvxproblem(yhat, y, model, penalty, beta, rho, args):

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model

    criterion = nn.MSELoss()
    # term 1
    loss = 0.5 * criterion(yhat, y)
    w_values = model_prox.w.transpose(1, 3).reshape((-1, args.P))
    v_values = model_prox.v.transpose(1, 3).reshape((-1, args.P))
    # term 2
    loss = loss + beta * torch.sum(torch.norm(model_prox.w, dim=0))
    loss = loss + beta * torch.sum(torch.norm(model_prox.v, dim=0))

    # term 3
    loss = loss + rho * penalty

    # cost averaging
    return loss

def validation_loss(yhat, y, model, beta, args):
    criterion = nn.MSELoss()
    loss = 0.5 * criterion(yhat, y)

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model

    loss = loss + beta/2 * torch.norm(model_prox.layer1[0].weight)**2
    loss = loss + beta/2 * torch.norm(model_prox.layer2.weight)**2

    return loss

def test(test_dataset, test_loader, writer, epoch, args, save=False):
    # creating modelA
    model = ConvNetwork(2*args.P, 1, args.kernel_size, args.padding)
    PATH = args.load_dir+'/.pth'
    #model.load_state_dict(torch.load(os.path.join(args.modeldir, args.outputstr)))
    model.load_state_dict(torch.load(args.load_dir+'/model_'+str(epoch)+'.pth'))
    writer = SummaryWriter(log_dir=args.log_dir)
    print('model loaded')
    #multi-gpu
    if args.parallel:
         model = nn.DataParallel(model)    #parallelize in BATCH DIMENSION. it creates a model variable for each gpu and assigns a piece of data to each gpu -- it splits the models, computes per gpu, and aggregates the results
    model = model.to(args.device)
    
    test_total_loss = 0
    test_psnr = 0
    
    if args.mode == 'den':
        args.dim = args.patch
    elif args.mode == 'recon':
        args.dim = args.patch
    
    with torch.no_grad():
        if args.mode == 'den':
            for _x, _, _z in test_loader:
                noisy_img = (_x + _z).to(args.device)
                _y = _x.detach_().to(args.device)
                
                prediction, _ = model(noisy_img)
                test_total_loss += validation_loss(prediction, _y, model, args.beta, args)/len(test_loader)
        if args.mode == 'recon':
            i = 0
            for _x, _zf, _ in test_loader:    
                noisy_img = _zf.to(args.device)
                _y = _x.detach_().to(args.device)
                
                prediction, _ = model(noisy_img)
                test_total_loss += validation_loss(prediction, _y, model, args.beta, args)/len(test_loader)
                
                test_psnr += psnr(prediction,_y)/len(test_loader)
                if i%20 == 0:
                    cur_psnr = psnr(prediction,_y)
                    zf_psnr = psnr(noisy_img,_y)
                    vis_noisy_img = torch.reshape(noisy_img,(-1,args.patch,args.patch,2))
                    vis_output= torch.reshape(prediction,(-1,args.patch,args.patch,2))
                    vis_img = torch.reshape(_y,(-1,args.patch,args.patch,2))
                    residual = vis_img-vis_noisy_img
                    vis_residual = cplx.abs(residual)
                    vis_noisy_img = cplx.abs(vis_noisy_img)
                    vis_output = cplx.abs(vis_output)
                    vis_img = cplx.abs(vis_img)
                    
                    all_imgs = torch.cat((vis_img, vis_output, vis_noisy_img), dim=2)
                    vis_image(all_imgs,writer,epoch,args,3,'dual_test_allimgs'+str(i)+'_'+str(cur_psnr)+'_'+str(zf_psnr))
                i = i+1

    print('Model testing loss', test_total_loss.item())

    writer.add_scalar('Test_Loss', test_total_loss, epoch)
    writer.add_scalar('Test_PSNR', test_psnr, epoch)
    


def main(args):
    if not os.path.exists(args.modeldir):
        os.mkdir(args.modeldir)
    
    args.datadir = './data'
    args.accelerations = [args.acc,args.acc]
    args.load_dir = './dual_results/fastmri'+str(args.acc)+'_sub50_patch80_poisson_paper_results'
    args.log_dir = args.load_dir+'_val'
    
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)
    if args.dataset == 'MNIST':
        normalize = transforms.Normalize((0.1307,), (0.3081,))
        train_dataset = datasets.MNIST(
            args.datadir, train=True, download=True,
            transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

        test_dataset = datasets.MNIST(
            args.datadir, train=False, download=True,
            transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
            # extract datset in numpy in A and y matrices
        dummy_loader= torch.utils.data.DataLoader(
            train_dataset, batch_size=len(train_dataset), shuffle=False,
            pin_memory=True, sampler=None)

        dummy_test_loader= torch.utils.data.DataLoader(
            test_dataset, batch_size=len(test_dataset), shuffle=False,
            pin_memory=True, sampler=None)
        
    if args.dataset == 'fastmri':
        args.datadir = ''
        dummy_loader, dummy_test_loader, display_loader = create_data_loaders(args)
        #data = FastMRIData(train_loader, dev_loader)
        #args.dim = args.resolution

    if args.verbose:
        print('creating test dataset')

    for A_test, y_test in dummy_test_loader:
        break
    if args.dataset == 'fastmri':
        A_test = A_test.view(-1,1,320,320)
        y_test = y_test.view(-1,1,320,320)
        A_test = downsample(A_test)
        y_test = downsample(y_test)
        #A_test = A_test[:,:,80:160,80:160]
        #y_test = y_test[:,:,80:160,80:160]
        
    test_noise = gen_noise(A_test, args.noise_std)
    test_dataset = PrepareData3D(X=A_test, y=y_test, z=test_noise)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
    print(len(test_loader))
    
    if os.path.exists(os.path.join(args.signpth, args.dataset, str(args.P) + '_' +args.down_mode+args.mode+'acc'+str(args.accelerations[0])+'_ds_size'+str(args.sample_rate)+str(args.kernel_size) + '_u_vector_list.npy')):
        if args.verbose:
            print('loading sign patterns...')

        u_vector_list = np.load(os.path.join(args.signpth, args.dataset, str(args.P) + '_' +args.down_mode+args.mode+'acc'+str(args.accelerations[0])+'_ds_size'+str(args.sample_rate)+str(args.kernel_size) + '_u_vector_list.npy'))
    else:
        if args.verbose:
            print('generating sign patterns...')
        u_vector_list = generate_sign_patterns(A, args.P, args.verbose)
        
        if args.verbose:
            print('reshaping u vectors...')
        u_vector_list = np.asarray(u_vector_list).reshape((args.P, d)).T

        if args.verbose:
            print('saving sign patterns...')
            if not os.path.exists(args.signpth):
                os.mkdir(args.signpth)
            if not os.path.exists(os.path.join(args.signpth, args.dataset)):
                os.mkdir(os.path.join(args.signpth, args.dataset))
            np.save(os.path.join(args.signpth, args.dataset, str(args.P) + '_' +args.down_mode+args.mode+'acc'+str(args.accelerations[0])+'_ds_size'+str(args.sample_rate)+str(args.kernel_size) + '_u_vector_list.npy'), u_vector_list)

   
    u_vector_list = u_vector_list.reshape((args.P, 1, args.kernel_size, args.kernel_size))
    
    writer = SummaryWriter(log_dir=args.log_dir)
    
    if torch.cuda.device_count() > 1:
        args.parallel = False
    else:
        args.parallel = False

    print(args)
    if args.verbose:
        print('Starting training...')
    '''
    w_value, v_value = joint_minimization(ds_train, u_vector_list, args, train_dataset, train_loader, 
                                    test_dataset, test_loader, writer)
    '''
    epoch = 25
    test(test_dataset, test_loader, writer, epoch, args, save=False)

if __name__ == "__main__":
    """ This is executed when run from the command line """
    parser = argparse.ArgumentParser()

    # Required positional argument
    parser.add_argument("--optimizer", help='SGD or Adam', default='Adam')
    parser.add_argument("--lr", help='Initial Learning Rate', default=5e-7, type=float)
    parser.add_argument("--momentum", help='Momentum for SGD', default=0.9, type=float)
    parser.add_argument("--epochs", help='Number of epochs to train', default=100, type=int)
    parser.add_argument("--bs", help='Batch size', default=2, type=int)
    parser.add_argument('--dataset', help='Dataset', default='fastmri', type=str)
    parser.add_argument('--beta', help='Regularization parameter', default=1e-5, type=float)
    parser.add_argument('--rho', help='Hinge loss parameter (penalizes infeasible solutions)', default=1e-7, type=float)
    parser.add_argument('--rho_factor', help='Factor by which rho increases each epoch (to force feasibility)', default=1, type=float)
    parser.add_argument('--device', help='Device to use', default='cuda', type=str)
    parser.add_argument('--verbose', help='Show intermediate steps', default=True, type=bool)
    parser.add_argument('--modeldir', help='Directory to save final model', default='./models', type=str)
    parser.add_argument('--outputstr', help='String output', default='1by1_conv_dual_model.pkl', type=str)
    parser.add_argument('--datadir', help='Directory of dataset', default='', type=str)
    parser.add_argument('--P', help='number of random relu patterns to use', default=5000, type=int)
    parser.add_argument('--signpth', help='Path to sign patterns, load if exists, otherwise save', default='1by1_sign_patterns', type=str)
    parser.add_argument('--printfreq', help='Frequency to print', default=1, type=int)
    parser.add_argument('--noise_std', help='standard deviation of noise to add if denoising', default=0.5, type=float)
    parser.add_argument('--use_saved_checkpoint', help='use saved checkpoint for initialization', default=True, type=bool)
    parser.add_argument('--log_dir', help='Directory for tensorboard logs', default='./dual_results/', type=str)
    parser.add_argument('--kernel_size', help='Kernel size for convolution of 1st layer', default=7, type=int)
    parser.add_argument('--padding', help='Padding for convolution of 1st layer', default=3, type=int)
    parser.add_argument('--sign_subsample_factor', help='Factor by which to subsample the training set to determine sign patterns', default=0.1, type=float)
    #MRI params
    parser.add_argument('--mode', help='recon (reconstruction) or den (denoising)', default='recon', type=str)
    parser.add_argument('--acc', default=3, type=int,
                        help='Range of acceleration rates to simulate in training data.')
    parser.add_argument('--calib-size', type=int, default=16, help='Size of calibration region')
    parser.add_argument('--sample-rate', type=float, default=0.05, help='Fraction of total volumes to include')
    parser.add_argument('--resolution', default=320, type=int, help='Resolution of images')
    parser.add_argument('--num-emaps', type=int, default=1, help='Number of ESPIRiT maps')
    parser.add_argument('--patch', default=80, type=int, help='Patch size for training images')
    parser.add_argument('--down_mode', default='ds', type=str, help='ds for downsample, patch for patch')

    args = parser.parse_args()
    args.problem = 'dual'
    main(args)
