import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from skorch import NeuralNetClassifier
from skorch.helper import predefined_split
from skorch.callbacks import LRScheduler, EarlyStopping

from mnist_auto_aug.models import LeNet1
from mnist_auto_aug.dataset import make_datasets


def test(model, device, test_loader, valid=False):
    """ Eval over desired set
    """
    set_name = 'Test set'
    if valid:
        set_name = 'Validation set'
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)

    print(
        '\n{}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            set_name, test_loss, correct, len(test_loader.dataset),
            test_acc
        )
    )
    return test_loss, test_acc


def fit_and_predict(
    model,
    train_set,
    valid_set,
    test_loader,
    epochs,
    model_params=None,
    callbacks=None,
    return_valid_perf=False,
    **kwargs
):
    """Train models on train set and use it to predict target probabilities on
    test set

    Parameters
    ----------
    model : torch.nn.Module
        Model to train.
    train_set : torch.util.data.Dataset
        Dataset to train on.
    valid_set : torch.util.data.Dataset
        Validation set (used for EarlyStopping only).
    test_set : torch.util.data.Dataset
        Dataset to use to evaluate the trained model.
    epochs : int,
        Number of training epochs.
    model_params : dict | None, optional
        Parameters to pass to skorch.NeuralNetClassifier class. Defaults to
        None.
    callbacks : list, optional
        List of of skorch or pytorch callbacks to use. Defaults to None.
    """
    if model_params is None:
        model_params = {}

    classifier_params = {
        'module': model,
        'train_split': predefined_split(valid_set),
        'callbacks': callbacks,
        # Not useful here I think
        # 'iterator_train__multiprocessing_context': 'fork'
    }
    classifier_params.update(model_params)
    print(">>> Created NeuralNetClassifier with the following parameters:")
    print(model_params)

    clf = NeuralNetClassifier(**classifier_params)

    # Model training for a specified number of epochs. `y` is None as
    # it is already supplied in the dataset.
    clf.fit(train_set, y=None, epochs=epochs, **kwargs)

    test_perf = test(clf.module, clf.device, test_loader)
    if return_valid_perf:
        valid_loader = torch.utils.data.DataLoader(
                valid_set, batch_size=model_params.get('test_batch_size', 64))
        valid_perf = test(clf.module, clf.device, valid_loader, valid=True)
        return test_perf, valid_perf
    return test_perf


def _base_prepare(args):
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    print(f">>> Seeding pytorch global generator with: {args.seed}")
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    #if use_cuda:
    num_workers = 1
    pin_memory = True
    shuffle = False

    model = LeNet1().to(device)

    epochs = args.epochs if not args.dry_run else 1
    return model, epochs, device, use_cuda, num_workers, pin_memory, shuffle


def prepare_skorch_training(args):
    """ Prepares model, loaders, optimizer, scheduler, epochs and device for
    skorch training
    """
    (
        model, epochs, device, use_cuda, num_workers, pin_memory, shuffle
    ) = _base_prepare(args)

    classes_to_keep = None
    if not args.use_all_classes:
        classes_to_keep = [8, 4, 6, 9]

    train_set, valid_set, test_set = make_datasets(
        '../data', classes_to_keep=classes_to_keep, random_state=args.seed
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.test_batch_size,
    )
    data = (train_set, valid_set, test_loader)

    epochs = args.epochs if not args.dry_run else 1
    callbacks = [
        ('lr_scheduler',
         LRScheduler(policy=StepLR, step_size=1, gamma=args.gamma)),
        ('earlystop',
         EarlyStopping(monitor='valid_acc', patience=5, lower_is_better=False))
    ]

    model_params = {
        'optimizer': optim.Adam,
        'lr': args.lr,
        'batch_size': args.batch_size,
        'device': device,
    }
    if use_cuda:
        model_params.update({
            'iterator_train__num_workers': num_workers,
            'iterator_valid__num_workers': num_workers,
            'iterator_train__pin_memory': pin_memory,
            'iterator_valid__pin_memory': pin_memory,
            'iterator_train__shuffle': shuffle,
            'iterator_valid__shuffle': shuffle,
        })
    return model, data, model_params, callbacks, epochs, classes_to_keep
