""" Solve the strong dual of the two-layer convolutional vector-output
ReLU activation NN training problem for denoising. Output convolution must be 
a 1x1 convolution. 
"""

# libraries 
import os
import numpy as np
import time
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import argparse
import torch.nn as nn
import time
from visualize import visualize, vis_image, to_img
from tensorboardX import SummaryWriter
#torch.cuda.set_device(1)

def add_noise(img,noise_std):
    noise = torch.randn(img.size()) * noise_std
    noisy_img = img + noise
    return noisy_img

def gen_noise(img, noise_std, dist):
    if dist =='gaussian':
        noise = torch.randn(img.size())* noise_std
    else:
        m = torch.distributions.exponential.Exponential(torch.tensor([1/np.sqrt(noise_std)]))
        noise = m.sample(img.size()).squeeze(4)
    return noise

# functions for generating sign patterns
def check_if_already_exists(element_list, element):
    # check if element exists in element_list
    # where element is a numpy array
    return list(element) in element_list

    #for i in range(len(element_list)):
    #    if np.array_equal(element_list[i], element):
    #        return True
    #return False

class PrepareData(Dataset):
    def __init__(self, X, y):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


class PrepareData3D(Dataset):
    def __init__(self, X, y, z):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y
        
        if not torch.is_tensor(z):
            self.z = torch.from_numpy(z)
        else:
            self.z = z
        

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.z[idx]


def generate_sign_patterns(A, P, args, verbose=False):
    start = time.time()
    # generate sign patterns
    n, d = A.shape
    unique_sign_pattern_list = []  # sign patterns
    u_vector_list = []             # random vectors used to generate the sign paterns

    rows = np.arange(len(A))
    np.random.shuffle(rows)
    rows_selected = rows[:int(args.sign_subsample_factor*len(rows))]
    A_eff = A[rows_selected]

    for i in range(P):
        # obtain a sign pattern
        u = torch.randn(d,1) # sample u

        with torch.no_grad():
            sampled_sign_pattern = ((A_eff @ u >= 0)[:,0]).tolist()
        
        # check whether that sign pattern has already been used
        if not check_if_already_exists(unique_sign_pattern_list, sampled_sign_pattern):
            unique_sign_pattern_list.append(list(sampled_sign_pattern))
            u_vector_list.append(u.cpu().data.numpy())

            if verbose and len(u_vector_list) % 10 == 0:
                print('currently generated', len(u_vector_list), 'sign patterns')

    if verbose:
        print("Number of unique sign patterns generated: " + str(len(unique_sign_pattern_list)))
        print('Generating sign patterns took', time.time() - start, 'seconds')
    return u_vector_list

# nonconvex model architecture
class ConvNetwork(nn.Module):
    def __init__(self, H, num_channels=1, kernel_size=3, padding=1):
        super(ConvNetwork, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(num_channels, H, kernel_size=kernel_size, padding=padding, bias=False), nn.ReLU())
        self.layer2 = nn.Conv2d(H, num_channels, 1, bias=False)

    def forward(self, x):
        feats = self.layer1(x)
        out = self.layer2(feats)
        return out, feats
    
# for trianing both sets of dual weights together
class JointLayer(torch.nn.Module):
    def __init__(self, kernel_size, padding, num_neurons, u_vectors):
        super(JointLayer, self).__init__()
        self.padding = padding
        self.kernel_size = kernel_size
        self.v = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, self.kernel_size, self.kernel_size), requires_grad=True)
        self.w = torch.nn.Parameter(data=torch.zeros(num_neurons, 1, self.kernel_size, self.kernel_size), requires_grad=True)

        self.h = torch.nn.Parameter(data=torch.Tensor(u_vectors), requires_grad=False)

    def forward(self, x, args):
        sign_patterns = (torch.nn.functional.conv2d(x, self.h, padding=self.padding) >= 0).int()  # N x p x h x w

        Xw = torch.nn.functional.conv2d(x, self.w, padding=self.padding)
        Xv = torch.nn.functional.conv2d(x, self.v, padding=self.padding)

        DXw = torch.mul(sign_patterns, Xw)
        DXv = torch.mul(sign_patterns, Xv)
        
        relu_term_w = (torch.max(-2*DXw + Xw, torch.Tensor([0]).to(args.device)))
        relu_term_v = (torch.max(-2*DXv + Xv, torch.Tensor([0]).to(args.device)))
        penalty = torch.sum(relu_term_w) + torch.sum(relu_term_v)

        DXwv = DXw - DXv #  N x P
        y_pred = torch.sum(DXwv, dim=1, keepdim=True) # N x1 x  h x w
        
        return y_pred, penalty

def loss_func_cvxproblem(yhat, y, model, penalty, beta, rho, args):

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model

    criterion = nn.MSELoss()
    # term 1
    loss = 0.5 * criterion(yhat, y)
    w_values = model_prox.w.transpose(1, 3).reshape((-1, args.P))
    v_values = model_prox.v.transpose(1, 3).reshape((-1, args.P))
    # term 2
    loss = loss + beta * torch.sum(torch.norm(w_values, dim=0))
    loss = loss + beta * torch.sum(torch.norm(v_values, dim=0))

    # term 3
    loss = loss + rho * penalty

    # cost averaging
    return loss

def validation_loss(yhat, y, model, beta, args):
    criterion = nn.MSELoss()
    loss = 0.5 * criterion(yhat, y)

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model

    loss = loss + beta/2 * torch.norm(model_prox.layer1[0].weight)**2
    loss = loss + beta/2 * torch.norm(model_prox.layer2.weight)**2

    return loss

def joint_minimization(train_loader, u_vector_list, args, train_dataset, net_train_loader,
                                                test_dataset, test_loader, writer):

    d = int(args.kernel_size**2)
    #multi-gpu 
    model = JointLayer(args.kernel_size, args.padding, num_neurons=args.P, u_vectors=u_vector_list)
    #identify GPU
    if args.parallel:
         model = nn.DataParallel(model)    #parallelize in BATCH DIMENSION. it creates a model variable for each gpu and assigns a piece of data to each gpu -- it splits the models, computes per gpu, and aggregates the results
    model = model.to(args.device)

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=args.verbose,
                                                           factor=0.5,
                                                           eps=1e-12,
                                                           patience=30)
    
    if args.verbose:
        print('training convex program')

    for i in range(args.epochs):
        avg_cost = 0
        avg_penalty = 0
        for idx, (_x, _, _z) in enumerate(net_train_loader):

            noisy_img = (_x + _z).to(args.device)
            _y = _x.detach_().to(args.device)
            
            prediction, penalty = model(noisy_img, args)
            penalty = torch.sum(penalty)
            cost = loss_func_cvxproblem(prediction, _y, model, penalty, args.beta, args.rho, args)
            avg_cost += cost.item()
            avg_penalty += penalty.item()
            #if idx % 100 == 0:
            #    print(idx, cost.item(), penalty.item())

            optimizer.zero_grad()
            cost.backward()
            optimizer.step()


        
        if args.verbose and i%args.printfreq == 0:

            print(i, 'cost', avg_cost/(idx+1))
            print(i, 'penalty', avg_penalty/(idx+1))

            if args.parallel:
                model_prox = model.module
            else:
                model_prox = model

            w_value = model_prox.w.detach().cpu().data.numpy()
            v_value = model_prox.v.detach().cpu().data.numpy()
            validate_model(w_value, v_value, train_dataset, net_train_loader, test_dataset, test_loader, writer, i, args, d)
        
        scheduler.step(avg_cost/(idx+1))
        args.rho *= args.rho_factor
        if args.verbose:
            print('current rho', args.rho)

    if args.parallel:
        model_prox = model.module
    else:
        model_prox = model
    w_value = model_prox.w.detach().cpu().data.numpy()
    v_value = model_prox.v.detach().cpu().data.numpy()

    return w_value, v_value

def validate_model(w_value, v_value, train_dataset, train_loader, test_dataset, test_loader, writer, epoch, args, d, save=False):
    # creating modelA
    epsilon = 1e-12
    u_primal = [w_value[i]/(np.sqrt(np.linalg.norm(w_value[i, 0])) + epsilon) for i in range(args.P)] +\
                    [v_value[i]/(np.sqrt(np.linalg.norm(v_value[i, 0])) + epsilon) for i in range(args.P)]
    v_primal = [np.sqrt(np.linalg.norm(w_value[i, 0])) for i in range(args.P)] +\
                    [-np.sqrt(np.linalg.norm(v_value[i, 0])) for i in range(args.P)]


    u_primal  = torch.Tensor(u_primal) # 2P x d
    v_primal = torch.Tensor(v_primal).unsqueeze(0).unsqueeze(2).unsqueeze(3)


    model = ConvNetwork(2*args.P, 1, args.kernel_size, args.padding)
    model.layer1[0].weight = nn.Parameter(data=u_primal)
    model.layer2.weight = nn.Parameter(data=v_primal)

    #multi-gpu
    if args.parallel:
         model = nn.DataParallel(model)    #parallelize in BATCH DIMENSION. it creates a model variable for each gpu and assigns a piece of data to each gpu -- it splits the models, computes per gpu, and aggregates the results
    model = model.to(args.device)

    total_loss = 0

    with torch.no_grad():
        for idx, (_x, _, _z) in enumerate(train_loader):
            noisy_img = (_x + _z).to(args.device)
            _y = _x.to(args.device)

            prediction, _ = model(noisy_img)
            total_loss += validation_loss(prediction, _y, model, args.beta, args)/len(train_loader)


    print('Model training loss', total_loss.item())
        
    writer.add_scalar('train_loss/dual', total_loss.item(),  epoch)
    vis_image(prediction,writer,epoch,args,'dual_train_output')
    vis_image(noisy_img,writer,epoch,args,'dual_train_x')
    vis_image(_y,writer,epoch,args,'dual_train_y')
    
    test_total_loss = 0

    with torch.no_grad():
        for _x, _, _z in test_loader:
            noisy_img = (_x + _z).to(args.device)
            _y = _x.to(args.device)
                
            prediction, _ = model(noisy_img)
            test_total_loss += validation_loss(prediction, _y, model, args.beta, args)/len(test_loader)

    print('Model testing loss', test_total_loss.item())

    
    if save:
        if args.verbose:
            print('saving model')
        torch.save(model.state_dict(), os.path.join(args.modeldir, args.outputstr))
    
    writer.add_scalar('test_loss/dual', test_total_loss.item(), epoch)
    # visualize(model,test_loader,args)
    vis_image(prediction,writer,epoch,args,'dual_test_output')
    vis_image(noisy_img,writer,epoch,args,'dual_test_x')
    vis_image(_y,writer,epoch,args,'dual_test_y')

    print('number of non-zero dual filters', (torch.abs(v_primal) > epsilon).int().sum().item())
    
    return total_loss

def main(args):
    args.problem = 'dual'

    if not os.path.exists(args.modeldir):
        os.mkdir(args.modeldir)

    assert  args.dataset == 'MNIST', 'other datasets not yet supported'

    args.dim = 28
    args.chans = 1

    normalize = transforms.Normalize((0.1307,), (0.3081,))
    train_dataset = datasets.MNIST(
        args.datadir, train=True, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

    test_dataset = datasets.MNIST(
        args.datadir, train=False, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))
    
    # extract datset in numpy in A and y matrices
    dummy_loader= torch.utils.data.DataLoader(
        train_dataset, batch_size=int(len(train_dataset)*args.dataset_subsample_factor), shuffle=False,
        pin_memory=True, sampler=None)
    
    dummy_test_loader= torch.utils.data.DataLoader(
        test_dataset, batch_size=len(test_dataset), shuffle=False,
        pin_memory=True, sampler=None)

    if args.verbose:
        print('creating train dataset')
    for A, y in dummy_loader:
        break
    # A = A.view(A.shape[0], -1)
    print('original shape', A.shape)
    A_orig = A
    y_orig = y
    train_noise = gen_noise(A, args.noise_std, args.dist)

    A = A + train_noise 

    # unfold A matrix into a tall matrix (depends on kernel size)
    unfold_operator = nn.Unfold(kernel_size=args.kernel_size, padding=args.padding)
    A_unfolded = unfold_operator(A).transpose(1,2)
    num_blocks = A_unfolded.shape[1]
    A= A_unfolded.reshape((A_unfolded.shape[0]*A_unfolded.shape[1], -1))

    print('unfolded shape', A.shape)
    d = A.shape[1]

    y = A_orig.permute(0, 2, 3, 1).reshape((-1, A_orig.shape[1]))
    print('labels unfolded shape', y.shape)
    
    ds_train = PrepareData(X=A, y=y)
    ds_train = DataLoader(ds_train, batch_size=args.bs*(len(A)//len(A_orig)), shuffle=True)

    train_dataset = PrepareData3D(X=A_orig, y=y_orig, z=train_noise)
    train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True)

    if args.verbose:
        print('creating test dataset')

    for A_test, y_test in dummy_test_loader:
        break
    test_noise = gen_noise(A_test, args.noise_std, args.dist)
    test_dataset = PrepareData3D(X=A_test, y=y_test, z=test_noise)
    test_loader = DataLoader(test_dataset, batch_size=args.bs, shuffle=False)
    
    if os.path.exists(os.path.join(args.signpth, args.dataset, str(args.P) + '_' + str(args.kernel_size) + '_' + str(args.noise_std) + '_'+ args.dist + '_u_vector_list.npy')):
        if args.verbose:
            print('loading sign patterns...')

        u_vector_list = np.load(os.path.join(args.signpth, args.dataset, str(args.P) + '_' + str(args.kernel_size) + '_' + str(args.noise_std) + '_' + args.dist + '_u_vector_list.npy'))
    else:
        if args.verbose:
            print('generating sign patterns...')
        u_vector_list = generate_sign_patterns(A, args.P, args, args.verbose)
        
        if args.verbose:
            print('reshaping u vectors...')
        u_vector_list = np.asarray(u_vector_list).reshape((args.P, d)).T

        if args.verbose:
            print('saving sign patterns...')
            if not os.path.exists(args.signpth):
                os.mkdir(args.signpth)
            if not os.path.exists(os.path.join(args.signpth, args.dataset)):
                os.mkdir(os.path.join(args.signpth, args.dataset))
            np.save(os.path.join(args.signpth, args.dataset, str(args.P) + '_' + str(args.kernel_size) + '_' + str(args.noise_std) + '_' + args.dist + '_u_vector_list.npy'), u_vector_list)

    u_vector_list = u_vector_list.reshape((args.P, 1, args.kernel_size, args.kernel_size))
    
    writer = SummaryWriter(log_dir=args.log_dir)
    
    if torch.cuda.device_count() > 1:
        args.parallel = True
    else:
        args.parallel = False

    print(args)
    if args.verbose:
        print('Starting training...')

    w_value, v_value = joint_minimization(ds_train, u_vector_list, args, train_dataset, train_loader, 
                                    test_dataset, test_loader, writer)

    total_loss = validate_model(w_value, v_value, train_dataset, train_loader, test_dataset, test_loader, writer, args.epochs, args, A.shape[1], save=True)
    
    return total_loss

if __name__ == "__main__":
    """ This is executed when run from the command line """
    parser = argparse.ArgumentParser()

    # Required positional argument
    parser.add_argument("--optimizer", help='SGD or Adam', default='Adam')
    parser.add_argument("--lr", help='Initial Learning Rate', default=5e-6, type=float)
    parser.add_argument("--momentum", help='Momentum for SGD', default=0.9, type=float)
    parser.add_argument("--epochs", help='Number of epochs to train', default=100, type=int)
    parser.add_argument("--bs", help='Batch size', default=500, type=int)
    parser.add_argument('--dataset', help='Dataset', default='MNIST', type=str)
    parser.add_argument('--beta', help='Regularization parameter', default=1e-5, type=float)
    parser.add_argument('--rho', help='Hinge loss parameter (penalizes infeasible solutions)', default=1e-4, type=float)
    parser.add_argument('--rho_factor', help='Factor by which rho increases each epoch (to force feasibility)', default=1.0, type=float)
    parser.add_argument('--device', help='Device to use', default='cuda', type=str)
    parser.add_argument('--verbose', help='Show intermediate steps', default=True, type=bool)
    parser.add_argument('--modeldir', help='Directory to save final model', default='./models', type=str)
    parser.add_argument('--outputstr', help='String output', default='1by1_conv_dual_model.pkl', type=str)
    parser.add_argument('--datadir', help='Directory of dataset', default='', type=str)
    parser.add_argument('--P', help='number of random relu patterns to use', default=5000, type=int)
    parser.add_argument('--signpth', help='Path to sign patterns, load if exists, otherwise save', default='1by1_sign_patterns', type=str)
    parser.add_argument('--printfreq', help='Frequency to print', default=1, type=int)
    parser.add_argument('--noise_std', help='standard deviation of noise to add if denoising', default=0.5, type=float)
    parser.add_argument('--use_saved_checkpoint', help='use saved checkpoint for initialization', default=True, type=bool)
    parser.add_argument('--log_dir', help='Directory for tensorboard logs', default='./logs', type=str)
    parser.add_argument('--kernel_size', help='Kernel size for convolution of 1st layer', default=3, type=int)
    parser.add_argument('--padding', help='Padding for convolution of 1st layer', default=1, type=int)
    parser.add_argument('--sign_subsample_factor', help='Factor by which to subsample the training set to determine sign patterns', default=1.0, type=float)
    parser.add_argument('--dataset_subsample_factor', help='Factor by which to subsample the dataset for training', default=0.01, type=float)
    parser.add_argument('--dist', type=str, default='gaussian')

    args = parser.parse_args()
    main(args)
