import argparse
import os

import sys
sys.path.insert(1, '..')
import numpy as np
import pandas as pd

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import train_util
import math
import random


import datasets
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset




lr_scheds= ['wr_default']

parser = argparse.ArgumentParser(description='CIFAR-10 Training')
parser.add_argument('--epochs', default=100, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--frst_ann', '--fa', default=55, type=int,
                    help='first annealing time')
parser.add_argument('--snd_ann', '--sa', default=82, type=int,
                    help='second annealing time')
parser.add_argument('--n_batch_train', '--nbt',  default=128, type=int,
                    help='train mini-batch size (default: 1024)')
parser.add_argument('--n_batch_test', default=100, type=int,
                    help='test mini-batch size (default: 100)')
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer')
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
                    help='initial learning rate')
parser.add_argument('--d_model', default=128, type=int,
                    help='choice of architecture by dmodel')

parser.add_argument('--momentum', '--m', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')

parser.add_argument('--lr_sched', choices=lr_scheds, default='wr_default', 
                    help=' | '.join(lr_scheds))
parser.add_argument('--seed', '-s', default=0, type=int,
                    help='seed (default: 0)')


parser.add_argument('--n_concat', default="1", type=int,
                    help='concat number')
parser.add_argument('--index', default="1", type=int,
                    help='index number')




class linear_classifier(nn.Module):
    def __init__(self, concat, n_features):
        super(linear_classifier, self).__init__()
        self.classifier =nn.Linear(concat*n_features, 3)
        
    def forward(self, x):
        out = self.classifier(x)
        return out 



class Dataset_PreTrainedFeatures(Dataset):
    def __init__(self, file_paths):
        fets=[]

        for fpath in file_paths:
            df=pd.read_csv(fpath, header=None, index_col=False)
            fets.append(df.iloc[:,1:])
            self.target=df.iloc[:,0]

        self.data = pd.concat(fets, axis=1, ignore_index=True)
        print("Number of features: {val}".format(val=len(self.data.iloc[0,:])))
        print("Data size: {val}".format(val=len(self.data)))


    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        features = self.data.iloc[index, :].astype(np.float32).values
        label = self.target.iloc[index]

        return features, int(label)



def main():

    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))
    
    #set the seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    #prepare the data
    print("=> preparing the data for model {}".format(args.d_model))

    data_folder_path="../features/"+str(args.d_model)+"/"
    data_paths_training=[]
    data_paths_test=[]

    for i in range(1,args.n_concat+1):
        data_paths_training.append( data_folder_path+ str(30+(args.index-1)*args.n_concat+i) +"/" + 'training.csv')
        print(data_folder_path+ str(30+(args.index-1)*args.n_concat+i) +"/" + 'training.csv')
        data_paths_test.append( data_folder_path+ str(30+(args.index-1)*args.n_concat+i)+"/" + 'test.csv')

    print("####################################")
    print("Training Data info")
    train_dataset = Dataset_PreTrainedFeatures(data_paths_training)

    print("####################################")
    print("Test Data info")
    eval_dataset = Dataset_PreTrainedFeatures(data_paths_test) 

    train_loader = DataLoader(
        train_dataset, shuffle=True,  batch_size= args.n_batch_train, num_workers=4)

    val_loader = DataLoader(eval_dataset, shuffle=False, batch_size= args.n_batch_test,  num_workers=4) 

      
    # create the model
    model = linear_classifier(args.n_concat, args.d_model)      
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    #################
    print("Number of Parameters")
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    #################
    
    
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
    
    optim_hparams = {
        'initial_lr' : args.lr, 
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay
    }
    
    lr_hparams = {
            'initial_lr' : args.lr, 
            'lr_sched' : args.lr_sched,
            'frst_ann' : args.frst_ann,
            'snd_ann' : args.snd_ann}

    optimizer = train_util.create_optimizer(model,args.optim_choice,
        optim_hparams)

    test_tab=[]
    train_tab=[]


    for epoch in range(args.start_epoch, args.epochs):
        lr= train_util.adjust_lr(optimizer,args.optim_choice,
                epoch + 1,
                lr_hparams)
        
        print("Epoch" + str(epoch))
        for param_group in optimizer.param_groups:
            print("LR: "+str(param_group['lr']))
            print("mom: "+str(param_group['momentum']))
        
        


        train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            args.optim_choice,
            optimizer,
            epoch,
            device)

        val_acc = train_util.validate(
            val_loader,
            model,
            criterion,
            epoch,
            device)
        
        train_tab.append(train_loss)
        test_tab.append(val_acc)




    if args.n_batch_train==1024:
        OPT_1="GD"
    else: 
        OPT_1="SGD"
        
    for param_group in optimizer.param_groups:
        if param_group['momentum'] !=0:
            OPT_2="M"
            OPT=OPT_1+OPT_2
        else:
            OPT=OPT_1
    
    
    
    print("\n")    
    print("Final accuracy: {}".format(val_acc))
    print("Seed: {}".format(args.seed))
    print("Dataset: {}".format(data_folder_path))
    print("Architecture: {}".format(args.arch))
    print("Optimization algorithm: {}".format(OPT))
    print("LR: {}; B: {}; M: {}; 1st anneal: {}; 2nd anneal: {}; WD: {}".format(args.lr,\
        args.n_batch_train, args.momentum, args.frst_ann, args.snd_ann,\
              args.weight_decay))
        
    
   


    with open(data_folder_path + "/accuracy_{}_{}".format(args.lr ,args.weight_decay ) , 'w') as f:
        f.write("%s %s" % (train_loss, val_acc)) 

main()
