import argparse
import os

os.environ['MKL_THREADING_LAYER'] = 'GNU'


import shutil
from datetime import datetime

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import train_util
import data_util as data_util
#import eval_loss_util
import random
import models as models

import torch.nn.functional as F

from normalized import normalized_gradient


parser = argparse.ArgumentParser(description='CIFAR-10 Training')
parser.add_argument('--epochs', default=50, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--frst_ann', '--fa', default=25, type=int,
                    help='first annealing time')
parser.add_argument('--snd_ann', '--sa', default=40, type=int,
                    help='second annealing time')
parser.add_argument('--train_batch_size',  default=1024, type=int,
                    help='train mini-batch size (default: 1024)')
parser.add_argument('--test_batch_size', default=100, type=int,
                    help='test mini-batch size (default: 100)')
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    help='initial learning rate')
parser.add_argument('--save', default="False", type=str,
                    help='Save or not')
parser.add_argument('--momentum', '--m', default=0., type=float, help='momentum')
parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--print_freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--name', default='CIFAR-10-VGG16', type=str,
                    help='name of experiment')
parser.add_argument('--seed', '-s', default=12, type=int,
                    help='seed (default: 0)')

parser.add_argument('--d',  default=200, type=int,
                    help='dimension of the problem')
parser.add_argument('--m_train',  default=20000, type=int,
                    help='number of training datapoints in the problem')
parser.add_argument('--m_test',  default=2000, type=int,
                    help='number of test datapoints in the problem')
parser.add_argument('--P',  default=5, type=int,
                    help='number of patches')
parser.add_argument('--neurons',  default=10, type=int,
                    help='number of neurons in the hidden layer')
parser.add_argument('--alpha',  default=10, type=int,
                    help='large margin')
parser.add_argument('--beta',  default=0.1, type=int,
                    help='small margin')
parser.add_argument('--mu',  default=0.2, type=int,
                    help='fraction of datapoints with small margin')
parser.add_argument('--sigma2',  default=0.1, type=float,
                    help='variance control of the covariance')
parser.add_argument('--sigma02',  default=0.1, type=float,
                    help='variance control of the covariance at init')

parser.add_argument('--k',  default=1, type=int,
                    help='k')



def main():
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
 
    train_loader, test_loader, W= data_util.load_data(args)
    #Add the negatives
    W=torch.cat((W,-W),dim=1)
    
    # Seed is set in data_util as well so set them one more time.
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    
    print("first 2 data points of the training dataset")
    print(train_loader.dataset.y[:2])
    print(train_loader.dataset.X[:2])

    print("#############################################################")
    print("#############################################################")
    print("NORM")
    print(torch.norm(train_loader.dataset.X,dim=0))
   

    print("#############################################################")
    print("#############################################################")
    model = models.nett(args.d, args.neurons, args.P)

    print("fc1.weight")
    print(model.fc1.weight)
    print("fc1.bias")
    print(model.fc1.bias)
    print("fc2.weight")
    print(model.fc2.weight)
    print("fc2.bias")
    print(model.fc2.bias)

    print("#############################################################")
    print("#############################################################")


    print("fc1.weight")
    #torch.save(model.fc1.weight,"./results/r"+str(args.neurons))


    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    #####
    #Set the weights of the model as desired. Make some non-differentiable
    #####
    model.fc1.weight.data=torch.normal(mean=0,std=1/np.log(args.d),size=(args.neurons,args.d), requires_grad=True, device=device)
    #model.fc1.weight.data=torch.normal(mean=0,std=1/np.log(args.d)*1.6,size=(args.neurons,args.d), requires_grad=True, device=device)
    #model.fc1.weight.data=torch.normal(mean=0,std=1/np.sqrt(np.log(args.d)),size=(args.neurons,args.d), requires_grad=True, device=device)
    #model.fc1.weight.data=torch.normal(mean=0,std=1,size=(args.neurons,args.d), requires_grad=True, device=device)
    model.fc1.weight.requires_grad = True
    model.fc1.bias.data = -torch.ones(size=(args.neurons,), requires_grad=False, device=device)
    model.fc1.bias.requires_grad = False

    model.fc2.weight.data = torch.ones(size=(2,args.P*args.neurons),requires_grad=True,device=device)
    model.fc2.weight.requires_grad = True
    model.fc2.bias.data = torch.zeros(size=(2,),requires_grad=False,device=device)
    model.fc2.bias.requires_grad = False

    print("fc1.weight")
    print(model.fc1.weight)
    print("fc1.bias")
    print(model.fc1.bias)
    print("fc2.weight")
    print(model.fc2.weight)
    print("fc2.bias")
    print(model.fc2.bias)

    print("#############################################################")
    print("#############################################################")


    criterion = nn.CrossEntropyLoss(reduction="mean").cuda()#BCELoss()#nn.CrossEntropyLoss(reduction='mean').cuda()
    
    optim_hparams = {
        'initial_lr' : args.lr, 
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay
    }
    
    #SGD
    #optimizer = train_util.create_optimizer(model,args.optim_choice,
    #    optim_hparams)

    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[args.frst_ann,args.snd_ann], gamma=0.1)
    
    #Normalized SGD
    optimizer= normalized_gradient(model.parameters(), lr= args.lr, momentum= args.momentum, weight_decay= args.weight_decay)


    test_tab=[]
    train_tab=[]

    count=0

    for epoch in range(args.start_epoch, args.epochs):
        #lr= train_util.adjust_lr(optimizer,args.optim_choice,
        #        epoch + 1,
        #        lr_hparams)
        
        print("Epoch" + str(epoch))
        for param_group in optimizer.param_groups:
            print("LR: "+str(param_group['lr']))
            print("mom: "+str(param_group['momentum']))
        
        
        print("training dataset")
        train_loss, train_acc = train_util.train_loop(
            train_loader,
            model,
            criterion,
            args.optim_choice,
            optimizer,
            epoch,
            device)

        print("#############################################################")
        print("#############################################################")

        print("test dataset")

        test_acc = train_util.validate(
            test_loader,
            model,
            criterion,
            epoch,
            device)

        ##Adjust the step size. If commented out, the same step size throughout the whole training
        #scheduler.step()

        print("#############################################################")
        print("#############################################################")

        #if count<=100:
        #    count+=1
        #    print("fc1.weight")
        #    print(model.fc1.weight)
        #    print("fc1.bias")
        #    print(model.fc1.bias)
        #    print("fc2.weight")
        #    print(model.fc2.weight)
        #    print("fc2.bias")
        #    print(model.fc2.bias)    

        print('Accuracy of the model on the test images: {} %'.format(test_acc))
        print('Accuracy of the model on the train images: {} %'.format(train_acc))
        
        train_tab.append(train_loss)
        test_tab.append(test_acc)
        if np.isnan(test_acc) or test_acc==np.inf:
           test_acc=0
           break
        if np.isnan(train_loss) or train_loss==np.inf:
           train_loss=0
           break
        #print("###################")
        #print("###################")
        #print("fc2.weight")
        #print(model.fc2.weight)
        #W=torch.load("data/W")
        
        #Check the neurons activated for W
        A=model.fc1.weight.cpu()
        Z=torch.matmul(A,W)
        o=torch.gt(Z, 1)
        o=torch.sum(o,0)
        print(o)


        #print("###################")
        #print("###################")

    
    torch.set_printoptions(threshold=10_000)


    #print("###########")
    #print("fc1.bias")
    #print(model.fc1.bias)
    #print("fc2.weight")
    #print(model.fc2.weight)
    #print("fc2.bias")
    #print(model.fc2.bias)

    #print("##################")
    #print("save")
    #print("results/r"+str(args.neurons)+"_"+str(args.seed))
    #torch.save(model.fc1.weight,"./results/r"+str(args.neurons)+"_"+str(args.seed))
    print("###################")
    print("###################")
    print("features")
    #W=torch.load("results/W") 
    print(W)

    print("###################")
    print("###################")

    A=model.fc1.weight.cpu()

    #print(A)

    d={}
    for i in range(args.k*2):
        d[i]=[]

    #W=torch.cat((W,-W),dim=1)
    
    neur=[]
    Z=torch.matmul(A,W)
    o=torch.gt(Z, 1)
    count=0
    for i in range(args.neurons):
        for j in range(args.k*2):
            if count==0:
                print(o[i][j])
                print(o[i][j]==True)
                print("###################")
                print("###################")
                count=1
            if o[i][j]==True:
                d[j].append(i)
                if not i in neur:
                    neur.append(i)
    
    print("Activated neurons for W")
    print(d)

    #Check the correlation between the activated neuron and the feature
    print("###################")
    print("NEURON WEIGHTS")
    for key in d:
        if d[key]:
            print("feature") 
            print(key)
            print("###################")
            for j in d[key]:
                print("Neuron index and vector")
                print(j)
                print(A[j])
                print("###################")
                print(np.corrcoef(torch.cat((A[j].reshape(1,-1),W[:,key].reshape(1,-1)),dim=0).detach().numpy())) 
                print("###################")

    print("###################")
    print("###################")

    #Print the weights of the activated neurons
    model.eval()
    with torch.no_grad():
        first,second=model.fc2.weight[0],model.fc2.weight[1]
    first,second =first.reshape(args.P,args.neurons), second.reshape(args.P,args.neurons) 
    for i in neur:
        print(i)
        print("FIRST")
        print(first[:,i]) 
        print("###################")
        print("SECOND")
        print(second[:,i]) 
        print("###################")
        print("###################") 

    print("###################")
    print("###################")

  

    X_train=train_loader.dataset.X
    X_train=X_train.to(device)
    #with torch.no_grad():
    #    print(model(X_train)[:2])
    #    out1=F.relu(model.fc1(X_train[:2]))
    #    print(torch.count_nonzero(out1,(0,1))) 
    #    #print(model.fc2.weight)


    #print(acctrain_tab)    
                
    print("\n")    
    print("Final accuracy: {}".format(test_acc))
    print("Train accuracy: {}".format(train_acc)) 
    print("Train Loss: {}".format(train_loss))  
    print("Seed: {}".format(args.seed))
    print("Number of training datapoints: {}".format(args.m_train))
    print("Number of test datapoints: {}".format(args.m_test))
    print("Dimension: {}".format(args.d))
    print("Number of patches: {}".format(args.P))
    print("Number of neurons: {}".format(args.neurons))    
    print("Optimization algorithm: {}".format(args.optim_choice))#OPT))
    print("LR: {}; B: {}; M: {}; 1st anneal: {}; 2nd anneal: {}; WD: {}".format(args.lr,\
              args.train_batch_size, args.momentum, args.frst_ann, args.snd_ann,\
              args.weight_decay))        
    print(torch.sum(o,0))          
    #print(args.arch)
    #print(args.dataset)
    #print(args.save)
    
    #if args.save=="True":
    #    curr_dir=os.getcwd()
        
    #    save_dir_res="{}/results".format(curr_dir)
        
    #    if not os.path.exists(save_dir_res):
    #        os.makedirs(save_dir_res)
            

    #    str_save="OPT_{}_SEED{}".format(
    #                      args.optim_choice ,str(args.seed))
    #    trainstr_save="{}/train_{}.npy".format(save_dir_res,str_save)
    #    teststr_save="{}/test_{}.npy".format(save_dir_res,str_save)
    #    np.save(trainstr_save,np.array(train_tab))
    #    np.save(teststr_save,np.array(test_tab))
  
            
main()
