import sage
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from adaptive import BaseModel
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.metrics import accuracy_score

import sys
sys.path.append('../')
from data import Flatten, ColumnSelector, get_x, get_xy

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load train dataset, split into train/val
    mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,
                          transform=transforms.Compose([transforms.ToTensor(), Flatten()]))
    np.random.seed(0)
    val_inds = np.sort(np.random.choice(len(mnist_dataset), size=10000, replace=False))
    train_inds = np.setdiff1d(np.arange(len(mnist_dataset)), val_inds)
    train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)
    val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)
    d_in = 784
    d_out = 10
    
    # Load test dataset
    test_dataset = MNIST('/tmp/mnist/', download=True, train=False,
                         transform=transforms.Compose([transforms.ToTensor(), Flatten()]))

    # Set up architecture
    device = torch.device('cuda', args.gpu)
    hidden = 512
    dropout = 0.3
    model = nn.Sequential(
        nn.Linear(d_in, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_out))

    # Train model
    basemodel = BaseModel(model).to(device)
    basemodel.fit(train_dataset,
                  val_dataset,
                  mbsize=128,
                  lr=1e-3,
                  nepochs=250,
                  loss_fn=nn.CrossEntropyLoss(),
                  verbose=False)

    # Calculate SAGE values, rank features
    model.eval()
    model_activation = nn.Sequential(model, nn.Softmax(dim=1))
    default_values = get_x(train_dataset).mean(dim=0).numpy()
    imputer = sage.DefaultImputer(model_activation, default_values)
    estimator = sage.PermutationEstimator(imputer, 'cross entropy')
    x, y = get_xy(val_dataset)
    sage_values = estimator(x.numpy(), y.numpy(), thresh=0.01)
    ranked_features = np.argsort(sage_values.values)[::-1].tolist()

    # Prepare to train models with feature subsets
    feature_num_list = list(range(5, 35, 5)) + list(range(40, 110, 10))
    print(feature_num_list)
    acc_dict = {}
    features_input = {}

    for num in feature_num_list:
        # Prepare dataset with top features
        print(f'# features: {num}')
        selected_features = ranked_features[:num]
        inds = torch.tensor(np.isin(np.arange(d_in), selected_features), device=device)
        column_selector = ColumnSelector(inds)
        features_input[num] = selected_features

        # For tracking best model
        best_model = None
        best_loss = np.inf

        for _ in range(args.num_restarts):
            # Set up architecture
            hidden = 512
            dropout = 0.3
            model = nn.Sequential(
                column_selector,
                nn.Linear(num, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, d_out))

            # Train model
            basemodel = BaseModel(model).to(device)
            basemodel.fit(train_dataset,
                          val_dataset,
                          mbsize=128,
                          lr=1e-3,
                          nepochs=250,
                          loss_fn=nn.CrossEntropyLoss(),
                          patience=5,
                          verbose=False)
            
            # Check if best
            val_loss = basemodel.evaluate(val_dataset, nn.CrossEntropyLoss(), 1024)
            if val_loss < best_loss:
                best_model = model
                best_loss = val_loss
                
        # Get best model
        model = best_model

        # Calculate test set performance
        model.eval()
        x, y = get_xy(test_dataset)
        pred = model(x.to(device)).softmax(dim=1).cpu().data.numpy()
        test_acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'Acc = {test_acc:.4f}')
        acc_dict[num] = test_acc

    # Save results
    results_dict = {
        'acc': acc_dict,
        'features': features_input
    }
    with open('results/sage_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
