import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from adaptive import BaseModel
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.metrics import accuracy_score
from captum.attr import DeepLift

import sys
sys.path.append('../')
from data import Flatten, ColumnSelector, get_xy


# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load train dataset, split into train/val
    mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,
                          transform=transforms.Compose([transforms.ToTensor(), Flatten()]))
    np.random.seed(0)
    val_inds = np.sort(np.random.choice(len(mnist_dataset), size=10000, replace=False))
    train_inds = np.setdiff1d(np.arange(len(mnist_dataset)), val_inds)
    train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)
    val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)
    d_in = 784
    d_out = 10
    
    # Load test dataset
    test_dataset = MNIST('/tmp/mnist/', download=True, train=False,
                         transform=transforms.Compose([transforms.ToTensor(), Flatten()]))

    # Set up architecture
    device = torch.device('cuda', args.gpu)
    hidden = 512
    dropout = 0.3
    model = nn.Sequential(
        nn.Linear(d_in, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_out))

    # Train model
    basemodel = BaseModel(model).to(device)
    basemodel.fit(train_dataset,
                  val_dataset,
                  mbsize=128,
                  lr=1e-3,
                  nepochs=250,
                  loss_fn=nn.CrossEntropyLoss(),
                  verbose=False)

    # Calculate mean abs IntGrad, rank features
    deeplift = DeepLift(model, multiply_by_inputs=False)
    x, y = get_xy(val_dataset)
    x = x.to(device).requires_grad_(True)
    y = y.to(device)
    attr = deeplift.attribute(x, target=y)
    mean_abs = np.abs(attr.cpu().data.numpy()).mean(axis=0)
    ranked_features = np.argsort(mean_abs)[::-1].tolist()

    # Prepare to train models with feature subsets
    feature_num_list = list(range(5, 35, 5)) + list(range(40, 110, 10))
    print(feature_num_list)
    acc_dict = {}
    features_input = {}

    for num in feature_num_list:
        # Prepare module to mask all but top features
        print(f'# features: {num}')
        selected_features = ranked_features[:num]
        inds = torch.tensor(np.isin(np.arange(d_in), selected_features), device=device)
        column_selector = ColumnSelector(inds)
        features_input[num] = selected_features

        # For tracking best model
        best_model = None
        best_loss = np.inf

        for _ in range(args.num_restarts):
            # Set up architecture
            hidden = 512
            dropout = 0.3
            model = nn.Sequential(
                column_selector,
                nn.Linear(num, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, d_out))

            # Train model
            basemodel = BaseModel(model).to(device)
            basemodel.fit(train_dataset,
                          val_dataset,
                          mbsize=128,
                          lr=1e-3,
                          nepochs=250,
                          loss_fn=nn.CrossEntropyLoss(),
                          patience=5,
                          verbose=False)
            
            # Check if best
            val_loss = basemodel.evaluate(val_dataset, nn.CrossEntropyLoss(), 1024)
            if val_loss < best_loss:
                best_model = model
                best_loss = val_loss
                
        # Get best model
        model = best_model

        # Calculate test set performance
        model.eval()
        x, y = get_xy(test_dataset)
        pred = model(x.to(device)).softmax(dim=1).cpu().data.numpy()
        test_acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'Acc = {test_acc:.4f}')
        acc_dict[num] = test_acc

    # Save results
    results_dict = {
        'acc': acc_dict,
        'features': features_input
    }
    with open('results/deeplift_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
