# Libraries
import numpy as np
import torch
import torch.nn as nn
import argparse 
from sklearn.linear_model import LogisticRegressionCV
import pandas as pd 
import json 
import sys 

# Custom Imports
sys.path.append('./utils')
from train_utils import *

#Arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_name', type=str)

args = parser.parse_args()

#Paths
save_data_path_train = '../data/' + args.dataset_name + '_TRAIN/'
save_data_path_test = '../data/' + args.dataset_name + '_TEST/'
save_model_path = '../results/trained_models'
save_results_path = '../results/results_csv/'
save_final_results_path = '../results/final_results_csv'

#Load y
y_train = torch.load(save_data_path_train + 'y_tensor.pt').to(torch.float32)
y_test = torch.load(save_data_path_test + 'y_tensor.pt').to(torch.float32)


test_for_max_eval, version = find_best_accu_eval(args.dataset_name, save_results_path)

#Load ngrams
ngrams_low = torch.load(save_model_path + '/ngrams_bot_' + args.dataset_name + '_' + str(version) + '.pt')
ngrams_mid = torch.load(save_model_path + '/ngrams_mid_' + args.dataset_name + '_' + str(version) + '.pt')
ngrams = torch.cat([ngrams_low, ngrams_mid], dim=1)


#Make permutations
ngrams_train = ngrams[:len(y_train)].numpy()
ngrams_test = ngrams[len(y_train):].numpy()

permutation_train = np.random.permutation(len(y_train))
permutation_test = np.random.permutation(len(y_test))

ngrams_shuffle_train = ngrams_train[permutation_train]
ngrams_shuffle_test = ngrams_test[permutation_test]

y_train_shuffle = y_train[permutation_train]
y_test_shuffle = y_test[permutation_test]


#Assign X_train and X_test
X_train = ngrams_shuffle_train
X_test = ngrams_shuffle_test

y_real_train = y_train_shuffle
y_real_test = y_test_shuffle


#Classification

list_Accuracy_Train = list()
list_Accuracy_Test = list()

list_l1_ratio = [0., 0.10, 0.25, 0.50, 0.75, 0.90, 0.98]
list_geomspace_C = np.geomspace(1e-4, 1e6, num=100, endpoint=False)

clf = LogisticRegressionCV(max_iter=1000000, 
                           penalty='elasticnet', 
                           multi_class='ovr', 
                           solver='saga', 
                           l1_ratios=list_l1_ratio, 
                           Cs=list_geomspace_C, n_jobs=-1).fit(X_train, y_real_train)

learner = clf.predict(X_train)
prediction = clf.predict(X_test)

Accuracy_Test = sum(abs(prediction == y_real_test.numpy())) / len(y_real_test.numpy())
Accuracy_Train =  sum(abs(learner == y_real_train.numpy())) / len(y_real_train.numpy())

list_Accuracy_Train.append(Accuracy_Train)
list_Accuracy_Test.append(Accuracy_Test)

df_results = pd.DataFrame(list(zip(list_Accuracy_Train, list_Accuracy_Test)), columns =['Accuracy_Train', 'Accuracy_Test'])
df_results.to_csv(save_final_results_path + '/results_' + args.dataset_name + '.csv')


