import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from adaptive import BaseModel
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score, roc_auc_score

import sys
sys.path.append('../')
sys.path.append('../../')
from data import DenseDatasetSelected, data_split, get_x, get_xy
from baselines import DifferentiableSelector, ConcreteMask

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load dataset
    dataset = DenseDatasetSelected('../../datasets/spam.csv')
    d_in = dataset.X.shape[1]
    d_out = len(np.unique(dataset.Y))
    all_features = np.array(dataset.features)
    print(f'X.shape = {dataset.X.shape}')

    # Split dataset
    train_dataset_full, val_dataset_full, _ = data_split(dataset, random_state=0)
    
    # For normalizing
    x = get_x(train_dataset_full)
    mean = x.mean(axis=0)
    dataset.X = dataset.X - mean

    # Prepare to train models with feature subsets
    feature_num_list = list(range(1, 11)) + list(range(15, 30, 5))
    print(feature_num_list)
    auroc_dict = {}
    acc_dict = {}
    features_input = {}

    for num in feature_num_list:
        # Set up architecture
        device = torch.device('cuda', args.gpu)
        hidden = 128
        dropout = 0.3
        selector_layer = ConcreteMask(d_in, num)
        model = nn.Sequential(
            nn.Linear(d_in, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, d_out))

        # Train differentiable selector
        diff_selector = DifferentiableSelector(model, selector_layer).to(device)
        diff_selector.fit(train_dataset_full,
                          val_dataset_full,
                          mbsize=32,
                          lr=1e-3,
                          nepochs=250,
                          loss_fn=nn.CrossEntropyLoss(),
                          # val_loss_fn=AUROC(num_classes=2),
                          # val_loss_mode='max',
                          patience=5,
                          verbose=False)

        # Extract top features
        logits = selector_layer.logits.cpu().data.numpy()
        selected_features = np.sort(logits.argmax(axis=1))
        assert len(np.unique(selected_features)) == num

        # Prepare dataset with top features
        print(f'# features: {num}')
        dataset = DenseDatasetSelected('../../datasets/spam.csv', all_features[selected_features].tolist())
        features_input[num] = selected_features
        train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)

        # For tracking best model
        best_model = None
        best_loss = np.inf

        for _ in range(args.num_restarts):
            # Set up architecture
            hidden = 128
            dropout = 0.3
            model = nn.Sequential(
                nn.Linear(num, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, d_out))

            # Train model
            basemodel = BaseModel(model).to(device)
            basemodel.fit(train_dataset,
                        val_dataset,
                        mbsize=32,
                        lr=1e-3,
                        nepochs=250,
                        loss_fn=nn.CrossEntropyLoss(),
                        # val_loss_fn=AUROC(num_classes=2),
                        # val_loss_mode='max',
                        patience=5,
                        verbose=False)
            
            # Check if best
            val_loss = basemodel.evaluate(val_dataset, nn.CrossEntropyLoss(), 1024)
            if val_loss < best_loss:
                best_model = model
                best_loss = val_loss
                
        # Get best model
        model = best_model

        # Calculate test set performance
        model.eval()
        x, y = get_xy(test_dataset)
        pred = model(torch.tensor(x, device=device)).softmax(dim=1).cpu().data.numpy()
        test_auroc = roc_auc_score(y, pred[:, 1])
        test_acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'AUROC = {test_auroc:.4f}, Acc = {test_acc:.4f}')
        auroc_dict[num] = test_auroc
        acc_dict[num] = test_acc

    # Save results
    results_dict = {
        'auroc': auroc_dict,
        'acc': acc_dict,
        'features': features_input
    }
    with open('results/differentiable_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
