import argparse
import os


import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import train_util
import data_util
#import eval_loss_util

import models as models
#from models import vgg16, vgg16_bn, resnet110, resnet110_bn


model_names = sorted(name for name in models.__dict__
	if name.islower() and not name.startswith("__")
	and callable(models.__dict__[name]))


#curr_dir=os.getcwd()+"/new/concat_60/Ridge/Reg/vgg16_bn/different_seed/vgg16_under_d8_bn/"

#with open(curr_dir+'coef', 'rb') as f:
#    a = np.load(f)

##X=[[0 for j in range(3840)] for i in range(512)]
##for i in range(len(a)):
##    X[i%512]=X[i%512]+a[i]

##for i in range(len(X)):
##    X[i]=X[i].tolist()    

##X=np.array(X)
##beta=X

#beta=a

#with open(curr_dir+'intercept', 'rb') as f:
#    b = np.load(f)

##Y=[0 for i in range(512)]
##for i in range(len(b)):
##    Y[i%512]=Y[i%512]+b[i]

##beta_0=Y
#beta_0=b
#beta_0=beta_0.reshape(1,-1)

#class MyEnsemble(nn.Module):
#    def __init__(self, models, n_features):
#        super(MyEnsemble, self).__init__()
#        self.models =models
#        self.overpara = torch.nn.Linear(n_features*len(models),512,bias=True)
#        with torch.no_grad():            
#            self.overpara.weight = torch.nn.Parameter(torch.from_numpy(beta).float())
#            self.overpara.bias = torch.nn.Parameter(torch.from_numpy(beta_0).float())
#        self.overpara.requires_grad_(False)
#        self.classifier =nn.Linear(5120, 10)
#        
#    def forward(self, x):
#        features=[]
#        for model in self.models:
#            model.eval()
#            with torch.no_grad():
#                features.append(model.get_features(x))     
#        con_features = torch.cat(features,dim=1)
#        con_features = self.overpara(con_features)
#        x = self.classifier(con_features)
#        return x



class MyEnsemble(nn.Module):
    def __init__(self, models, n_features):
        super(MyEnsemble, self).__init__()
        self.models =models
        self.classifier =nn.Linear(n_features*len(models), 10)
        
    def forward(self, x):
        features=[]
        for model in self.models:
            model.eval()
            with torch.no_grad():
                features.append(model.get_features(x)) 
            
        con_features=torch.cat(features,dim=1)
        x = self.classifier(con_features)
        return x




lr_scheds= ['wr_default']
parser = argparse.ArgumentParser(description='CIFAR-10 Training')
parser.add_argument('--epochs', default=300, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--frst_ann', '--fa', default=170, type=int,
                    help='first annealing time')
parser.add_argument('--snd_ann', '--sa', default=245, type=int,
                    help='second annealing time')
parser.add_argument('--n_batch_train', '--nbt',  default=128, type=int,
                    help='train mini-batch size (default: 1024)')
parser.add_argument('--n_batch_test', default=100, type=int,
                    help='test mini-batch size (default: 100)')
parser.add_argument('--path_data', default='./data', type=str,
                    help='path to store data')
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer')
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
                    help='initial learning rate')
parser.add_argument('--arch', default="vgg16", type=str,
                    help='choice of architecture')
parser.add_argument('--save', default="False", type=str,
                    help='Save or not')
parser.add_argument('--momentum', '--m', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--print_freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--name', default='CIFAR-10-VGG16', type=str,
                    help='name of experiment')
parser.add_argument('--dataset', choices=["cifar10", "cifar100"], default="cifar10",
                    help='cifar10 or cifar100')
parser.add_argument('--lr_sched', choices=lr_scheds, default='wr_default', 
                    help=' | '.join(lr_scheds))
parser.add_argument('--seed', '-s', default=0, type=int,
                    help='seed (default: 0)')
parser.add_argument('--save_model', default="False", type=str,
                    help='Save Model')
parser.add_argument('--model_random', default="False", type=str,
                    help='Model Random')
parser.add_argument('--n_concat', default="250", type=int,
                    help='Model Random')




n_features={"vgg16_bn":512, "vgg16_under_d2_bn":256, "vgg16_under_d4_bn":128, "vgg16_under_d8_bn":64, "vgg16_under_d16_bn":32,\
           "resnet18_bn":512, "resnet18_under_d2_bn":256, "resnet18_under_d4_bn":128, "resnet18_under_d8_bn":64, "resnet18_under_bn":32}


args = parser.parse_args()

#aug=""
aug="/aug"


for arg in vars(args):
    print(arg, " : ", getattr(args, arg))


if (args.arch=="vgg16_under_d2_bn" or args.arch=="resnet18_under_d2_bn")  and args.n_concat>4: 
    quit()

if (args.arch=="vgg16_under_d4_bn" or args.arch=="resnet18_under_d4_bn") and args.n_concat>16: 
    quit()

if (args.arch=="vgg16_under_d8_bn" or args.arch=="resnet18_under_d8_bn") and args.n_concat>64: 
    quit()




def main(start):
    print(start)
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    #if args.optim_choice=="sgd":
    #    args.n_batch_train=128
        
    #if "m" in args.optim_choice:
    #    args.momentum=0.9
    
    if args.dataset=="cifar100":
        args.path_data="./data/"
    else:
        args.path_data="./data/"
        
        
    train_loader, val_loader = data_util.load_data(args.n_batch_train,
                                                    args.n_batch_test, 
                                                    args.dataset,
                                                    args.path_data)
  

    print("=> creating model '{}'".format(args.arch))
    
    curr_dir=os.getcwd()
    
    models=[]
    paths=[]
    for i in range(start+1,start+args.n_concat+1):
        paths.append(curr_dir + aug + "/"+args.arch+"/"+"result_trained_"+str(i))
	
    for path in paths:
        models.append(torch.load(path))
    
    model = MyEnsemble(models=models, n_features=n_features[args.arch])
    
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    if args.model_random=="True":
        curr_dir=os.getcwd() 
        path=curr_dir + "/"+args.arch
        if not os.path.exists(path): 
           os.makedirs(path)

        torch.save(model, path+"/"+"result_random_"+str(args.seed)) 
        return
   
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
    
    optim_hparams = {
        'initial_lr' : args.lr, 
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay
    }
    
    lr_hparams = {
            'initial_lr' : args.lr, 
            'lr_sched' : args.lr_sched,
            'frst_ann' : args.frst_ann,
            'snd_ann' : args.snd_ann}

    optimizer = train_util.create_optimizer(model,args.optim_choice,
        optim_hparams)

    test_tab=[]
    train_tab=[]


    for epoch in range(args.start_epoch, args.epochs):
        lr= train_util.adjust_lr(optimizer,args.optim_choice,
                epoch + 1,
                lr_hparams)
        
        print("Epoch" + str(epoch))
        for param_group in optimizer.param_groups:
            print("LR: "+str(param_group['lr']))
            print("mom: "+str(param_group['momentum']))
        
        


        train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            args.optim_choice,
            optimizer,
            epoch,
            device)

        val_acc = train_util.validate(
            val_loader,
            model,
            criterion,
            epoch,
            device)
        
        train_tab.append(train_loss)
        test_tab.append(val_acc)


    #return [train_loss.cpu().numpy(),val_acc.cpu().numpy()]    
    return [train_loss,val_acc] 

    if args.n_batch_train==1024:
        OPT_1="GD"
    else: 
        OPT_1="SGD"
        
    for param_group in optimizer.param_groups:
        if param_group['momentum'] !=0:
            OPT_2="M"
            OPT=OPT_1+OPT_2
        else:
            OPT=OPT_1
    
    
    
    print("\n")    
    print("Final accuracy: {}".format(val_acc))
    print("Seed: {}".format(args.seed))
    print("Dataset: {}".format(args.dataset))
    print("Architecture: {}".format(args.arch))
    print("Optimization algorithm: {}".format(OPT))
    print("LR: {}; B: {}; M: {}; 1st anneal: {}; 2nd anneal: {}; WD: {}".format(args.lr,\
        args.n_batch_train, args.momentum, args.frst_ann, args.snd_ann,\
              args.weight_decay))
        
    return val_acc

    if args.save_model=="True":
        curr_dir=os.getcwd()
        path=curr_dir + "/"+args.arch
        
        if not os.path.exists(path): 
            os.makedirs(path)

        torch.save(model, path+"/result_trained_"+str(args.seed))    
        return
    
   
results_test=[]
results_train=[]

size={"vgg16_bn":10, "vgg16_under_d2_bn":50, "vgg16_under_d4_bn":150, "vgg16_under_d8_bn":320, "vgg16_under_d16_bn":508,\
      "resnet18_bn":10, "resnet18_under_d2_bn":50, "resnet18_under_d4_bn":150, "resnet18_under_d8_bn":320, "resnet18_under_bn":508}

count=1       
for freq in range(int(size[args.arch]/args.n_concat)):
    if count==6:
        break    

    tra, tes = main(30+freq*args.n_concat)
    results_train.append(tra)
    results_test.append(tes)    
    #results.append(main(30+freq*args.n_concat))

    curr_dir=os.getcwd()+aug
    curr_dir=curr_dir+"/linear_feature_learning/"+args.arch +"/" + str(args.n_concat)

    if not os.path.exists(curr_dir):
        os.makedirs(curr_dir)

    count=count+1
    path=curr_dir + "/" +"result_test_"+ str(args.lr) + "_" + str(args.weight_decay)+ "_" + str(batch_train) + ".txt"
    print(path)


    with open(path, 'w') as f:
        for i in range(len(results_test)):
            f.write("%s %s \n" % (results_train[i],results_test[i]) ) 
        f.write("%s %s" % (np.average(np.array(results_train)),np.average(np.array(results_test)))) 