import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score, roc_auc_score
from adaptive import AdaptiveSelection, MaskLayer, MaskingPretrainer

import sys
sys.path.append('../')
from data import DenseDatasetSelected, data_split, get_xy

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)


if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Load dataset
    dataset = DenseDatasetSelected('../../datasets/diabetes.csv')
    d_in = dataset.X.shape[1]  # 45
    d_out = len(np.unique(dataset.Y))  # 2

    # Split dataset
    train_dataset, val_dataset, test_dataset = data_split(dataset, random_state=0)
    print(f'Train samples = {len(train_dataset)}, val samples = {len(val_dataset)}, test samples = {len(test_dataset)}')

    # Find mean/variance for normalizing
    x, y = get_xy(train_dataset)
    mean = np.mean(x, axis=0)
    std = np.std(y, axis=0)

    # Normalize via the original dataset
    dataset.X = dataset.X - mean

    # Setup
    max_features = 35
    device = torch.device('cuda', args.gpu)

    # Set up architecture
    hidden = 128
    dropout = 0.3

    # Predictor
    predictor = nn.Sequential(
        nn.Linear(d_in * 2, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_out))

    # Selector
    selector = nn.Sequential(
        nn.Linear(d_in * 2, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, hidden),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(hidden, d_in))

    # Tie weights
    # selector[0] = predictor[0]
    # selector[3] = predictor[3]

    # Pretrain predictor
    mask_layer = MaskLayer(append=True)
    pretrain = MaskingPretrainer(predictor, mask_layer).to(device)
    pretrain.fit(train_dataset,
                 val_dataset,
                 mbsize=128,
                 lr=1e-3,
                 nepochs=100,
                 loss_fn=nn.CrossEntropyLoss(),
                 # val_loss_fn=AUROC(num_classes=2),
                 # val_loss_mode='max',
                 patience=5,
                 verbose=False)

    # Train adaptive selection
    gafs = AdaptiveSelection(selector, predictor, mask_layer).to(device)
    gafs.fit(train_dataset,
             val_dataset,
             mbsize=128,
             lr=1e-3,
             nepochs=250,
             max_features=max_features,
             loss_fn=nn.CrossEntropyLoss(),
             # val_loss_fn=AUROC(num_classes=2),
             # val_loss_mode='max',
             patience=5,
             verbose=False)

    # For saving results
    results = {
        'auroc': {},
        'acc': {}
    }

    # Generate results
    num_features = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25]
    x, y = get_xy(test_dataset)
    for num in num_features:
        pred, _, _ = gafs(torch.tensor(x, device=device), max_features=num)
        pred = pred.softmax(dim=1).cpu().data.numpy()
        auroc = roc_auc_score(y, pred, multi_class='ovr')
        acc = accuracy_score(y, pred.argmax(axis=1))
        print(f'Num = {num}, AUROC = {100*auroc:.2f}, Acc = {100*acc:.2f}')
        results['auroc'][num] = auroc
        results['acc'][num] = acc

    with open('results/adaptive_results.pkl', 'wb') as f:
        pickle.dump(results, f)

    # Find most common selections at each step
    top_list = [[]]
    for num in range(d_in):
        x, y = get_xy(val_dataset)
        _, _, m = gafs(torch.tensor(x, device=device), max_features=num + 1)
        p = m.mean(dim=0).cpu().data.numpy()
        top_list.append(np.sort(np.argsort(p)[-(num + 1):]))

    # Save results
    with open('results/adaptive_frequent_selections.pkl', 'wb') as f:
        pickle.dump(top_list, f)
        
    # Save model
    gafs.cpu()
    torch.save(gafs, 'results/adaptive_trained.pt')
