import torch
import mnist_manipulator
from mnist_manipulator import Getter, Putter, Encoder, Decoder, VAEEncoder, VAEDecoder
import matplotlib.pyplot as plt

# Train task (a) from Appendix D.3, saves plots and models to `outputs/`
def train_getter_initial(test_images, test_labels, train, val, device, name) -> "tuple[float, Getter]":
    print(f"[{name}] training getter supervised:")
    getter = Getter().to(device)
    mnist_manipulator.train_getter_supervised(train, val, getter, epochs=20, lr=0.001)
    acc, fig = mnist_manipulator.plot_getter(test_images, test_labels, getter)
    fig.savefig(f"outputs/plots/{name}_getter_initial.png", dpi=300)
    plt.close(fig)
    print(f"[{name}] test accuracy = {acc*100.0:.1f}%")
    torch.save(getter.state_dict(), f"outputs/models/{name}_getter_initial.pt")
    return acc, getter

# Train task (b) from Appendix D.3, saves plots and models to `outputs/`
def train_getter_from_putter(test_images, test_labels, train, val, device, putter, name, idx) -> "tuple[float, Getter]":
    print(f"[{name}] training getter for step {idx}:")
    getter = Getter().to(device)
    supervised_batch, supervised_labels = mnist_manipulator.make_supervised_subset(test_images, test_labels, 10)
    mnist_manipulator.train_putter_getter(
        train, val, device, 
        supervised_batch, supervised_labels, 
        putter, getter, freeze_put=True,
        epochs = 20, lr = 0.001,
        w_getput = 10.0, w_putget = 10.0,
        w_putput = 10.0, w_undo = 10.0,
        w_dist1 = 1.0, w_dist2 = 1.0,
        w_supervised = 0.0
    )
    acc, fig = mnist_manipulator.plot_getter(test_images, test_labels, getter)
    print(f"[{name}] test accuracy = {acc*100.0:.1f}%")
    fig.savefig(f"outputs/plots/{name}_getter_{idx}.png", dpi=300)
    plt.close(fig)
    torch.save(getter.state_dict(), f"outputs/models/{name}_getter_{idx}.pt")

    return acc, getter

# Train tasks (c) and (e) from Appendix D.3, saves plots and models to `outputs/`
def train_putter_from_getter(images, labels, train, val, big_train, big_val, device, getter, name, idx) -> Putter:
    print(f"[{name}] training autoencoder for step {idx}:")
    encoder = Encoder().to(device)
    decoder = Decoder().to(device)
    mnist_manipulator.train_autoencoder(big_train, big_val, encoder, decoder, 80, 0.001, 1.0, 0.0, 1.0)
    fig = mnist_manipulator.plot_autoencoder(val, encoder, decoder)
    fig.savefig(f"outputs/plots/{name}_autoencoder_{idx}.png", dpi=300)
    plt.close(fig)
    torch.save(encoder.state_dict(), f"outputs/models/{name}_autoencoder_{idx}_encoder.pt")
    torch.save(decoder.state_dict(), f"outputs/models/{name}_autoencoder_{idx}_decoder.pt")

    print(f"[{name}] training putter for step {idx}:")
    putter = Putter(encoder, decoder).to(device)
    supervised_batch, supervised_labels = mnist_manipulator.make_supervised_subset(images, labels, 10)
    mnist_manipulator.train_putter_getter(
        train, val, device, 
        supervised_batch, supervised_labels, 
        putter, getter, freeze_get=True,
        epochs = 20, lr = 0.001,
        w_getput = 10.0, w_putget = 10.0,
        w_putput = 10.0, w_undo = 10.0,
        w_dist1 = 0.0, w_dist2 = 0.0,
        w_supervised = 0.0
    )
    fig1, fig2 = mnist_manipulator.plot_putter(val, getter, putter, device)
    fig1.savefig(f"outputs/plots/{name}_putter_{idx}_getput.png", dpi=300)
    fig2.savefig(f"outputs/plots/{name}_putter_{idx}_putall.png", dpi=300)
    plt.close(fig1)
    plt.close(fig2)
    torch.save(putter.state_dict(), f"outputs/models/{name}_putter_{idx}.pt")

    return putter

# Train task (d) from Appendix D.3, saves plots and models to `outputs/`
def train_getter_from_getter(test_images, test_labels, train, val, device, getter, name, idx) -> "tuple[float, Getter]":
    print(f"[{name}] training getter for step {idx}:")
    getter2 = Getter().to(device)
    mnist_manipulator.train_getter_transfer(train, val, getter, getter2, epochs=20, lr=0.001)
    acc, fig = mnist_manipulator.plot_getter(test_images, test_labels, getter2)
    fig.savefig(f"outputs/plots/{name}_getter_{idx}.png", dpi=300)
    plt.close(fig)
    print(f"[{name}] test accuracy = {acc*100.0:.1f}%")
    torch.save(getter2.state_dict(), f"outputs/models/{name}_getter_{idx}.pt")
    return acc, getter2

# Train task (f) from Appendix D.3, saves plots and models to `outputs/`
def train_vae_from_getter(train, val, device, getter, name, idx) -> "tuple[VAEEncoder, VAEDecoder]":
    print(f"[{name}] training vae for step {idx}:")
    encoder = VAEEncoder().to(device)
    decoder = VAEDecoder().to(device)
    mnist_manipulator.train_vae_from_getter(
        train, val, encoder, decoder, getter,
        epochs = 40, lr = 0.001,
        w_lab = 1.0, w_im_mse = 100.0, w_im_ssim = 0.0, w_kl = 0.5
    )
    fig1, fig2 = mnist_manipulator.plot_vae(val, device, encoder, decoder)
    fig1.savefig(f"outputs/plots/{name}_vae_{idx}_recon.png", dpi=300)
    fig2.savefig(f"outputs/plots/{name}_vae_{idx}_generate.png", dpi=300)
    plt.close(fig1)
    plt.close(fig2)
    torch.save(encoder.state_dict(), f"outputs/models/{name}_vae_encoder_{idx}.pt")
    torch.save(decoder.state_dict(), f"outputs/models/{name}_vae_decoder_{idx}.pt")

    return (encoder, decoder)

# Train task (g) from Appendix D.3, saves plots and models to `outputs/`
def train_getter_from_vae(test_images, test_labels, train, val, device, decoder, name, idx) -> "tuple[float, Getter]":
    print(f"[{name}] training getter for step {idx}:")
    getter = Getter().to(device)
    mnist_manipulator.train_getter_from_vae(train, val, device, decoder, getter, epochs=20, lr=0.001)
    acc, fig = mnist_manipulator.plot_getter(test_images, test_labels, getter)
    fig.savefig(f"outputs/plots/{name}_getter_{idx}.png", dpi=300)
    plt.close(fig)
    print(f"[{name}] test accuracy = {acc*100.0:.1f}%")
    torch.save(getter.state_dict(), f"outputs/models/{name}_getter_{idx}.pt")
    return acc, getter
