import torch
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import MNIST
from models import BaseModel
from data import DenseDatasetSelected, data_split, get_x, get_y
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
import pickle
import sage

data_name = 'spam'
data_dir = '../datasets/'+data_name+'.csv'

dataset = DenseDatasetSelected(data_dir)
multi_label = (len(np.unique(dataset.Y)) > 2)
d_in = dataset.X.shape[1]
d_out = len(np.unique(np.array(dataset.Y)))

train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)
model = nn.Sequential(
    nn.Linear(d_in, d_in * 2),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(d_in * 2, d_in),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(d_in, d_out))
basemodel = BaseModel(model)
loss_fn = nn.CrossEntropyLoss()
basemodel.fit(train_dataset,
        val_dataset,
        lr=0.001, 
        nepochs=1000,
        loss_fn=loss_fn,
        mbsize=256,
        verbose=False,)
test_loss, y_pre, y_true = basemodel.evaluate(test_dataset, loss_fn, 1024)
model_activation = nn.Sequential(basemodel.model, nn.Softmax(dim=1))
imputer = sage.MarginalImputer(model_activation, get_x(train_dataset)[:128])
estimator = sage.PermutationEstimator(imputer, 'cross entropy')
sage_values = estimator(get_x(test_dataset), get_y(test_dataset))
ranked_features = np.array(dataset.features)[np.argsort(-np.abs(sage_values.values), axis=0)].tolist()

feature_num_list = [dataset.X.shape[1]]+[i for i in range(15, dataset.X.shape[1], 5)][::-1]+[10,9,8,7,6,5,4,3,2,1]
feature_num_list = feature_num_list[::-1]
print(feature_num_list)

res_dict_auroc = {}
res_dict_acc = {}
features_input = {}

for feature_num in feature_num_list:
    print('# features: ', feature_num)
    dataset = DenseDatasetSelected(data_dir, ranked_features[:feature_num])
    features_input[feature_num] = ranked_features[:feature_num]
    print(dataset.X.shape)
    train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)

    d_in = dataset.X.shape[1]
    d_out = len(np.unique(np.array(dataset.Y)))
    model = nn.Sequential(
        nn.Linear(d_in, d_in * 2),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(d_in * 2, d_in),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(d_in, d_out))
    basemodel = BaseModel(model)
    loss_fn = nn.CrossEntropyLoss()
    basemodel.fit(train_dataset,
            val_dataset,
            lr=0.001, 
            nepochs=1000,
            loss_fn=loss_fn,
            mbsize=256,
            verbose=False,)
    test_loss, y_pre, y_true = basemodel.evaluate(test_dataset, loss_fn, 1024)
    if multi_label:
        test_auroc = roc_auc_score(y_true.cpu().numpy(), nn.functional.softmax(y_pre, dim=1).cpu().numpy(), average='macro', multi_class = 'ovo')
    else:
        test_auroc = roc_auc_score(y_true.cpu().numpy(), nn.functional.softmax(y_pre, dim=1).cpu().numpy()[:,1])
    test_acc = accuracy_score(y_true.cpu().numpy(), y_pre.cpu().numpy().argmax(axis=1))
    print('AUROC: ', test_auroc)
    print('Accuracy: ', test_acc)
    res_dict_auroc[feature_num] = test_auroc
    res_dict_acc[feature_num] = test_acc

result_dict = {}
result_dict['res_dict_auroc'] = res_dict_auroc
result_dict['res_dict_acc'] = res_dict_acc

pickle.dump(features_input, open('../UCI_datasets/results/'+data_name+'_global_nn_sage_selected_features.pkl', 'wb'))
pickle.dump(result_dict, open('../UCI_datasets/results/'+data_name+'_global_nn_sage_results.pkl', 'wb'))