#!/usr/bin/env python
# coding: utf-8

import argparse
import functools
import os
import pickle
import random
import time

import numpy as np
import seaborn as sns
import sklearn
import sklearn.datasets
import sklearn.model_selection
import sklearn.preprocessing
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch_lr_finder import LRFinder

sns.set()


def get_current_lr(optimizer):
    r"""
    Get the learning rate of an optimizer.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer
        An optimizer, with a learning rate common to all parameter groups.

    Returns
    -------
    float
        The learning rate of the first parameter group.
    """
    return optimizer.param_groups[0]["lr"]


def count_parameters(model, only_trainable=True):
    r"""
    Count the number of (trainable) parameters within a model and its children.
    Arguments:
        model (torch.nn.Model): the model.
        only_trainable (bool, optional): indicates whether the count should be restricted
            to only trainable parameters (ones which require grad), otherwise all
            parameters are included. Default is ``True``.
    Returns:
        int: total number of (trainable) parameters possessed by the model.
    """
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())


def set_rng_state(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def maxd(z, dim):
    return torch.max(z, dim=dim).values


def mind(z, dim):
    return torch.min(z, dim=dim).values


def logistic_and_approx(z, dim):
    return torch.where(
        (z < 0).all(dim=dim),
        z.sum(dim=dim),
        torch.min(z, dim=dim).values,
    )


def logistic_or_approx(z, dim):
    return torch.where(
        (z > 0).all(dim=dim),
        z.sum(dim=dim),
        torch.max(z, dim=dim).values,
    )


def logistic_xnor_approx(z, dim):
    return torch.sign(torch.prod(z, dim=dim)) * torch.min(z.abs(), dim=dim).values


def unroll_k(x, k, d):
    x = x.view(*x.shape[:d], x.shape[d] // k, k, *x.shape[d + 1 :])
    d_new = d if d < 0 else d + 1
    return x, d_new


def squeeze_relu(z, dim):
    # Undo the effects of unroll_k
    # Apply ReLU
    return F.relu(z)


class HOActfun(nn.Module):
    def __init__(self, k=2, dim=1):
        super().__init__()
        self.k = k
        self.dim = dim


class MaxOut(HOActfun):
    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return torch.max(x, dim=d_new).values


class MinOut(HOActfun):
    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return torch.min(x, dim=d_new).values


class SignedGeomeanFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, dim, keepdim, clamp_grad):
        # Save inputs
        ctx.save_for_backward(input)
        ctx.dim = dim + input.ndim if dim < 0 else dim
        ctx.keepdim = keepdim
        ctx.clamp_grad = clamp_grad
        # Compute forward pass
        prods = input.prod(dim=dim, keepdim=keepdim)
        signs = prods.sign()
        output = signs * prods.abs().sqrt()
        return output

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        if not ctx.keepdim:
            grad_output = grad_output.unsqueeze(ctx.dim)

        # Re-compute forward pass
        prods = input.prod(dim=ctx.dim, keepdim=True)
        signs = prods.sign()
        output = signs * prods.abs().sqrt()

        grad_inner = 0.5 * output / input
        # Remove singularities
        grad_inner[input.abs() == 0] = 0
        # Clamp large values
        if ctx.clamp_grad is not None:
            grad_inner = torch.clamp(grad_inner, -ctx.clamp_grad, ctx.clamp_grad)
        # dy/dx = dy/dz * dz/dx
        grad_input = grad_output * grad_inner

        # Need to return None for each non-tensor input to forward
        return grad_input, None, None, None


def signed_geomean(x, dim=1, keepdim=False, clamp_grad=None):
    return SignedGeomeanFunc.apply(x, dim, keepdim, clamp_grad)


class SignedGeomean(HOActfun):
    def __init__(self, clamp_grad=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.clamp_grad = clamp_grad

    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return signed_geomean(x, d_new, clamp_grad=self.clamp_grad)


class AIL_AND(HOActfun):
    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return logistic_and_approx(x, d_new)


class AIL_OR(HOActfun):
    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return logistic_or_approx(x, d_new)


class AIL_XNOR(HOActfun):
    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return logistic_xnor_approx(x, d_new)


class MultiActfunDuplicate(HOActfun):
    def __init__(self, actfuns, **kwargs):
        super().__init__(**kwargs)
        self.actfuns = actfuns

    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        return torch.cat([f(x, d_new) for f in self.actfuns], dim=self.dim)


class MultiActfunPartition(HOActfun):
    def __init__(self, actfuns, **kwargs):
        super().__init__(**kwargs)
        self.actfuns = actfuns

    def forward(self, x):
        x, d_new = unroll_k(x, self.k, self.dim)
        xs = torch.split(x, len(self.actfuns), dim=d_new)
        return torch.cat(
            [f(xi, d_new) for f, xi in zip(self.actfuns, xs)], dim=self.dim
        )


class max_min_duplicate(MultiActfunDuplicate):
    def __init__(self, **kwargs):
        super().__init__([maxd, mind], **kwargs)


class AIL_AND_OR_duplicate(MultiActfunDuplicate):
    def __init__(self, **kwargs):
        super().__init__([logistic_and_approx, logistic_or_approx], **kwargs)


class AIL_OR_XNOR_duplicate(MultiActfunDuplicate):
    def __init__(self, **kwargs):
        super().__init__([logistic_or_approx, logistic_xnor_approx], **kwargs)


class AIL_AND_OR_XNOR_duplicate(MultiActfunDuplicate):
    def __init__(self, **kwargs):
        super().__init__(
            [logistic_and_approx, logistic_or_approx, logistic_xnor_approx], **kwargs
        )


class max_min_partition(MultiActfunPartition):
    def __init__(self, **kwargs):
        super().__init__([maxd, mind], **kwargs)


class AIL_OR_XNOR_partition(MultiActfunPartition):
    def __init__(self, **kwargs):
        super().__init__([logistic_or_approx, logistic_xnor_approx], **kwargs)


class AIL_AND_OR_XNOR_partition(MultiActfunPartition):
    def __init__(self, **kwargs):
        super().__init__(
            [logistic_and_approx, logistic_or_approx, logistic_xnor_approx], **kwargs
        )


def get_data_loader(
    dataset="covtype", validate=False, batch_size=128, test_batch_size=None, seed=0
):

    if dataset == "iris":
        X, y = sklearn.datasets.load_iris(return_X_y=True)
    elif dataset == "covtype":
        X, y = sklearn.datasets.fetch_covtype(return_X_y=True, random_state=0)
    else:
        raise ValueError("Unsupported dataset name: {}".format(dataset))
    # Handle covtype classes indexed from 1
    y -= np.min(y)
    # Partition the data
    X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
        X,
        y,
        train_size=0.8,
        stratify=y,
        shuffle=True,
        random_state=seed,
    )
    if validate:
        # Use train and val subpartitions of the training partition
        X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
            X_train,
            y_train,
            train_size=0.85,
            stratify=y_train,
            shuffle=True,
            random_state=seed,
        )
    scaler = sklearn.preprocessing.StandardScaler().fit(X_train)
    X_train = scaler.transform(X_train)
    X_test = scaler.transform(X_test)

    dataset_train = torch.utils.data.TensorDataset(
        torch.Tensor(X_train),
        torch.tensor(y_train, dtype=torch.long),
    )
    dataset_test = torch.utils.data.TensorDataset(
        torch.Tensor(X_test), torch.tensor(y_test, dtype=torch.long)
    )
    if batch_size is None:
        batch_size = len(dataset_train)
    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=min(batch_size, len(dataset_train)),
        drop_last=True,
        shuffle=True,
    )
    if test_batch_size is None:
        test_batch_size = batch_size
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=min(test_batch_size, len(dataset_test)),
        drop_last=False,
        shuffle=False,
    )
    return train_loader, test_loader


class MLP(nn.Module):
    def __init__(
        self, n_input, n_output, n_hidden, n_layer, actfun_factory=nn.ReLU, k=1
    ):
        super(MLP, self).__init__()
        self.actfun_factory = actfun_factory
        self.layers = nn.ModuleList(
            [
                nn.Linear(n_input, n_hidden, bias=False),
                torch.nn.BatchNorm1d(n_hidden, affine=True),
                actfun_factory(),
            ]
        )
        for i_layer in range(1, n_layer):
            self.layers.append(nn.Linear(int(n_hidden / k), n_hidden, bias=False))
            self.layers.append(torch.nn.BatchNorm1d(n_hidden, affine=True))
            self.layers.append(actfun_factory())
        self.layers.append(nn.Linear(int(n_hidden / k), n_output))

    def forward(self, x):
        for i_layer, layer in enumerate(self.layers):
            x = layer(x)
        return x


def run_lr_finder(
    model,
    train_loader,
    optimizer,
    criterion,
    val_loader=None,
    verbose=True,
    show=True,
    figpth=None,
    device=None,
    recommender="min110",
):
    if verbose:
        print("Running learning rate finder")
    lr_finder = LRFinder(model, optimizer, criterion, device=device)
    lr_finder.range_test(
        train_loader,
        val_loader=val_loader,
        start_lr=1e-7,
        end_lr=10,
        num_iter=200,
        diverge_th=3,
    )
    # Find the minimum loss
    min_index = np.argmin(lr_finder.history["loss"])
    lr_at_min = lr_finder.history["lr"][min_index]
    min_loss = lr_finder.history["loss"][min_index]
    # Find the maximum loss preceding the minima
    max_index = np.argmax(lr_finder.history["loss"][:min_index])
    lr_at_max = lr_finder.history["lr"][max_index]
    max_loss = lr_finder.history["loss"][max_index]
    # Plot LR finder result
    if not show and not figpth:
        lr_steepest = None
    else:
        if verbose:
            print("Plotting learning rate finder results")
        hf = plt.figure(figsize=(15, 9))
        ax = plt.axes()
        _, lr_steepest = lr_finder.plot(skip_start=0, skip_end=3, log_lr=True, ax=ax)
        ylim = np.array([min_loss, max_loss])
        ylim += 0.1 * np.diff(ylim) * np.array([-1, 1])
        plt.ylim(ylim)
        plt.tick_params(reset=True, color=(0.2, 0.2, 0.2))
        plt.tick_params(labelsize=14)
        ax.minorticks_on()
        ax.tick_params(direction="out")

    init_loss = lr_finder.history["loss"][0]

    loss_12 = min_loss + 0.5 * (max_loss - min_loss)
    index_12 = max_index + np.argmin(
        np.abs(np.array(lr_finder.history["loss"][max_index:min_index]) - loss_12)
    )
    lr_12 = lr_finder.history["lr"][index_12]

    loss_13 = min_loss + 1 / 3 * (max_loss - min_loss)
    index_13 = max_index + np.argmin(
        np.abs(np.array(lr_finder.history["loss"][max_index:min_index]) - loss_13)
    )
    lr_13 = lr_finder.history["lr"][index_13]

    loss_23 = min_loss + 2 / 3 * (max_loss - min_loss)
    index_23 = max_index + np.argmin(
        np.abs(np.array(lr_finder.history["loss"][max_index:min_index]) - loss_23)
    )
    lr_23 = lr_finder.history["lr"][index_23]

    loss_14 = min_loss + 1 / 4 * (max_loss - min_loss)
    index_14 = max_index + np.argmin(
        np.abs(np.array(lr_finder.history["loss"][max_index:min_index]) - loss_14)
    )
    lr_14 = lr_finder.history["lr"][index_14]

    loss_18 = min_loss + 1 / 8 * (max_loss - min_loss)
    index_18 = max_index + np.argmin(
        np.abs(np.array(lr_finder.history["loss"][max_index:min_index]) - loss_18)
    )
    lr_18 = lr_finder.history["lr"][index_18]

    loss_110 = min_loss + 0.1 * (max_loss - min_loss)
    index_110 = max_index + np.argmin(
        np.abs(np.array(lr_finder.history["loss"][max_index:min_index]) - loss_110)
    )
    lr_110 = lr_finder.history["lr"][index_110]

    if recommender == "div10":
        lr_recomend = np.exp(np.mean([np.log(lr_at_min / 10), np.log(lr_12)]))
    elif recommender == "min12":
        lr_recomend = np.min([lr_at_min / 10, lr_12])
    elif recommender == "min13":
        lr_recomend = np.min([lr_at_min / 10, lr_13])
    elif recommender == "min14":
        lr_recomend = np.min([lr_at_min / 10, lr_14])
    elif recommender == "min18":
        lr_recomend = np.min([lr_at_min / 10, lr_18])
    elif recommender == "min110":
        lr_recomend = np.min([lr_at_min / 10, lr_110])
    elif recommender == "logmean12":
        lr_recomend = np.exp(np.mean([np.log(lr_at_min / 10), np.log(lr_12)]))
    elif recommender == "logmean13":
        lr_recomend = np.exp(np.mean([np.log(lr_at_min / 10), np.log(lr_13)]))
    elif recommender == "logmean14":
        lr_recomend = np.exp(np.mean([np.log(lr_at_min / 10), np.log(lr_14)]))

    if verbose:
        if lr_steepest is not None:
            print("LR at steepest grad: {:.3e}  (red)".format(lr_steepest))
        print("LR at minimum loss : {:.3e}".format(lr_at_min))
        print("LR a tenth of min  : {:.3e}  (orange)".format(lr_at_min / 10))
        print("LR when 1/10 up    : {:.3e}  (purple)".format(lr_110))
        print("LR when 1/4 up     : {:.3e}  (yellow)".format(lr_14))
        print("LR when 1/3 up     : {:.3e}  (blue)".format(lr_13))
        print("LR when 1/2 up     : {:.3e}  (cyan)".format(lr_12))
        print("LR when 2/3 up     : {:.3e}  (green)".format(lr_23))
        print("LR recommended     : {:.3e}  (black)".format(lr_recomend))

    if show or figpth:
        ax.axvline(x=lr_steepest, color="red")
        ax.axvline(x=lr_at_min / 10, color="orange")
        ax.axvline(x=lr_110, color="purple")
        ax.axvline(x=lr_14, color="yellow")
        ax.axvline(x=lr_13, color="blue")
        ax.axvline(x=lr_12, color="cyan")
        ax.axvline(x=lr_23, color="green")
        ax.axvline(x=lr_recomend, color="black", ls=":")
    if figpth:
        # Save figure
        os.makedirs(os.path.dirname(figpth), exist_ok=True)
        plt.savefig(figpth)
        if verbose:
            print("LR Finder results saved to {}".format(figpth))
    if show:
        plt.show()

    return lr_recomend


def validate(model, val_loader, criterion, device=None):
    model.eval()
    test_loss = 0
    n_correct = 0
    n_sample = 0
    with torch.no_grad():
        for batch_idx, (input, target) in enumerate(val_loader):
            input, target = input.to(device), target.to(device)
            output = model(input)
            test_loss += len(input) * criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            n_correct += pred.eq(target.view_as(pred)).sum().item()
            n_sample += len(input)
    test_loss /= n_sample
    test_acc = n_correct / n_sample
    return test_acc, test_loss


def train(
    model,
    train_loader,
    optimizer,
    val_loader=None,
    schedule="onecycle",
    max_lr=1e-2,
    n_epoch=50,
    device=None,
    verbose=True,
    log_interval=1,
):
    t_start = time.time()
    # optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-2, weight_decay=1e-5, momentum=0.7)
    criterion = nn.CrossEntropyLoss()

    batch_scheduler = None
    epoch_scheduler = None
    if verbose:
        print("Using {} schedule".format(schedule))
    if schedule == "onecycle":
        batch_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr,
            epochs=n_epoch,
            steps_per_epoch=len(train_loader),
        )
    elif schedule == "exponential":
        epoch_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
    else:
        raise ValueError("Unsupported schedule: {}".format(schedule))

    first_max_val = val_acc = None
    model.to(device)
    # Using apex.amp
    # model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    if verbose:
        print("Starting training for {} epochs, device={}".format(n_epoch, device))
    for epoch in range(1, n_epoch + 1):
        t_ep_start = time.time()
        model.train()
        n_correct = 0
        n_sample = 0
        for batch_idx, (input, target) in enumerate(train_loader):
            input, target = input.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(input)
            loss = criterion(output, target)
            loss.backward()
            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #    scaled_loss.backward()
            optimizer.step()
            pred = output.argmax(dim=1, keepdim=True)
            n_correct += pred.eq(target.view_as(pred)).sum().item()
            n_sample += len(input)
            if batch_scheduler:
                batch_scheduler.step()

        train_acc = n_correct / n_sample
        if val_loader is not None:
            val_acc, val_loss = validate(
                model, val_loader, criterion=criterion, device=device
            )
            if val_acc >= 1 and first_max_val is None:
                first_max_val = epoch
        else:
            val_acc = None

        if epoch % log_interval == 0 and verbose:
            print(
                "Train Epoch:{:3d}/{:3d} ({:4.1f}s, {:5.1f}s)  LR: {:.3e}  Loss: {:.6f}  Train acc: {:6.2f}%  Val acc: {:6.2f}%".format(
                    epoch,
                    n_epoch,
                    time.time() - t_ep_start,
                    time.time() - t_start,
                    get_current_lr(optimizer),
                    loss.item(),
                    100 * train_acc,
                    100 * val_acc,
                ),
                flush=True,
            )
        if epoch_scheduler:
            epoch_scheduler.step()

    print("Total training time: {:.1f}s".format(time.time() - t_start))

    train_acc_fin, train_loss_fin = validate(
        model, train_loader, criterion=criterion, device=device
    )

    return train_acc_fin, val_acc, train_loss_fin, val_loss, first_max_val


def run_experiment(
    actfun_factory=nn.ReLU,
    effective_k=1,
    seed=0,
    seed_data_partition=0,
    n_latent=50,
    n_layer=2,
    batch_size=128,
    test_batch_size=None,
    n_epoch=50,
    log_interval=1,
    max_lr=None,
    weight_decay=0,
    use_val=False,
    device=None,
    lr_recommender="min12",
    lrfinder_figpth=None,
):
    t_start = time.time()
    print(
        "Running {} layer MLP, {} neurons/layer, {} (efk={}, seed={}); {} epochs, {} batch_size".format(
            n_layer,
            n_latent,
            actfun_factory.__name__,
            effective_k,
            seed,
            n_epoch,
            batch_size,
        ),
        flush=True,
    )

    def init(use_val=False):
        train_loader, test_loader = get_data_loader(
            batch_size=batch_size,
            test_batch_size=test_batch_size,
            seed=seed_data_partition,
            validate=use_val,
        )
        set_rng_state(seed)
        model = MLP(
            54,
            7,
            n_latent,
            n_layer,
            actfun_factory=actfun_factory,
            k=effective_k,
        )
        print(model)
        print("{} parameters".format(count_parameters(model)))
        return train_loader, test_loader, model

    if max_lr is None:
        train_loader, test_loader, model = init(use_val=True)
        lr_found = run_lr_finder(
            model,
            train_loader,
            torch.optim.Adam(model.parameters()),
            nn.CrossEntropyLoss(),
            val_loader=None,
            show=False,
            figpth=lrfinder_figpth,
            device=device,
            recommender=lr_recommender,
        )
        print("Time for LR finder: {:.1f}s".format(time.time() - t_start))
        print("Found learning rate: {:3e}".format(lr_found))
        max_lr = lr_found

    train_loader, test_loader, model = init(use_val=use_val)

    print("Calling trainer")
    t_start_train = time.time()
    optimizer = torch.optim.Adam(model.parameters(), max_lr, weight_decay=weight_decay)
    train_acc, val_acc, train_loss, val_loss, first_max_val = train(
        model,
        train_loader,
        optimizer,
        val_loader=test_loader,
        schedule="onecycle",
        max_lr=max_lr,
        n_epoch=n_epoch,
        device=device,
        log_interval=log_interval,
    )
    t_train = time.time() - t_start_train

    print(
        "Results for {} layer MLP, {} neurons/layer, {} (efk={}, seed={}); {} epochs".format(
            n_layer,
            n_latent,
            actfun_factory.__name__,
            effective_k,
            seed,
            n_epoch,
        )
    )
    print(
        "Train acc: {:6.2f}%   Test acc: {:6.2f}%".format(
            100 * train_acc,
            100 * val_acc,
        )
    )
    print("Train time: {:.1f}s".format(t_train))
    print("Total time: {:.1f}s".format(time.time() - t_start), flush=True)
    return (
        count_parameters(model),
        train_acc,
        val_acc,
        train_loss,
        val_loss,
        max_lr,
        t_train,
    )


def actfun_name2factory(name):
    lower_name = name.lower()
    if lower_name == "relu":
        return (nn.ReLU, 1)
    elif lower_name == "maxout":
        return (MaxOut, 2)
    elif lower_name == "signedgeomean":
        return (SignedGeomean, 2)
    elif lower_name == "signedgeomean_clamp2":
        return (functools.partial(SignedGeomean, clamp_grad=2), 2)
    elif lower_name == "signedgeomean_clamp10":
        return (functools.partial(SignedGeomean, clamp_grad=10), 2)
    elif lower_name == "ail_and":
        return (AIL_AND, 2)
    elif lower_name == "ail_or":
        return (AIL_OR, 2)
    elif lower_name == "ail_xnor":
        return (AIL_XNOR, 2)
    elif lower_name == "max_min_dup":
        return (max_min_duplicate, 1)
    elif lower_name == "ail_and_or_dup":
        return (AIL_AND_OR_duplicate, 1)
    elif lower_name == "ail_or_xnor_dup":
        return (AIL_OR_XNOR_duplicate, 1)
    elif lower_name == "ail_and_or_xnor_dup":
        return (AIL_AND_OR_XNOR_duplicate, 2 / 3)
    elif lower_name == "max_min_part":
        return (max_min_partition, 2)
    elif lower_name == "ail_or_xnor_part":
        return (AIL_OR_XNOR_partition, 2)
    elif lower_name == "ail_and_or_xnor_part":
        return (AIL_AND_OR_XNOR_partition, 2)
    elif hasattr(nn, name):
        return (getattr(nn, name), 1)
    else:
        raise ValueError("Unsupported actfun: {}".format(name))


def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description="PyTorch combinact covertype experiment"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1024,
        metavar="N",
        help="input batch size for training (default: %(default)s)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: default: %(default)s)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=50,
        metavar="N",
        help="number of epochs to train (default: %(default)s)",
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
    )
    parser.add_argument(
        "--force",
        action="store_true",
        default=False,
        help="re-run and overwrite existing runs (default is to skip)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        metavar="S",
        help="random seed (default: %(default)s)",
    )
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument(
        "--actfun",
        type=str,
        default="ReLU",
        help="which activation function to use (default: %(default)s)",
    )
    parser.add_argument(
        "--nlatent",
        type=int,
        default=50,
        help="number of units per hidden layer (default: %(default)s)",
    )
    parser.add_argument(
        "--nlayer",
        type=int,
        default=2,
        help="number of hidden layers (default: %(default)s)",
    )
    parser.add_argument(
        "--validate",
        action="store_true",
        help="evaluate on val partition (default: evaluate on test partition)",
    )
    parser.add_argument(
        "--lr-recommender",
        type=str,
        default="min12",
        help="LR recommendation mechanism (default: %(default)s)",
    )
    parser.add_argument(
        "--weight-decay",
        "--wd",
        type=float,
        default=0,
        help="weight decay (default: %(default)s)",
    )
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    actfun_factory, effk = actfun_name2factory(args.actfun)

    if args.weight_decay == 0:
        wd_str = ""
    else:
        wd_str = "__wd-{}".format(args.weight_decay)
    output_fname = os.path.join(
        "results",
        "covtype-std_{}__lr-{}{}".format(
            "val" if args.validate else "test", args.lr_recommender, wd_str
        ),
        "ep-{}__bs-{}".format(args.epochs, args.batch_size),
        args.actfun,
        "act-{}__nlayer-{}__nlatent-{}__s-{}.csv".format(
            args.actfun,
            args.nlayer,
            args.nlatent,
            args.seed,
        ),
    )

    if os.path.isfile(output_fname) and not args.force:
        print(
            "Aborting run since target file\n  {}\nalready exists".format(output_fname)
        )
        return

    t_start = time.time()
    (
        param_count,
        train_acc,
        val_acc,
        train_loss,
        val_loss,
        max_lr,
        t_train,
    ) = run_experiment(
        actfun_factory=actfun_factory,
        effective_k=effk,
        seed=args.seed,
        n_latent=args.nlatent,
        n_layer=args.nlayer,
        n_epoch=args.epochs,
        batch_size=args.batch_size,
        test_batch_size=args.test_batch_size,
        weight_decay=args.weight_decay,
        use_val=args.validate,
        device=device,
        lr_recommender=args.lr_recommender,
        lrfinder_figpth=os.path.splitext(output_fname)[0] + ".jpg",
    )
    t_end = time.time()
    os.makedirs(os.path.dirname(output_fname), exist_ok=True)

    with open(output_fname, "w") as hf:
        hf.write(
            "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                "n_epoch",
                "batch_size",
                "actfun",
                "effk",
                "n_layer",
                "n_latent",
                "seed",
                "max_lr",
                "weight_decay",
                "param_count",
                "train_acc",
                "val_acc" if args.validate else "test_acc",
                "train_loss",
                "val_loss" if args.validate else "test_loss",
                "time_train",
                "time_total",
            )
        )
        hf.write(
            "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                args.epochs,
                args.batch_size,
                args.actfun,
                effk,
                args.nlayer,
                args.nlatent,
                args.seed,
                max_lr,
                args.weight_decay,
                param_count,
                train_acc,
                val_acc,
                train_loss,
                val_loss,
                t_train,
                t_end - t_start,
            )
        )


if __name__ == "__main__":
    main()
