import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from adaptive import BaseModel
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.metrics import accuracy_score

import sys
sys.path.append('../')
sys.path.append('../../')
from data import Flatten, ColumnSelector, get_xy
from baselines import DifferentiableSelector, ConcreteMask

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()
    
    # Load train dataset, split into train/val
    mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,
                          transform=transforms.Compose([transforms.ToTensor(), Flatten()]))
    np.random.seed(0)
    val_inds = np.sort(np.random.choice(len(mnist_dataset), size=10000, replace=False))
    train_inds = np.setdiff1d(np.arange(len(mnist_dataset)), val_inds)
    train_dataset = torch.utils.data.Subset(mnist_dataset, train_inds)
    val_dataset = torch.utils.data.Subset(mnist_dataset, val_inds)
    d_in = 784
    d_out = 10
    
    # Load test dataset
    test_dataset = MNIST('/tmp/mnist/', download=True, train=False,
                         transform=transforms.Compose([transforms.ToTensor(), Flatten()]))

    # Prepare to train models with feature subsets
    feature_num_list = list(range(5, 35, 5)) + list(range(40, 110, 10))
    print(feature_num_list)
    acc_dict = {}
    features_input = {}

    for num in feature_num_list:
        # Set up architecture
        device = torch.device('cuda', args.gpu)
        hidden = 512
        dropout = 0.3
        selector_layer = ConcreteMask(d_in, num)
        model = nn.Sequential(
            nn.Linear(d_in, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, d_out))

        # Train differentiable selector
        diff_selector = DifferentiableSelector(model, selector_layer).to(device)
        diff_selector.fit(train_dataset,
                          val_dataset,
                          mbsize=128,
                          lr=1e-3,
                          nepochs=250,
                          loss_fn=nn.CrossEntropyLoss(),
                          patience=5,
                          start_temp=1.0,
                          end_temp=0.1,
                          temp_steps=5,
                          verbose=True)

        # Extract top features
        logits = selector_layer.logits.cpu().data.numpy()
        selected_features = np.sort(logits.argmax(axis=1))
        assert len(np.unique(selected_features)) == num

        # Prepare module to mask all but top features
        inds = torch.tensor(np.isin(np.arange(d_in), selected_features), device=device)
        column_selector = ColumnSelector(inds)
        features_input[num] = selected_features

        # For tracking best model
        best_model = None
        best_loss = np.inf

        for _ in range(args.num_restarts):
            # Set up architecture
            hidden = 512
            dropout = 0.3
            model = nn.Sequential(
                column_selector,
                nn.Linear(num, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, d_out))

            # Train model
            basemodel = BaseModel(model).to(device)
            basemodel.fit(train_dataset,
                          val_dataset,
                          mbsize=128,
                          lr=1e-3,
                          nepochs=250,
                          loss_fn=nn.CrossEntropyLoss(),
                          patience=5,
                          verbose=False)
            
            # Check if best
            val_loss = basemodel.evaluate(val_dataset, nn.CrossEntropyLoss(), 1024)
            if val_loss < best_loss:
                best_model = model
                best_loss = val_loss
                
        # Get best model
        model = best_model

        # Calculate test set performance
        model.eval()
        x, y = get_xy(test_dataset)
        pred = model(x.to(device)).softmax(dim=1).cpu().data.numpy()
        test_acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'Acc = {test_acc:.4f}')
        acc_dict[num] = test_acc

    # Save results
    results_dict = {
        'acc': acc_dict,
        'features': features_input
    }
    with open('results/differentiable_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
