# In[1]:
import os

# libraries 
import numpy as np
import dill
import pickle
from datetime import datetime
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import time
import scipy
from scipy.sparse.linalg import LinearOperator
from scipy.fftpack import dct
import torch
import sklearn.linear_model
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import argparse
import random




def parse_args():
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, nargs=1)
    parser.add_argument('--num_neurons', nargs=1, type=int, required=True)
    parser.add_argument('--num_path', nargs=1, type=int, required=True)
    parser.add_argument('--P', nargs=1, type=int, required=True)
    parser.add_argument('--n_epochs', nargs=2, type=int, required=True)
    parser.add_argument('--lr', type=float, nargs=2, required=True)
    parser.add_argument('--solver', type=str, nargs=2)
    parser.add_argument('--schedule', type=int, nargs=1)
    parser.add_argument('--save', nargs=1, type=int, required=True)
    parser.add_argument('--seed', type=int, default=42)
    # Parsing
    args = parser.parse_args()
    random.seed(a=args.seed)
    np.random.seed(seed=args.seed)
    torch.manual_seed(seed=args.seed)

    return args


ARGS=parse_args()

# In[2]:

class PrepareData(Dataset):
    def __init__(self, X, y):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X)
        else:
            self.X = X
            
        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y)
        else:
            self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# Samples convolutional random layer weights to sample hyperplane arrangements
def generate_conv_sign_patterns(A2, P2, m1, verbose=False): 
    # generate sign patterns
    #n, d = A.shape
    n, c, p1, p2 = A2.shape
    print((c,p1,p2))
    A = A2.reshape(n,int(c*p1*p2))
    fsize=9*c
    d=c*p1*p2;
    fs=int(np.sqrt(fsize//c))
    umat1=np.zeros((d,m1,P2))
    umat2=np.zeros((m1,P2))
    for i in range(P2): 
        # obtain a sign pattern
        for j in range(m1):
            ind1=np.random.randint(0,p1-fs+1)
            ind2=np.random.randint(0,p2-fs+1)
            u1p= np.zeros((c,p1,p2))
            u1p[:,ind1:ind1+fs,ind2:ind2+fs]=np.random.normal(0, 1, (fsize,1)).reshape(c,fs,fs)
            umat1[:,j,i]=u1p.reshape(d,)#/np.linalg.norm(u1p.reshape(d,))

        u2 = np.random.normal(0, 1, (m1,1))#2*(np.random.normal(0, 1, (m1,1))>0)-1 # sample u
        umat2[:,i]=u2.reshape(-1,)
    return umat1, umat2

# Converts digits to one hot encoded representation
def one_hot(labels, num_classes=10):
    """Embedding labels to one-hot form.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """
    y = torch.eye(num_classes) 
    return y[labels.long()] 


# In[3]:
    
#=====================================STANDARD NON-CONVEX NETWORK=====================================

# Defines loss function
def loss_func_nonstandard(yest,y,model,beta,K):
    loss=0.5*torch.norm(yest-y,p='fro')**2#/y.shape[0]
    m1=model.W1[:,:,1].shape[1]
    for k in range(K):
        temp=0
        temp=torch.norm(torch.mul(model.W1[:,:,k],(model.W2[:,k].reshape(1,-1))),p='fro')**2*torch.norm(model.W3[k,:],p=1)**2
        loss=loss+0.5*beta*torch.sqrt(temp)
    return loss

# Obtains validation results
def validation_nonstandard(model, testloader, beta, K, num_classes, device):
    test_loss = 0
    test_correct = 0

    for ix, (_x, _y) in enumerate(testloader):
        _x = Variable(_x).to(device)
        _y = Variable(_y).to(device)
        _x = _x.view(_x.shape[0], -1)

        output = model.forward(_x)
        yhat = model(_x).float()
        loss = loss_func_nonstandard(yhat, one_hot(_y, num_classes).to(device), model, beta, K)
        test_loss += loss.item()
        test_correct += torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()
    return test_loss, test_correct
 

# Defines the parallel ReLU network architecture
class Net(nn.Module):
    def __init__(self, D_in, num_neurons, K, sigma, num_classes=10):
        super(Net, self).__init__()
        #sigma=0.05# 0.05
        self.W1=torch.nn.Parameter(data=sigma*torch.randn(D_in,num_neurons,K),requires_grad=True)
        self.W2=torch.nn.Parameter(data=sigma*torch.randn(num_neurons,K),requires_grad=True)
        self.W3=torch.nn.Parameter(data=sigma*torch.randn(K,num_classes),requires_grad=True)
        
    def forward(self,x):
        x = x.reshape(x.size(0), -1)
        d, num_neurons, K=self.W1.size()
        _,num_classes=self.W3.size()
        y_pred=torch.zeros(x.shape[0],num_classes).to('cuda')                
        for k in range(K):
            o_k=torch.matmul(x,self.W1[:,:,k])
            o_k=F.relu(o_k)
            o_k=torch.matmul(o_k,self.W2[:,k])
            o_k=torch.reshape(F.relu(o_k),(-1,1))
            o_k=torch.matmul(o_k,torch.reshape(self.W3[k,:],(1,-1)))
            y_pred=y_pred+o_k
        return y_pred


# Main function to solve optimization problem
def sgd_solver_pytorch_v2(ds, ds_test, num_epochs, num_neurons,K, beta, sigma,
                         learning_rate, batch_size, solver_type, schedule, 
                          LBFGS_param, verbose=False, 
                        num_classes=10, D_in=3*1024, test_len=10000, 
                          train_len=50000, device='cuda'):
    

    device = torch.device(device)
    # H is hidden dimension, D_out is output dimension.
    H, D_out = num_neurons, num_classes
    
    # create the model
    model = Net(D_in, num_neurons, K, sigma, num_classes).to(device)
    
    # select optimizer
    if solver_type == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    elif solver_type == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#,
    elif solver_type == "adagrad":
        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)#,
    elif solver_type == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)#,
    elif solver_type == "LBFGS":
        optimizer = torch.optim.LBFGS(model.parameters(), history_size=LBFGS_param[0], max_iter=LBFGS_param[1])#,
        
    # arrays for saving the loss and accuracy    
    losses = np.zeros((int(num_epochs*np.ceil(train_len / batch_size))))
    accs = np.zeros(losses.shape)
    losses_test = np.zeros((num_epochs+1))
    accs_test = np.zeros((num_epochs+1))
    times = np.zeros((losses.shape[0]+1))
    times[0] = time.time()
    
    model.eval()
    losses_test[0], accs_test[0] = validation_nonstandard(model, ds_test, beta, K, num_classes,  device) # loss on the entire test set
    
    # select a schedule for the learning rate
    if schedule==1:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=verbose,
                                                           factor=0.5,
                                                           eps=1e-12)
    elif schedule==2:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
        
    # main training loop
    iter_no = 0
    for i in range(num_epochs):
        model.train()

        for ix, (_x, _y) in enumerate(ds):
            #=========make input differentiable=======================
            _x = Variable(_x).to(device)
            _y = Variable(_y).to(device)
            
            #========forward pass=====================================
            yhat = model(_x).float()
            
            loss = loss_func_nonstandard(yhat, one_hot(_y, num_classes).to(device), model, beta,K)/len(_y)
            correct = torch.eq(torch.argmax(yhat, dim=1), torch.squeeze(_y)).float().sum()/len(_y)
            
           
            optimizer.zero_grad() # zero the gradients on each pass before the update
            loss.backward() # backpropagate the loss through the model
            optimizer.step() # update the gradients w.r.t the loss

            losses[iter_no] = loss.item() # loss on the minibatch
            accs[iter_no] = correct # number of correctly classified samples on the minibatch
        
            iter_no += 1
            times[iter_no] = time.time()
        
        model.eval()

        # get test loss and accuracy
        losses_test[i+1], accs_test[i+1] = validation_nonstandard(model, ds_test, beta, K, num_classes,  device) # loss on the entire test set

        if i % 1 == 0:
            print("Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}".format(i, num_epochs,
                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1], 3), 
                    np.round(losses_test[i+1], 3)/test_len, np.round(accs_test[i+1]/test_len, 3)))
        if schedule>0:
            scheduler.step(losses[iter_no-1])
            
    return losses, accs, losses_test/test_len, accs_test/test_len, times, model


# In[4]:
#=====================================CONVEX NETWORK=====================================

# Defines the equivalent convex architecture
class custom_cvx_layer(torch.nn.Module):
    def __init__(self, d, num_neurons, P2, sigma, num_classes=10):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(custom_cvx_layer, self).__init__()
        # P x d x C
        self.v = torch.nn.Parameter(data=sigma*torch.randn(num_neurons,  P2, d, num_classes), requires_grad=True)
        self.w = torch.nn.Parameter(data=sigma*torch.randn(num_neurons,  P2, d, num_classes), requires_grad=True)

    def forward(self, x, u_vectors1, u_vectors2):
        y_pred=torch.zeros(x.shape[0],self.v.shape[3]).to('cuda')

        x = x.view(x.shape[0], -1) # n x d
        for i in range(P2):
            # sample hyperplane arrangements for the first and second layer
            sign_patterns1 = torch.matmul(x,u_vectors1[:,:,i])>=0 
            sign_patterns2 = torch.matmul(F.relu(torch.matmul(x,u_vectors1[:,:,i])),u_vectors2[:,i])>=0
            
            x = x.view(x.shape[0], -1) 
            Xv_w = torch.matmul(x, self.v[:,i,:,:] - self.w[:,i,:,:]) 
            DXv_w1 = torch.mul(sign_patterns1.unsqueeze(2), Xv_w.permute(1, 0, 2)) 
            DXv_w2 = torch.mul(sign_patterns2.unsqueeze(1).unsqueeze(2), DXv_w1) 
            y_pred =y_pred+ torch.sum(DXv_w2, dim=1, keepdim=False) 
        return y_pred

# Defines the loss function for the convex program
def loss_func_cvxproblem(yhat, y, model, _x, u_vectors1, u_vectors2, beta, rho, device):
    _x = _x.view(_x.shape[0], -1)
    
    # term 1 training loss
    loss = 0.5 * torch.norm(yhat - y, p='fro')**2
    
    # term 2 regularization penalty
    loss = loss + 0.5*beta * torch.sum(torch.norm(model.v, dim=(0,2),p='fro'))
    loss = loss + 0.5*beta * torch.sum(torch.norm(model.w, dim=(0,2),p='fro'))

    
    # term 3 for penalizing the violated constraints
    P2=u_vectors1.shape[2]
    if rho>0:
        for i in range(P2):

            sign_patterns1 = torch.matmul(_x,u_vectors1[:,:,i])>=0
            sign_patterns2 = torch.matmul(F.relu(torch.matmul(_x,u_vectors1[:,:,i])),u_vectors2[:,i])>=0
        
            Xv = torch.matmul(_x, model.v[:,i,:,:] ) # m1 x n x C
            Xw = torch.matmul(_x, model.w[:,i,:,:] ) # m1 x n x C
            
            DXv1 = torch.mul(sign_patterns1.unsqueeze(2), Xv.permute(1, 0, 2)) #  n x m1 x C
            DXw1 = torch.mul(sign_patterns1.unsqueeze(2), Xw.permute(1, 0, 2)) #  n x m1 x C

            relu_term11 = torch.max(-2*DXv1 + Xv.permute(1, 0, 2), torch.Tensor([0]).to(device)) 
            relu_term12 = torch.max(-2*DXw1 + Xw.permute(1, 0, 2), torch.Tensor([0]).to(device))    

            loss = loss + rho * torch.sum(relu_term11)+ rho * torch.sum(relu_term12)

            DXv2 = torch.sum(torch.mul(sign_patterns2.unsqueeze(1).unsqueeze(2), DXv1),dim=1, keepdims=False) #  n x  C
            relu_term21 = torch.max(-2*DXv2 + torch.sum(DXv1,dim=1, keepdims=False), torch.Tensor([0]).to(device))    
    
            DXw2 = torch.sum(torch.mul(sign_patterns2.unsqueeze(1).unsqueeze(2), DXw1),dim=1, keepdims=False) #  n x  C
            relu_term22 = torch.max(-2*DXw2 + torch.sum(DXw1,dim=1, keepdims=False), torch.Tensor([0]).to(device))    
                
            loss = loss + rho * torch.sum(relu_term21)+ rho * torch.sum(relu_term22)
    
    return loss


# Obtains validation results
def validation_cvxproblem(model, testloader, u_vectors1, u_vectors2, beta, rho, num_classes,  device):
    test_loss = 0
    test_correct = 0

    with torch.no_grad():
        for ix, (_x, _y) in enumerate(testloader):
            _x = Variable(_x).to(device)
            _y = Variable(_y).to(device)
            _x = _x.view(_x.shape[0], -1)
            
            output = model.forward(_x, u_vectors1, u_vectors2)
            yhat = model(_x, u_vectors1, u_vectors2).float()
            loss = loss_func_cvxproblem(yhat, one_hot(_y, num_classes= num_classes).to(device), model, _x, u_vectors1, u_vectors2, beta, rho, device)
            test_loss += loss.item()
            test_correct += torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()

    return test_loss, test_correct

# Main function to solve the convex optimization problem
def sgd_solver_cvxproblem(ds, ds_test, num_epochs, P2, num_neurons, beta, sigma, 
                       learning_rate, batch_size, rho, u_vectors1, u_vectors2, 
                          solver_type,schedule, LBFGS_param, verbose=False,
                         n=50000, d=3072, num_classes=10, device='cuda'):
    device = torch.device(device)

    # create the model
    model = custom_cvx_layer(d, num_neurons, P2, sigma, num_classes).to(device)
    
    # select optimizer
    if solver_type == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    elif solver_type == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)#,
    elif solver_type == "adagrad":
        optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)#,
    elif solver_type == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)#,
    elif solver_type == "LBFGS":
        optimizer = torch.optim.LBFGS(model.parameters(), history_size=LBFGS_param[0], max_iter=LBFGS_param[1])#,
    
    # arrays for saving the loss and accuracy 
    losses = np.zeros((int(num_epochs*np.ceil(n / batch_size))))
    accs = np.zeros(losses.shape)
    losses_test = np.zeros((num_epochs+1))
    accs_test = np.zeros((num_epochs+1))
    times = np.zeros((losses.shape[0]+1))
    times[0] = time.time()
    
    
    # select a laerning rate schedule
    if schedule==1:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           verbose=verbose,
                                                           factor=0.5,
                                                           eps=1e-12)
    elif schedule==2:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
    
    model.eval()
    losses_test[0], accs_test[0] = validation_cvxproblem(model, ds_test, u_vectors1, u_vectors2, beta, rho, num_classes, device) # loss on the entire test set
    
    # main for training
    iter_no = 0
    print('starting training')
    for i in range(num_epochs):
        model.train()
        for ix, (_x, _y) in enumerate(ds):
            #=========make input differentiable=======================
            _x = Variable(_x).to(device)
            _y = Variable(_y).to(device)
            
            #========forward pass=====================================
            yhat = model(_x, u_vectors1, u_vectors2).float()
            
            loss = loss_func_cvxproblem(yhat, one_hot(_y, num_classes).to(device), model, _x,u_vectors1, u_vectors2, beta, rho, device)/len(_y)
            correct = torch.eq(torch.argmax(yhat, dim=1), _y).float().sum()/len(_y) # accuracy
            #=======backward pass=====================================
            optimizer.zero_grad() # zero the gradients on each pass before the update
            loss.backward() # backpropagate the loss through the model
            optimizer.step() # update the gradients w.r.t the loss

            losses[iter_no] = loss.item() # loss on the minibatch
            accs[iter_no] = correct # number of correctly classified samples on the minibatch
        
            iter_no += 1
            times[iter_no] = time.time()
        
        model.eval()
        # get test loss and accuracy
        losses_test[i+1], accs_test[i+1] = validation_cvxproblem(model, ds_test, u_vectors1, u_vectors2, beta, rho, num_classes, device) # loss on the entire test set
        
        if i % 1 == 0:
            print("Epoch [{}/{}], TRAIN: cvx loss:  {} acc: {}. TEST: cvx loss: {} acc: {}".format(i, num_epochs,
                     np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1], 3), 
                    np.round(losses_test[i+1], 3)/10000, np.round(accs_test[i+1]/10000, 3)))
        
        if schedule>0:
            scheduler.step(losses[iter_no-1])
        
    return accs, accs_test/10000, times, losses, losses_test/10000



# In[5]:


# cifar-10 -- using the version downloaded from "http://www.cs.toronto.edu/~kriz/cifar.html"

import os 
directory = os.path.dirname(os.path.realpath(__file__))

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision


normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])


# select a dataset
if ARGS.dataset[0]=='CIFAR10':
    dataset='cifar10'
    train_dataset = datasets.CIFAR10(
        directory, train=True, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))
    
    test_dataset = datasets.CIFAR10(
        directory, train=False, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

    num_classes=10
    n=50000
    d=3072
    
elif ARGS.dataset[0]=='CIFAR100':
    dataset='cifar100'

        
    train_dataset = datasets.CIFAR100(
        directory, train=True, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))
    
    test_dataset = datasets.CIFAR100(
        directory, train=False, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))
    num_classes=100
    n=50000
    d=3072
elif ARGS.dataset[0]=='FMNIST':
    dataset='fmnist'
    
    normalize = transforms.Normalize((0.5,), (0.5,))
    
    train_dataset = torchvision.datasets.FashionMNIST(
        root = './data_mnist',
        train = True,
        download = True,
        transform = transforms.Compose([
            transforms.ToTensor(),normalize                                 
        ])
    )
    
    test_dataset = torchvision.datasets.FashionMNIST(
        root = './data_mnist',
        train = False,
        download = True,
        transform = transforms.Compose([
            transforms.ToTensor() ,normalize                                
        ])
    )
    num_classes=10
    n=60000
    d=784

# In[6]
# data extraction
print('Extracting the data')
dummy_loader= torch.utils.data.DataLoader(
    train_dataset, batch_size=n, shuffle=False,
    pin_memory=True, sampler=None)
for A, y in dummy_loader:
    pass
Apatch=A.detach().clone()
A = A.view(A.shape[0], -1)
print(A.shape)

# In[7]:

# problem parameters
num_neurons, verbose = ARGS.num_neurons[0], True # SET verbose to True to see progress
beta = 1e-3 # regularization parameter
num_epochs1, batch_size =  ARGS.n_epochs[0], 1000 #
K=ARGS.num_path[0] # number of subnetworks for the non-convex problem
sigma1 = 5e-3 # initialization magnitude
schedule=ARGS.schedule[0] # learning rate schedule (0: Nothing, 1: ReduceLROnPlateau, 2: ExponentialLR)


# In[8]:

# prepare training and test loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    pin_memory=True, sampler=None)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size , shuffle=False,
    pin_memory=True)


solver_type = ARGS.solver[0] # pick: "sgd", "adam", "adagrad", "adadelta", "LBFGS"
schedule=0 # learning rate schedule
LBFGS_param = [10, 4] # these parameters are for the LBFGS solver
learning_rate =ARGS.lr[0]

print(solver_type+'-mu={}'.format(learning_rate))

results_noncvx = sgd_solver_pytorch_v2(train_loader, test_loader, num_epochs1, num_neurons, K, beta, sigma1,
                         learning_rate, batch_size, solver_type, schedule,
                          LBFGS_param, verbose=True, 
                        num_classes=num_classes, D_in=d, train_len=n )


# In[9]:

# problem paramaters
P2=ARGS.P[0] # number of hyperplane arrangements for the convex problem
rho = 1e-5 # coefficient to penalize the violated constraints
solver_type = ARGS.solver[1] # pick: "sgd" or "LBFGS"
LBFGS_param = [10, 4] # these parameters are for the LBFGS solver
learning_rate = ARGS.lr[1] # 1e-6 for sgd
batch_size = 1000
num_epochs2, batch_size = ARGS.n_epochs[1], 1000  # 50
sigma2 = 1e-5 # initialization magnitude


    
#  Convex
print('Generating sign patterns')
u_vector_list1, u_vector_list2  = generate_conv_sign_patterns(Apatch, P2, num_neurons, verbose)
u_vector_list2[:,-1]=np.ones((num_neurons,))

u1_vector=torch.from_numpy(u_vector_list1).float().to('cuda')
u2_vector= torch.from_numpy(u_vector_list2).float().to('cuda')


ds_train = PrepareData(X=A, y=y)
ds_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)


print('Convex Random-mu={}'.format(learning_rate))

results_cvx = sgd_solver_cvxproblem(ds_train, test_loader, num_epochs2, P2, num_neurons, beta,sigma2, 
                        learning_rate, batch_size, rho, u1_vector, u2_vector, solver_type,schedule, LBFGS_param, verbose=True, 
                                         n=n, d=d, num_classes=num_classes, device='cuda')
    

# saves the results
if ARGS.save[0]==1:
    import pickle
    from datetime import datetime
    now = datetime.now()

    results_noncvxv2=results_noncvx[:5]
    print('Saving the objects')
    torch.save([num_epochs1,num_epochs2,results_noncvxv2, results_cvx],'neurips_results_parallel_path_'+ARGS.solver[1]+'_'+dataset+'_'+now.strftime("%d-%m-%Y_%H-%M-%S")+'.pt')

