import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import MNIST
from models import BaseModel
from data import DenseDatasetSelected, data_split, get_x, get_y
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
import pickle
import pandas as pd

from mlxtend.feature_selection import SequentialFeatureSelector as SFS
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler


data_name = 'prehospital_resuscitation'
data_dir = '../datasets/'+data_name+'.csv'
dataset = DenseDatasetSelected(data_dir)
print(dataset.X.shape)
train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)
X_train = pd.DataFrame(get_x(train_dataset), columns=dataset.features)
y_train = get_y(train_dataset)
X_val = pd.DataFrame(get_x(val_dataset), columns=dataset.features)
y_val = get_y(val_dataset)
X_test = pd.DataFrame(get_x(test_dataset), columns=dataset.features)
y_test = get_y(test_dataset)
multi_label = (len(np.unique(dataset.Y)) > 2)
feature_list = dataset.features

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
model = LogisticRegression(multi_class='multinomial')
sfs = SFS(model, k_features=X_train.shape[1], forward=True, verbose=1, scoring='roc_auc_ovo', n_jobs=1)
sfs.fit(X_train_scaled, y_train)
pickle.dump(sfs.subsets_, open('../UCI_datasets/results/'+data_name+'_forward_LR_selected_features.pkl', 'wb'))

feature_num_list = [dataset.X.shape[1]]+[i for i in range(15, dataset.X.shape[1], 5)][::-1]+[10,9,8,7,6,5,4,3,2,1]
feature_num_list = feature_num_list[::-1]
print(feature_num_list)

res_dict_auroc = {}
res_dict_acc = {}
features_input = {}

for feature_num in feature_num_list:
    print('# features: ', feature_num)
    selected_features = list(sfs.subsets_[feature_num]['feature_idx'])
    dataset = DenseDatasetSelected(data_dir, np.array(feature_list)[selected_features])
    features_input[feature_num] = selected_features
    print(dataset.X.shape)
    train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)

    d_in = dataset.X.shape[1]
    d_out = len(np.unique(np.array(dataset.Y)))
    model = nn.Sequential(
        nn.Linear(d_in, d_in * 2),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(d_in * 2, d_in),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(d_in, d_out))
    basemodel = BaseModel(model)
    loss_fn = nn.CrossEntropyLoss()
    basemodel.fit(train_dataset,
            val_dataset,
            lr=0.001, 
            nepochs=1000,
            loss_fn=loss_fn,
            mbsize=256,
            verbose=False,)
    test_loss, y_pre, y_true = basemodel.evaluate(test_dataset, loss_fn, 1024)
    if multi_label:
        test_auroc = roc_auc_score(y_true.cpu().numpy(), nn.functional.softmax(y_pre, dim=1).cpu().numpy(), average='macro', multi_class = 'ovo')
    else:
        test_auroc = roc_auc_score(y_true.cpu().numpy(), nn.functional.softmax(y_pre, dim=1).cpu().numpy()[:,1])
    test_acc = accuracy_score(y_true.cpu().numpy(), y_pre.cpu().numpy().argmax(axis=1))
    print('AUROC: ', test_auroc)
    print('Accuracy: ', test_acc)
    res_dict_auroc[feature_num] = test_auroc
    res_dict_acc[feature_num] = test_acc

result_dict = {}
result_dict['res_dict_auroc'] = res_dict_auroc
result_dict['res_dict_acc'] = res_dict_acc

pickle.dump(features_input, open('../UCI_datasets/results/'+data_name+'_forward_nn_selected_features.pkl', 'wb'))
pickle.dump(result_dict, open('../UCI_datasets/results/'+data_name+'_forward_nn_results.pkl', 'wb'))