import argparse
import os

import sys
sys.path.insert(1, '..')

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import pandas as pd
import train_util
#import eval_loss_util


from torch.utils.data import DataLoader, Dataset


import models as models
#from models import vgg16, vgg16_bn, resnet110, resnet110_bn


model_names = sorted(name for name in models.__dict__
	if name.islower() and not name.startswith("__")
	and callable(models.__dict__[name]))



lr_scheds= ['wr_default']

parser = argparse.ArgumentParser(description='CIFAR-10 Training')
parser.add_argument('--epochs', default=300, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--frst_ann', '--fa', default=170, type=int,
                    help='first annealing time')
parser.add_argument('--snd_ann', '--sa', default=245, type=int,
                    help='second annealing time')
parser.add_argument('--n_batch_train', '--nbt',  default=128, type=int,
                    help='train mini-batch size (default: 1024)')
parser.add_argument('--n_batch_test', default=100, type=int,
                    help='test mini-batch size (default: 100)')
parser.add_argument('--path_data', default='./data', type=str,
                    help='path to store data')
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer')
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
                    help='initial learning rate')
parser.add_argument('--arch', default="resnet18_bn", type=str,
                    help='choice of architecture')
parser.add_argument('--save', default="False", type=str,
                    help='Save or not')
parser.add_argument('--momentum', '--m', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--print_freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--name', default='CIFAR-10-VGG16', type=str,
                    help='name of experiment')

parser.add_argument('--lr_sched', choices=lr_scheds, default='wr_default', 
                    help=' | '.join(lr_scheds))
parser.add_argument('--seed', '-s', default=0, type=int,
                    help='seed (default: 0)')
parser.add_argument('--save_model', default="False", type=str,
                    help='Save Model')
parser.add_argument('--model_random', default="False", type=str,
                    help='Model Random')

parser.add_argument('--index', default=1, type=int,
                    help='index number')

parser.add_argument('--epsilon_model', default=None, type=str,
                    help='epsilon features')

parser.add_argument('--epsilon_only', default="False", type= str,
                    help='epsilon only or not')


n_features={"vgg16_bn":512, "vgg16_under_d2_bn":256, "vgg16_under_d4_bn":128, "vgg16_under_d8_bn":64, "vgg16_under_d16_bn":32,\
           "resnet18_bn":512, "resnet18_under_d2_bn":256, "resnet18_under_d4_bn":128, "resnet18_under_d8_bn":64, "resnet18_under_d16_bn":32}


n_concat={ "vgg16_bn":1, "vgg16_under_d2_bn":4, "vgg16_under_d4_bn":16, "vgg16_under_d8_bn":64, "vgg16_under_d16_bn":251,\
           "resnet18_bn":1, "resnet18_under_d2_bn":4, "resnet18_under_d4_bn":16, "resnet18_under_d8_bn":64, "resnet18_under_d16_bn":254}


n_concat={ "vgg16_bn":1, "vgg16_under_d2_bn":4, "vgg16_under_d4_bn":16, "vgg16_under_d8_bn":64, "vgg16_under_d16_bn":250,\
           "resnet18_bn":1, "resnet18_under_d2_bn":4, "resnet18_under_d4_bn":16, "resnet18_under_d8_bn":64, "resnet18_under_d16_bn":250}


class linear_classifier(nn.Module):
    def __init__(self, n_arch_features=0, n_epsilon_features=0 ):
        super(linear_classifier, self).__init__()
        print(n_epsilon_features)
        n_features = n_arch_features + n_epsilon_features
        print("n_features")
        print(n_features)
        self.classifier =nn.Linear(n_features, 10)
        
    def forward(self, x):
        out = self.classifier(x)
        return out 



class Dataset_PreTrainedFeatures(Dataset):
    
    def __init__(self, file_paths_feature, file_path_epsilon=None, only_eps=False):
        if file_path_epsilon:
            print(file_paths_feature)
            print(file_path_epsilon)
            features = [] 
            for i,file_path_feature in enumerate(file_paths_feature):
                if not i: 
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False))
                else:
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False).iloc[:,1:])

            if only_eps:
                print("only eps")
                features=[]
                features.append(pd.read_csv(file_paths_feature[0], header=None, index_col=False).iloc[:,0])

            features.append(pd.read_csv(file_path_epsilon, header=None, index_col=False))
            self.data = pd.concat(features, axis=1, ignore_index=True)
        else:
            print(file_paths_feature)
            features = [] 
            for i,file_path_feature in enumerate(file_paths_feature):
                if not i: 
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False))
                else:
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False).iloc[:,1:])
            self.data = pd.concat(features, axis=1, ignore_index=True)
        print("number of features: {val}".format(val=len(self.data.iloc[0,:])-1))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):

        features = self.data.iloc[index, 1:].astype(np.float32).values
        label = self.data.iloc[index, 0]
        

        return features, int(label)


def add_def(arr,s):
    ret=[]
    for a in arr:
        ret.append(a+s)
    return ret


def main():

    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))
    
    #set the seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    #prepare the data
    print("=> creating model '{}'".format(args.arch))
    curr_dir=os.getcwd()

    feature_folder_paths=[]
    
    if args.arch=="resnet18_bn" or args.arch=="vgg16_bn": 
        feature_folder_paths.append(curr_dir+"/random_features/"+args.arch+"/"+str(args.index)+"/")
    else:
        for i in range(30+1, 30+1+n_concat[args.arch]):
            feature_folder_paths.append(curr_dir+"/random_features/"+args.arch+"/"+str(i+(args.index-1)*n_concat[args.arch])+"/")


    if args.arch=="resnet18_bn" or args.arch=="vgg16_bn": 
        epsilon_folder_path=curr_dir+"/predicted_features/" +args.arch + "/"+args.epsilon_model + "/under_epsilon_"+str(args.index) if args.epsilon_model else None
    else:
        epsilon_folder_path=curr_dir+"/predicted_features/" + args.epsilon_model + "/"+args.arch + "/over_epsilon_"+str(args.index) if args.epsilon_model else None



    if args.epsilon_only=="True":
        print("only eps")
        print(epsilon_folder_path)
        feature_folder_paths=[curr_dir+"/features/resnet18_bn/1/"]
        trainset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'training.csv'), epsilon_folder_path + '_training.csv', True)
        testset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'test.csv'), epsilon_folder_path + '_test.csv', True)        
    elif epsilon_folder_path:
        print("features and eps")
        print(epsilon_folder_path)
        print(feature_folder_paths)
        trainset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'training.csv'), epsilon_folder_path + '_training.csv')
        testset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'test.csv'), epsilon_folder_path + '_test.csv')
    else:
        print("only feats")
        print(feature_folder_paths)
        trainset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'training.csv'))
        testset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'test.csv'))


    train_loader = torch.utils.data.DataLoader(
                trainset, batch_size=args.n_batch_train, shuffle=True, num_workers=4)

    val_loader = torch.utils.data.DataLoader(
            testset, batch_size=args.n_batch_test, shuffle=False, num_workers=4)  

      
    feat_cons=0 if args.epsilon_only=="True" else 1
    # set the model
    model = linear_classifier(n_arch_features= n_features[args.arch]*n_concat[args.arch]* feat_cons,n_epsilon_features= n_features [args.epsilon_model]*n_concat[args.epsilon_model]) if args.epsilon_model else linear_classifier(n_arch_features= n_features[args.arch]*n_concat[args.arch])

      
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
    
    optim_hparams = {
        'initial_lr' : args.lr, 
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay
    }
    
    lr_hparams = {
            'initial_lr' : args.lr, 
            'lr_sched' : args.lr_sched,
            'frst_ann' : args.frst_ann,
            'snd_ann' : args.snd_ann}

    optimizer = train_util.create_optimizer(model,args.optim_choice,
        optim_hparams)

    test_tab=[]
    train_tab=[]


    for epoch in range(args.start_epoch, args.epochs):
        lr= train_util.adjust_lr(optimizer,args.optim_choice,
                epoch + 1,
                lr_hparams)
        
        print("Epoch" + str(epoch))
        for param_group in optimizer.param_groups:
            print("LR: "+str(param_group['lr']))
            print("mom: "+str(param_group['momentum']))
        
        


        train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            args.optim_choice,
            optimizer,
            epoch,
            device)

        val_acc = train_util.validate(
            val_loader,
            model,
            criterion,
            epoch,
            device)
        
        train_tab.append(train_loss)
        test_tab.append(val_acc)




    if args.n_batch_train==1024:
        OPT_1="GD"
    else: 
        OPT_1="SGD"
        
    for param_group in optimizer.param_groups:
        if param_group['momentum'] !=0:
            OPT_2="M"
            OPT=OPT_1+OPT_2
        else:
            OPT=OPT_1
    
    
    
    print("\n")    
    print("Final accuracy: {}".format(val_acc))
    print("Seed: {}".format(args.seed))
    print("Dataset: {}".format(data_folder_path))
    print("Architecture: {}".format(args.arch))
    print("Optimization algorithm: {}".format(OPT))
    print("LR: {}; B: {}; M: {}; 1st anneal: {}; 2nd anneal: {}; WD: {}".format(args.lr,\
        args.n_batch_train, args.momentum, args.frst_ann, args.snd_ann,\
              args.weight_decay))
        
    
   


    with open(data_folder_path + "/accuracy_{}_{}".format(args.lr ,args.weight_decay ) , 'w') as f:
        f.write("%s %s" % (train_loss, val_acc)) 

main()