import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--device', help='which torch device to run on', default='cpu')
parser.add_argument('--epochs', help='number of epochs to train for', type=int)
parser.add_argument('--dataset', help='where to download or find the MNIST dataset', default='./datasets')
parser.add_argument('--outdir', help='directory in which to put output files', default='.')
parser.add_argument('--getter', help='specify a pretrained getter')
parser.add_argument('--putter', help='specify a pretrained putter')
args = parser.parse_args()

import torch
import mnist_manipulator
import model_figure5a

device = torch.device(args.device)

images, labels, test_images, test_labels, train, val = mnist_manipulator.load_mnist_dataset(64, device, root = args.dataset)


if args.getter is None or args.putter is None:
    encoder = model_figure5a.Encoder().to(device)
    decoder = model_figure5a.Decoder().to(device)

    print("Pre-training autoencoder unsupervised:")
    mnist_manipulator.train_autoencoder(
        train, val, encoder, decoder,
        epochs = args.epochs, lr = 0.001, weight_l2 = 1.0, weight_ssim = 1.0
    )

    putter = model_figure5a.Putter(encoder, decoder).to(device)
    getter = model_figure5a.Getter(encoder).to(device)

    print("Training putter and getter:")
    # These aren't actually used..
    supervised_batch, supervised_labels = mnist_manipulator.make_supervised_subset(images, labels, 1)
    # Train the manipulator:
    mnist_manipulator.train_putter_getter(
        train, val, device,
        supervised_batch, supervised_labels,
        putter, getter,
        epochs = args.epochs, lr = 0.001,
        w_getput = 10.0, w_putget = 10.0,
        w_putput = 10.0, w_undo = 10.0,
        w_dist1 = 0.0, w_dist2 = 0.0,
        w_supervised = 0.0, w_classifier = 1.0
    )

    torch.save(encoder.state_dict(), f"{args.outdir}/trained_encoder.pt")
    torch.save(decoder.state_dict(), f"{args.outdir}/trained_decoder.pt")
    torch.save(getter.state_dict(), f"{args.outdir}/trained_getter.pt")
    torch.save(putter.state_dict(), f"{args.outdir}/trained_putter.pt")
else:
    assert args.putter is not None and args.getter is not None

    encoder = mnist_manipulator.Encoder().to(device)
    decoder = mnist_manipulator.Decoder().to(device)
    putter = model_figure5a.Putter(encoder, decoder).to(device)
    putter.load_state_dict(torch.load(args.putter, map_location=device))
    getter = model_figure5a.Getter(encoder).to(device)
    getter.load_state_dict(torch.load(args.getter, map_location=device))

print("Evaluating and plotting:")

acc, fig = mnist_manipulator.plot_getter(images, labels, getter)
fig1, fig2 = mnist_manipulator.plot_putter(val, getter, putter, device)

print(f"Test Accuracy = {100.0*acc:.1f}%")

fig.savefig(f"{args.outdir}/mnist_getter.png", dpi=300)
fig1.savefig(f"{args.outdir}/mnist_putter_get_put.png", dpi=300)
fig2.savefig(f"{args.outdir}/mnist_putter_put_all.png", dpi=300)


