from sys import argv 
import pandas as pd
import argparse
import json
import logging
import math
import os

import sys
sys.path.insert(1, '..')

import random
from pathlib import Path

import datasets
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import train_util

from torch.utils.data import DataLoader, Dataset


import evaluate
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version

from bert2 import Bert

logger = get_logger(__name__)


from sklearn.linear_model import RidgeCV
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn import metrics


import time


lr_scheds= ['wr_default']

parser = argparse.ArgumentParser(description='Linear Classifier Training')
parser.add_argument('--epochs', default=40, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--frst_ann', '--fa', default=170, type=int,
                    help='first annealing time')
parser.add_argument('--snd_ann', '--sa', default=245, type=int,
                    help='second annealing time')
parser.add_argument('--n_batch_train', '--nbt',  default=128, type=int,
                    help='train mini-batch size (default: 1024)')
parser.add_argument('--n_batch_test', default=100, type=int,
                    help='test mini-batch size (default: 100)')
parser.add_argument('--path_data', default='./data', type=str,
                    help='path to store data')
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer', choices=["sgd","AdamW"])
parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
                    help='initial learning rate')
parser.add_argument('--arch', default="128", type=str,
                    help='choice of architecture')
parser.add_argument('--save', default="False", type=str,
                    help='Save or not')
parser.add_argument('--momentum', '--m', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--print_freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--name', default='MNLI/linear-classfier', type=str,
                    help='name of experiment')
parser.add_argument(
        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument('--lr_sched', choices=lr_scheds, default='wr_default', 
                    help=' | '.join(lr_scheds))
parser.add_argument('--seed', '-s', default=0, type=int,
                    help='seed (default: 0)')
parser.add_argument('--save_model', default="False", type=str,
                    help='Save Model')
parser.add_argument('--model_random', default="False", type=str,
                    help='Model Random')

parser.add_argument('--index', default=1, type=int,
                    help='index number')

parser.add_argument('--epsilon_model', default=None, type=str,
                    help='epsilon features')

parser.add_argument('--epsilon_only', default="False", type= str,
                    help='epsilon only or not')



n_concat={ "128":1,"64":4, "32":15, "16":59}


class linear_classifier(nn.Module):
    def __init__(self, n_arch_features=0, n_epsilon_features=0 ):
        super(linear_classifier, self).__init__()
        print(n_epsilon_features)
        n_features = n_arch_features + n_epsilon_features
        print("n_features")
        print(n_features)
        self.classifier =nn.Linear(n_features, 3)
        
    def forward(self, x):
        out = self.classifier(x)
        return out 



class Dataset_PreTrainedFeatures(Dataset):
    
    def __init__(self, file_paths_feature, file_path_epsilon=None, only_eps=False):
        if file_path_epsilon:
            print(file_paths_feature)
            print(file_path_epsilon)
            features = [] 
            for i,file_path_feature in enumerate(file_paths_feature):
                if not i: 
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False))
                else:
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False).iloc[:,1:])

            if only_eps:
                print("only eps")
                features=[]
                features.append(pd.read_csv(file_paths_feature[0], header=None, index_col=False).iloc[:,0])

            features.append(pd.read_csv(file_path_epsilon, header=None, index_col=False))
            self.data = pd.concat(features, axis=1, ignore_index=True)
        else:
            print(file_paths_feature)
            features = [] 
            for i,file_path_feature in enumerate(file_paths_feature):
                if not i: 
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False))
                else:
                    features.append(pd.read_csv(file_path_feature, header=None, index_col=False).iloc[:,1:])
            self.data = pd.concat(features, axis=1, ignore_index=True)
        print("number of features: {val}".format(val=len(self.data.iloc[0,:])-1))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):

        features = self.data.iloc[index, 1:].astype(np.float32).values
        label = self.data.iloc[index, 0]
        

        return features, int(label)


def add_def(arr,s):
    ret=[]
    for a in arr:
        ret.append(a+s)
    return ret


def main():

    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))
    
    #set the seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    #prepare the data
    print("=> creating model '{}'".format(args.arch))
    curr_dir=os.getcwd()

    feature_folder_paths=[]

    if args.arch=="128": 
        feature_folder_paths.append(curr_dir+"/../features/"+args.arch+"/"+str(args.index)+"/")
    else:
        for i in range(30+1, 30+1+n_concat[args.arch]):
            feature_folder_paths.append(curr_dir+"/../features/"+args.arch+"/"+str(i+(args.index-1)*n_concat[args.arch])+"/")
    

    if args.arch=="128": 
        epsilon_folder_path=curr_dir+"/predicted_features/" +args.arch + "/"+args.epsilon_model + "/under_epsilon_"+str(args.index) if args.epsilon_model else None
    else:
        epsilon_folder_path=curr_dir+"/predicted_features/" + args.epsilon_model + "/"+args.arch + "/over_epsilon_"+str(args.index) if args.epsilon_model else None


    if args.epsilon_only=="True":
        print("only eps")
        print(epsilon_folder_path)
        trainset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'training.csv'), epsilon_folder_path + '_training.csv', True)
        testset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'test.csv'), epsilon_folder_path + '_test.csv', True)        
    elif epsilon_folder_path:
        print("features and eps")
        print(epsilon_folder_path)
        print(feature_folder_paths)
        trainset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'training.csv'), epsilon_folder_path + '_training.csv')
        testset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'test.csv'), epsilon_folder_path + '_test.csv')
    else:
        print("only feats")
        print(feature_folder_paths)
        trainset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'training.csv'))
        testset= Dataset_PreTrainedFeatures(add_def(feature_folder_paths, 'test.csv'))


    train_loader = torch.utils.data.DataLoader(
                trainset, batch_size=args.n_batch_train, shuffle=True, num_workers=4)

    val_loader = torch.utils.data.DataLoader(
            testset, batch_size=args.n_batch_test, shuffle=False, num_workers=4)  

    
    feat_cons=0 if args.epsilon_only=="True" else 1
    # set the model
    model = linear_classifier(n_arch_features=int(args.arch)*n_concat[args.arch]* feat_cons,n_epsilon_features= int(args.epsilon_model)*n_concat[args.epsilon_model]) if args.epsilon_model else linear_classifier(n_arch_features= int(args.arch)*n_concat[args.arch])

      
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
    


    if args.optim_choice=="sgd":
        optim_hparams = {
        'initial_lr' : args.lr, 
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay
        }
        lr_hparams = {
            'initial_lr' : args.lr, 
            'lr_sched' : args.lr_sched,
            'frst_ann' : args.frst_ann,
            'snd_ann' : args.snd_ann}
        optimizer = train_util.create_optimizer(model,args.optim_choice,
        optim_hparams)
    else:
        # Split weights in two groups, one with weight decay and the other not.
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.lr)

        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs) 
    test_tab=[]
    train_tab=[]


    for epoch in range(args.start_epoch, args.epochs):
        print("Epoch" + str(epoch))
        if  args.optim_choice=="sgd":
            lr= train_util.adjust_lr(optimizer,args.optim_choice,
                epoch + 1,
                lr_hparams)
            for param_group in optimizer.param_groups:
                print("LR: "+str(param_group['lr']))
                print("mom: "+str(param_group['momentum']))
            lr_scheduler=None  
        


        train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            args.optim_choice,
            optimizer,
            epoch,
            device,
            lr_scheduler)

        val_acc = train_util.validate(
            val_loader,
            model,
            criterion,
            epoch,
            device)
        
        if args.optim_choice=="AdamW":
            if not lr_scheduler:
                raise Exception("AdamW requires defining a lr scheduler")
            lr_scheduler.step()          
  
        train_tab.append(train_loss)
        test_tab.append(val_acc)




    if args.n_batch_train==1024:
        OPT_1="GD"
    else: 
        OPT_1="SGD"
        
    for param_group in optimizer.param_groups:
        if param_group['momentum'] !=0:
            OPT_2="M"
            OPT=OPT_1+OPT_2
        else:
            OPT=OPT_1
    
    
    
    print("\n")    
    print("Final accuracy: {}".format(val_acc))
    print("Seed: {}".format(args.seed))
    print("Dataset: {}".format(data_folder_path))
    print("Architecture: {}".format(args.arch))
    print("Optimization algorithm: {}".format(OPT))
    print("LR: {}; B: {}; M: {}; 1st anneal: {}; 2nd anneal: {}; WD: {}".format(args.lr,\
        args.n_batch_train, args.momentum, args.frst_ann, args.snd_ann,\
              args.weight_decay))
        
    
   


    with open(data_folder_path + "/accuracy_{}_{}".format(args.lr ,args.weight_decay ) , 'w') as f:
        f.write("%s %s" % (train_loss, val_acc)) 

main()
