import argparse

parser = argparse.ArgumentParser()
parser.add_argument('mode', help='which method to use to learn a `get`: (a) `transfer`, (b) `vae`, or (c) `putter`')
parser.add_argument('steps', help='the number of repetitions of training `get`', type=int)
parser.add_argument('name', help='the name of the run (can have several to run multiple runs)', nargs='+')
parser.add_argument('-d', '--device', help='torch device to run with', default='cpu')
args = parser.parse_args()

import training_schedules
import mnist_manipulator
import torch
import matplotlib.pyplot as plt
import numpy as np
import sys

class Logger(object):
    def __init__(self, name):
        self.terminal = sys.stdout
        self.log = open(f"outputs/logs/{name}.txt", "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

mode = args.mode
steps = args.steps
device = torch.device(args.device)
names = args.name

images, labels, test_images, test_labels, train, val = mnist_manipulator.load_mnist_dataset(64, device)
_, _, _, _, big_train, big_val = mnist_manipulator.load_mnist_dataset(512, device)

for name in names:
    sys.stdout = Logger(name)

    accs = []

    acc, getter = training_schedules.train_getter_initial(test_images, test_labels, train, val, device, name)
    accs.append(acc)

    for idx in range(steps):
        if mode == "putter":
            putter = training_schedules.train_putter_from_getter(images, labels, train, val, big_train, big_val, device, getter, name, idx)
            acc, getter = training_schedules.train_getter_from_putter(test_images, test_labels, train, val, device, putter, name, idx)
        elif mode == "transfer":
            acc, getter = training_schedules.train_getter_from_getter(test_images, test_labels, train, val, device, getter, name, idx)
        elif mode == "vae":
            _, decoder = training_schedules.train_vae_from_getter(train, val, device, getter, name, idx)
            acc, getter = training_schedules.train_getter_from_vae(test_images, test_labels, train, val, device, decoder, name, idx)
        accs.append(acc)

    np.save(f"outputs/data/{name}_accs.npy", np.array(accs))
    
    fig = plt.figure()
    axes = fig.gca()
    axes.plot(accs, '.-')
    axes.set_title(f"{name} accuracy")
    axes.set_xlabel("Steps")
    axes.set_ylabel("Accuracy")
    axes.set_xticks(list(range(steps+1)))
    axes.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], ["0%", "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%", "100%"])
    axes.set_ylim(0.0, 1.0)
    axes.set_xlim(0, steps)
    axes.grid()
    fig.savefig(f"outputs/plots/{name}_accs.png", dpi=300)
    plt.close(fig)
