import argparse
parser = argparse.ArgumentParser(description='Process command-line arguments')
parser.add_argument('--lr', type=float, required=False, help='Learning rate', default=1)
parser.add_argument('--wd', type=float, required=False, help='Weight decay', default=0)
parser.add_argument('--group_norm', type=int, required=False, help='Group L1 norm parameter', default=1e-5)
args = parser.parse_args()
import torch
import pyvww
import functools
from functools import partial
import os
import copy
import torch.nn.functional as F
from torch import nn
import torchvision
from torchvision import transforms, models
import numpy as np #mp modify
import os
from datetime import datetime
import matplotlib.pyplot as plt
import csv
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
torch.manual_seed(1) 
# Hyperparameters
batch_size = 32
evaluation_size = 256
display_interval = 5
#lr_manual_grid = [0.008,0.01,0.02,0.04,0.2,0.3,0.4,0.5,0.6]#0.01# args.lr#
#lr_manual_grid linspace from 0.001 to 1 with 50 steps
#lr_manual_grid = np.linspace(0.01,1,20)
lr_manual_grid = np.logspace(-3,1,100)
num_epochs = 10
transform = transforms.Compose([transforms.Resize([256,256]), transforms.ToTensor()])

# cifar10      = torchvision.datasets.CIFAR10('cifar10/', download=True,  train=True,  transform=transform)
# cifar10_eval = torchvision.datasets.CIFAR10('cifar10/', download=False, train=False, transform=transform)

# Download the COCO dataset from here: https://cocodataset.org/#download
cifar10 = pyvww.pytorch.VisualWakeWordsClassification(root="/hdd/dataset/COCO/all2014", 
                    annFile="instances_train.json",transform=transform)
cifar10_eval = pyvww.pytorch.VisualWakeWordsClassification(root="/hdd/dataset/COCO/all2014", 
                   annFile="instances_val.json",transform=transform)
        
# take the first sample_size images
# cifar10 = torch.utils.data.Subset(cifar10, range(sample_size))
# cifar10_eval = torch.utils.data.Subset(cifar10_eval, range(sample_size))
# Resnet18 model
# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(512, 2)
# model = model.cuda()
# Use a smaller model architecture
class TinyModel(nn.Module):
    def __init__(self):
        super(TinyModel, self).__init__()
        self.input_size = 24
        self.conv1 = nn.Conv2d(self.input_size, 512, 3)
        self.fc1 = nn.Linear(512, 2)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        #adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
        #squeeze
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x
class MyConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.kernel_size = 3
        self.deep_patterns = False
        self.depth = 2
        self.input_size = 40
        self.conv1 = nn.Conv2d(self.input_size, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(self.input_size, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        #self.bn0 = nn.BatchNorm2d(input_size, affine=True)
        if self.deep_patterns == True:
            self.padding_size = int((self.kernel_size-1)/2)
            self.bn1 = nn.BatchNorm2d(self.filters, affine=True)
            self.bn2 = nn.BatchNorm2d(self.filters, affine=True)
            self.bn3 = nn.BatchNorm2d(self.filters, affine=True)
            self.bn4 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn5 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn6 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn7 = nn.BatchNorm2d(self.filters, affine=True)
            #self.bn8 = nn.BatchNorm2d(self.filters, affine=True)
            self.conv3 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            self.conv4 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            #self.conv5 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            #self.conv6 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            ##self.conv7 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            #self.conv8 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=self.padding_size, bias=True)
            self.conv3.weight.requires_grad = False
            #self.conv4.weight.requires_grad = False
            #self.conv5.weight.requires_grad = False
            #self.conv6.weight.requires_grad = False
        self.conv1.weight.requires_grad = True
        self.conv2.weight.requires_grad = False
    def forward(self, x):
        if self.deep_patterns == False:
            #x = self.conv1(self.bn0(x))*(self.conv2(self.bn0(x))>=0)
            x = self.conv1(x)*torch.sign(self.conv2(x))#(self.conv2(x)>=0)#
            #x = self.conv1(x)*(self.conv2(x)>=0)#
            #x = x1*(x2>=torch.median(x2))
        else:
            if self.depth<=3:
                #x = self.conv1(x)*(self.bn2(self.conv2(x)+self.conv3(F.relu(self.bn1(self.conv2(x)))))>=torch.median(self.bn2(self.conv2(x)+self.conv3(F.relu(self.bn1(self.conv2(x)))))))
                x = self.conv1(x)*torch.sign(self.bn2(self.conv2(x)+self.conv3(F.relu(self.bn1(self.conv2(x))))))
            elif self.depth==4: 
                x1 = self.conv1(x)
                x2 = self.conv2(self.bn0(x))
                x3=(x2+self.bn2(self.conv3(torch.sign(self.bn1(x2)))))
                x = x1*(x3+self.bn4(self.conv4(torch.sign(self.bn3(x3))))>=0)
            elif self.depth==5:
                x3 = (x2+self.bn2(self.conv3(F.relu(self.bn1(x2)))))
                x4 = (x3+self.bn4(self.conv4(F.relu(self.bn3(x3)))))  
                x = x1*(x4+self.bn6(self.conv6(F.relu(self.bn5(x4))))>=0) 
            elif self.depth==6:
                x3 = (x2+self.bn2(self.conv3(F.relu(self.bn1(x2)))))
                x4 = (x3+self.bn4(self.conv4(F.relu(self.bn3(x3)))))  
                x5 = (x4+self.bn6(self.conv5(F.relu(self.bn5(x4)))))  
                x = x1*(x5+self.bn8(self.conv8(F.relu(self.bn7(x5))))>=0)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) # global average pooling
        #sum the columns of x over the first half and the second half
        xsum1 = torch.sum(x[:,:x.shape[1]//2],dim=1)
        xsum2 = torch.sum(x[:,x.shape[1]//2:],dim=1)
        #stack two sums as a new tensor 
        x = torch.stack((xsum1,xsum2),axis=1)
        return x
class MyFCConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        #fc layer
        self.input_size = 576
        #self.w1 = torch.nn.Parameter(torch.zeros(self.input_size, self.filters, depths['num_classes']))
        #self.w2 = torch.nn.Parameter(torch.zeros(self.input_size, self.filters, depths['num_classes']))
        self.fc1 = nn.Linear(self.input_size, self.filters, bias=True)
        self.fc2 = nn.Linear(self.input_size, self.filters, bias=True)
        #self.fc0 = nn.Linear(self.input_size, 2, bias=True)
        #set fc2 to be non-trainable
        self.fc2.weight.requires_grad = False
        # self.bn1 = nn.BatchNorm2d(self.filters, affine=False)
        #self.threshold = torch.nn.Threshold(0,0)
    def forward(self, x):
        #x0 = self.fc0(x)
        x = self.fc1(x)*(self.fc2(x)>=0)
        #x1 = self.fc1(x)
        #x2 = self.fc2(x)
        #x = x1*self.threshold(x2+self.conv3(F.relu(x2)))
        #sum the columns of x over the first half and the second half
        xsum1 = torch.sum(x[:,:x.shape[1]//2],dim=1)
        xsum2 = torch.sum(x[:,x.shape[1]//2:],dim=1)
        #stack two sums as a new tensor 
        x = torch.stack((xsum1,xsum2),axis=1)
        #x = x+x0
        return x        
class MyNONConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.input_size = 40
        self.kernel_size = 9
        self.conv1 = nn.Conv2d(self.input_size, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(self.filters, 2, bias=True)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        #x = torch.mean(x, dim=0)
        #print(shape(x))
        #x = self.fc1(x)
        return x 
class MyFCNONConvexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.input_size = 576# for mobilenetv3small #768 #for ViT #
        self.fc1 = nn.Linear(self.input_size, self.filters, bias=True)
        self.fc2 = nn.Linear(self.filters, 2, bias=True)
        #self.fc0 = nn.Linear(self.input_size, 2, bias=True)
    def forward(self, x):
        x = self.fc2(F.relu(self.fc1(x))) #+ self.fc0(x)
        #x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) #mp modify
        #x = torch.mean(x, dim=0)
        #print(shape(x))
        #x = self.fc1(x)
        return x 
class MyNONConvexNet3(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = 512
        self.input_size = 24
        self.kernel_size = 3
        self.conv1 = nn.Conv2d(self.input_size, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.conv2 = nn.Conv2d(self.filters, self.filters, self.kernel_size, stride=1, padding=0, bias=True)
        self.fc1 = nn.Linear(self.filters, 2, bias=True)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        #x = torch.mean(x, dim=0)
        #print(shape(x))
        #x = self.fc1(x)
        return x          
class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 1, 1, stride=1, padding=0, bias=False)
        self.fc1 = nn.Linear(1, 2, bias=True)
    def forward(self, x):
        x = self.conv1(x) 
        x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
        x = self.fc1(x)
        return x   
    
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
#create an array to store training and test accuracy
train_acc_lr = []
test_acc_lr = []
#loop over lr_manual_grid
for lr_manual in lr_manual_grid:
    arch = 'mobilenetfifthlayer_twolayerCVX'
    model = models.mobilenet_v3_small(pretrained=True)
    if arch=='full_linear':
        model = nn.Sequential(
            model,
            nn.Linear(768,2))#MyFCConvexNet())#
    elif arch=='full_twolayerncvx':
        model = nn.Sequential(
            model,
            MyFCNONConvexNet())
    elif arch=='classifier_linear':
        model.classifier = Identity() #for mobilenetv3
        #model.heads = Identity()#for ViT
        model = nn.Sequential(
            model,
            nn.Linear(576,2))
    elif arch=='classifier_low_rank_linear':
        model.classifier = Identity()
        model = nn.Sequential(
            model,
            nn.Linear(576,512),
            nn.Linear(512,2))    
    elif arch=='classifier_twolayerNONcvx':
        model.classifier = Identity() #for mobilenetv3
        #model.heads = Identity()#for ViT
        model = nn.Sequential(
            model,
            MyFCNONConvexNet())
    elif arch=='classifier_twolayerCVX':
        model.classifier = Identity()
        #model.heads = Identity()
        model = nn.Sequential(
            model,
            MyFCConvexNet())
    elif arch=='mobilenetfirstlayer_twolayerCVX':
        model = nn.Sequential(
            model.features[0],
            MyConvexNet()) 
    elif arch=='mobilenetsecondlayer_twolayerCVX':
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            MyConvexNet()) 
    elif arch=='mobilenetthirdlayer_twolayerNONCVX':
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            model.features[2],
            MyNONConvexNet()) 
    elif arch=='mobilenetfourthlayer_twolayerCVX':
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            model.features[2],
            model.features[3],
            MyConvexNet()) 
    elif arch=='mobilenetfifthlayer_twolayerCVX':
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            model.features[2],
            model.features[3],
            model.features[4],
            MyConvexNet())
    elif arch=='mobilenetfifthlayer_twolayerNONCVX':
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            model.features[2],
            model.features[3],
            model.features[4],
            MyNONConvexNet())
    elif arch=='mobilenetthirdlayer_twolayerNONCVX3':
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            model.features[2],
            MyNONConvexNet3())
    elif arch=='mobilenetthirdlayer_tinyNONCONVEX':    
        model = nn.Sequential(
            model.features[0],
            model.features[1],
            model.features[2],
            TinyModel())
        #set model parameters to be non-trainable for the first sequential block
    for param in model.parameters():
        param.requires_grad = True
    for param in model[-1].parameters():
        param.requires_grad = True

    model = model.to('cuda')
    #model = SmallModel().cuda()
    # use multiple GPUs
    #model = nn.DataParallel(model)

    # Loss function
    loss_fn = nn.CrossEntropyLoss()
    # Optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=lr_manual, weight_decay=args.wd)
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr_manual, weight_decay=args.wd)

    # decrease learning rate by a factor of 10 every 1 epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
    # Constant size scheduler
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=1)

    # Training loop
    def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=1, device="cuda"):
        train_losses = []
        val_losses = []
        print_interval = display_interval  # Number of mini-batches after which running accuracy is printed
        val_iter = iter(val_loader)  # Create an iterator for the validation loader
        mini_train_losses = []
        mini_validation_losses = []
        for epoch in range(epochs):
            print("Epoch", epoch)
            running_loss_train = 0.0
            running_corrects_train = 0
            model.train()
            for i, (images, labels) in enumerate(train_loader):
                images = images.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                loss = loss_fn(outputs, labels)
                #loop over each neuron in the conv1 layer of model[-1] and add group l2 regularization to the loss
                groupnorm = True
                if groupnorm == True:
                    for param in model[-1].conv1.parameters():
                        if len(param.shape) > 1: #don't regularize biases
                          loss += args.group_norm * torch.norm(param, p=2, dim=(1,2,3)).sum()


                loss.backward()
                optimizer.step()
                running_loss_train += loss.item() * images.size(0)
                running_corrects_train += torch.sum(preds == labels.data)
                
                # Print running accuracy for training phase and evaluate on a validation mini-batch
                if (i + 1) % print_interval == 0:
                    running_acc_train = running_corrects_train.double() / ((i + 1) * images.size(0))
                    #save accuracy to an array
                    mini_train_losses.append(running_acc_train)
                    
                    # Evaluate on a mini-batch from the validation set
                    try:
                        val_images, val_labels = next(val_iter)
                    except StopIteration:
                        val_iter = iter(val_loader)
                        val_images, val_labels = next(val_iter)
                    val_images = val_images.to(device)
                    val_labels = val_labels.to(device)
                    val_outputs = model(val_images)
                    _, val_preds = torch.max(val_outputs, 1)
                    val_acc = torch.sum(val_preds == val_labels.data).double() / val_images.size(0)
                    #save validation accuracy to an array
                    mini_validation_losses.append(val_acc)
                    
                    print(f"Epoch: {epoch}, Mini-batch: {i + 1}, Training Running Accuracy: {running_acc_train:.4f}, Validation Mini-batch Accuracy: {val_acc:.4f}")
            
            # Compute and store epoch-level accuracy and loss
            epoch_loss_train = running_loss_train / len(train_loader.dataset)
            epoch_acc_train = running_corrects_train.double() / len(train_loader.dataset)
            train_losses.append(epoch_loss_train)
            val_losses.append(epoch_acc_train)  # Store epoch-level training accuracy as validation accuracy (for demo)
            print(f"End of Epoch {epoch}: train_loss={epoch_loss_train:.4f}, train_acc={epoch_acc_train:.4f}")
        
        return train_losses, val_losses, mini_train_losses, mini_validation_losses
    # Testing loop
    def test(model, test_loader, device="cuda"):
        model.eval()
        running_corrects = 0
        for images, labels in test_loader:  # Correctly unpack images and labels from batch
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
        test_acc = running_corrects.double() / len(test_loader.dataset)
        print("Test accuracy:", test_acc)
        return test_acc

    # Train the model
    train_loader = torch.utils.data.DataLoader(cifar10, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(cifar10_eval, batch_size=evaluation_size, shuffle=False, num_workers=4)
    train_losses, val_losses, mini_train_losses, mini_validation_losses = train(model, optimizer, loss_fn, train_loader, val_loader, epochs=num_epochs)
    # Test the model
    test_loader = torch.utils.data.DataLoader(cifar10_eval, batch_size=batch_size, shuffle=False, num_workers=4)
    test_acc = test(model, test_loader)
    # Print the final test accuracy
    print("Final test accuracy:", test_acc)
    now = datetime.now()
    dt_string = now.strftime("%d_%m_%Y_%H_%M_%S")
    save_activations = False
    if save_activations:
        torch.save(model.state_dict(), './results/models/NONCONVEX_model_' + dt_string + '.pt')
        train_loader = torch.utils.data.DataLoader(cifar10, batch_size=batch_size, shuffle=True, num_workers=4)
        #iterate over the training dataset
        #create numpy array to store the activations
        activations = np.array([])
        for i, (images, labels) in enumerate(train_loader):
            #get the activations of the first layer
            images = images.to('cuda')
            activation = model[0][0](images)
            #save the activations flattened to a file row by row
            activation = activation.view(activation.size(0), -1)
            #append activation vector to the numpy array activations
            activations = np.append(activations, activation.detach().cpu().numpy())

        #save the activations to a file
        np.savetxt('/results/activations/CONVEX_activations_' + dt_string + '.txt', activations)

    #plot mini_train_losses in black and mini_validation_losses in red and save as pdf with
    #current date and time
    mini_train_losses = [x.cpu().numpy() for x in mini_train_losses]
    mini_validation_losses = [x.cpu().numpy() for x in mini_validation_losses]
    plt.plot(mini_train_losses, label="train")
    plt.plot(mini_validation_losses, label="val")
    plt.xlabel('Mini-batches')
    plt.ylabel('Accuracy')
    plt.legend()
    #save to file name with current date and time
    #plt.savefig('./results/full_data/CONVEX_mini_train_val_losses_' + dt_string + '.pdf')
    #plt.show()
    #append training accuracy to train_acc_lr
    train_acc_lr.append(train_losses[-1])
    #append validation accuracy to val_acc_lr
    test_acc_lr.append(test_acc)
    #print the training and validation accuracy
    print("Training accuracy list:", train_acc_lr)
    print("Validation accuracy list :", test_acc_lr)
    #save lr_manual_grid, train_acc_lr, val_acc_lr to a file
    #save to csv file name with current date and time
    with open('./results/full_data_ubuntu_may29_steplr/CONVEX_train_val_acc_' + dt_string + '.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['learning rate'])
        writer.writerow([lr_manual])
        writer.writerow(['train_acc'])
        writer.writerow(train_acc_lr)
        writer.writerow(['test_acc'])
        writer.writerow(test_acc_lr)
        writer.writerow(['mini_train_losses'])
        writer.writerow(mini_train_losses)
        writer.writerow(['mini_validation_losses'])
        writer.writerow(mini_validation_losses)
# lr_manual_grid = [x.item() for x in lr_manual_grid]
# train_acc_lr = [x.item() for x in train_acc_lr]
# test_acc_lr = [x.item() for x in test_acc_lr]
print("folder: full_data_ubuntu_may29_steplr")
print("Architecture:", arch)
print("Number of epochs:", num_epochs)
print("Batch size:", batch_size)
print("Learning rate grid:", (lr_manual_grid))
print("Training loss list:", (train_acc_lr))
print("Validation accuracy list :", (test_acc_lr))
print("Maximum validation accuracy:", max((test_acc_lr)))

#save the model
torch.save(model.state_dict(), './results/models/CONVEX_model_' + dt_string + '.pt')
#save the training and validation accuracy to a file
#save to csv file name with current date and time
with open('./results/full_data/CONVEX_train_val_acc_' + dt_string + '.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['learning rate'])
    writer.writerow([lr_manual])
    writer.writerow(['train_acc'])
    writer.writerow(train_acc_lr)
    writer.writerow(['test_acc'])
    writer.writerow(test_acc_lr)
    writer.writerow(['mini_train_losses'])
    writer.writerow(mini_train_losses)
    writer.writerow(['mini_validation_losses'])
    writer.writerow(mini_validation_losses)
#close file
f.close()

    # Plot the training and validation losses
    #plt.plot(train_losses, label="train")
    #plt.plot(val_losses, label="val")


