import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from adaptive import MaskLayer
from torchmetrics import Accuracy, AUROC
from sklearn.metrics import accuracy_score, roc_auc_score

# Data imports
import sys
sys.path.append('../')
sys.path.append('../../')
from data import DenseDatasetSelected, data_split, get_x, get_xy
from baselines import IterativeSelector, UniformSampler

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load dataset
    dataset = DenseDatasetSelected('../../datasets/diabetes.csv')
    d_in = dataset.X.shape[1]
    d_out = len(np.unique(dataset.Y))
    print(f'X.shape = {dataset.X.shape}')

    # Split dataset
    train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)
    
    # For normalizing
    x = get_x(train_dataset)
    mean = x.mean(axis=0)
    dataset.X = dataset.X - mean
    
    # Set up architecture
    device = torch.device('cuda', args.gpu)
    hidden = 128
    dropout = 0.3
    mask_layer = MaskLayer(append=True)
    model = nn.Sequential(
        nn.Linear(d_in * 2, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_out))

    # Set up data sampler
    sampler = UniformSampler(torch.tensor(get_x(train_dataset)))

    # Train model
    iterative = IterativeSelector(model, mask_layer, sampler).to(device)
    iterative.fit(train_dataset,
                  val_dataset,
                  mbsize=128,
                  lr=1e-3,
                  nepochs=100,
                  loss_fn=nn.CrossEntropyLoss(),
                  # val_loss_fn=AUROC(num_classes=2),
                  # val_loss_mode='max',
                  patience=5,
                  verbose=True)

    # For saving results
    results = {
        'auroc': {},
        'acc': {}
    }

    # Generate results
    num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]
    x, y = get_xy(test_dataset)
    for num in num_features:
        pred = iterative(torch.tensor(x, device=device), max_features=num).softmax(dim=1).cpu().data.numpy()
        auroc = roc_auc_score(y, pred, multi_class='ovr')
        acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')
        results['auroc'][num] = auroc
        results['acc'][num] = acc
        
    with open('results/iterative_results.pkl', 'wb') as f:
        pickle.dump(results, f)
