import sage
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from adaptive import BaseModel
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score, roc_auc_score

import sys
sys.path.append('../')
from data import DenseDatasetSelected, data_split, get_x, get_y, get_xy

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load dataset
    dataset = DenseDatasetSelected('../../datasets/spam.csv')
    d_in = dataset.X.shape[1]
    d_out = len(np.unique(dataset.Y))
    print(f'X.shape = {dataset.X.shape}')

    # Split dataset
    train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)

    # Set up architecture
    device = torch.device('cuda', args.gpu)
    hidden = 128
    dropout = 0.3
    model = nn.Sequential(
        nn.Linear(d_in, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_out))

    # Train model
    basemodel = BaseModel(model).to(device)
    basemodel.fit(train_dataset,
                  val_dataset,
                  mbsize=32,
                  lr=1e-3,
                  nepochs=250,
                  loss_fn=nn.CrossEntropyLoss(),
                  # val_loss_fn=AUROC(num_classes=2),
                  # val_loss_mode='max',
                  verbose=False)

    # Calculate SAGE values, rank features
    model.eval()
    model_activation = nn.Sequential(model, nn.Softmax(dim=1))
    imputer = sage.MarginalImputer(model_activation, get_x(train_dataset)[:128])
    estimator = sage.PermutationEstimator(imputer, 'cross entropy')
    sage_values = estimator(get_x(val_dataset), get_y(val_dataset), thresh=0.01)
    ranked_features = np.array(dataset.features)[np.argsort(sage_values.values)[::-1]].tolist()

    # Prepare to train models with feature subsets
    feature_num_list = list(range(1, 11)) + list(range(15, d_in, 5)) + [d_in]
    print(feature_num_list)
    auroc_dict = {}
    acc_dict = {}
    features_input = {}

    for num in feature_num_list:
        # Prepare dataset with top features
        print(f'# features: {num}')
        dataset = DenseDatasetSelected('../../datasets/spam.csv', ranked_features[:num])
        features_input[num] = ranked_features[:num]
        train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)

        # For tracking best model
        best_model = None
        best_loss = np.inf

        for _ in range(args.num_restarts):
            # Set up architecture
            hidden = 128
            dropout = 0.3
            model = nn.Sequential(
                nn.Linear(num, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, d_out))

            # Train model
            basemodel = BaseModel(model).to(device)
            basemodel.fit(train_dataset,
                        val_dataset,
                        mbsize=32,
                        lr=1e-3,
                        nepochs=250,
                        loss_fn=nn.CrossEntropyLoss(),
                        # val_loss_fn=AUROC(num_classes=2),
                        # val_loss_mode='max',
                        patience=5,
                        verbose=False)
            
            # Check if best
            val_loss = basemodel.evaluate(val_dataset, nn.CrossEntropyLoss(), 1024)
            if val_loss < best_loss:
                best_model = model
                best_loss = val_loss
                
        # Get best model
        model = best_model

        # Calculate test set performance
        model.eval()
        x, y = get_xy(test_dataset)
        pred = model(torch.tensor(x, device=device)).softmax(dim=1).cpu().data.numpy()
        test_auroc = roc_auc_score(y, pred[:, 1])
        test_acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'AUROC = {test_auroc:.4f}, Acc = {test_acc:.4f}')
        auroc_dict[num] = test_auroc
        acc_dict[num] = test_acc

    # Save results
    results_dict = {
        'auroc': auroc_dict,
        'acc': acc_dict,
        'features': features_input
    }
    with open('results/sage_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
