import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.metrics import accuracy_score
from adaptive import AdaptiveSelection, MaskLayer, MaskingPretrainer

import sys
sys.path.append('../')
from data import Flatten, get_xy

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load train dataset, split into train/val
    mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,
                          transform=transforms.Compose([transforms.ToTensor(), Flatten()]))
    np.random.seed(0)
    val_inds = np.sort(np.random.choice(len(mnist_dataset), size=10000, replace=False))
    train_inds = np.setdiff1d(np.arange(len(mnist_dataset)), val_inds)
    train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)
    val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)
    d_in = 784
    d_out = 10
    
    # Load test dataset
    test_dataset = MNIST('/tmp/mnist/', download=True, train=False,
                         transform=transforms.Compose([transforms.ToTensor(), Flatten()]))

    # Setup
    max_features = 50
    device = torch.device('cuda', args.gpu)

    # Set up architecture
    hidden = 512
    dropout = 0.3

    # Predictor
    predictor = nn.Sequential(
        nn.Linear(d_in * 2, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_out))

    # Selector
    selector = nn.Sequential(
        nn.Linear(d_in * 2, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_in))

    # Tie weights
    selector[0] = predictor[0]
    selector[3] = predictor[3]

    # Pretrain predictor
    mask_layer = MaskLayer(append=True)
    pretrain = MaskingPretrainer(predictor, mask_layer).to(device)
    pretrain.fit(train_dataset,
                 val_dataset,
                 mbsize=128,
                 lr=1e-3,
                 nepochs=100,
                 loss_fn=nn.CrossEntropyLoss(),
                 verbose=True)

    # Train adaptive selection
    gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)
    gafs.fit(train_dataset,
             val_dataset,
             mbsize=128,
             lr=1e-3,
             nepochs=250,
             max_features=max_features,
             loss_fn=nn.CrossEntropyLoss(),
             verbose=True)

    # For saving results
    results = {
        'acc': {}
    }

    # Generate results
    num_features = list(range(5, 35, 5)) + list(range(40, 110, 10))
    x, y = get_xy(test_dataset)
    for num in num_features:
        pred, _, _ = gafs(torch.tensor(x, device=device), max_features=num)
        pred = pred.softmax(dim=1).cpu().data.numpy()
        acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'Num = {num}, Acc = {100*acc:.2f}')
        results['acc'][num] = acc

    with open('results/adaptive_results.pkl', 'wb') as f:
        pickle.dump(results, f)

    # Find most common selections at each step
    top_list = [[]]
    for num in range(max_features):
        x, y = get_xy(test_dataset)
        _, _, m = gafs(torch.tensor(x, device=device), max_features=num + 1)
        p = m.mean(dim=0).cpu().data.numpy()
        top_list.append(np.sort(np.argsort(p)[-(num + 1):]))

    # Save results
    with open('results/adaptive_frequent_selections.pkl', 'wb') as f:
        pickle.dump(top_list, f)
        
    # Save model
    gafs.cpu()
    torch.save(gafs, 'results/adaptive_trained.pt')
