import torch
import torch.utils.data
import torchvision
import torchvision.transforms.v2
import matplotlib.pyplot as plt
import tqdm
import math
import sklearn.metrics
import sklearn.cluster
import numpy as np
import piq

class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.latent = 32

        self.enc = torch.nn.Sequential(
            torch.nn.Flatten(1),
            torch.nn.Linear(28*28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.Sigmoid()
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        return self.enc(image)
    
class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.latent = 32

        self.dec = torch.nn.Sequential(
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28*28),
            torch.nn.Unflatten(1, (1, 28, 28)),
            torch.nn.Sigmoid()
        )

    def forward(self, latent: torch.Tensor) -> torch.Tensor:
        return self.dec(latent)
    
class Getter(torch.nn.Module):
    def __init__(self, classes: int = 10, image_dims: "tuple[int, int, int]" = (1, 28, 28)):
        super().__init__()

        self.get = torch.nn.Sequential(
            torch.nn.Flatten(1),
            torch.nn.Linear(math.prod(image_dims), 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, classes),
            torch.nn.Softmax(-1)
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        return self.get(image)

class Putter(torch.nn.Module):
    def __init__(self, enc: Encoder, dec: Decoder, classes: int = 10, freeze_enc: bool = True, freeze_dec: bool = True):
        super().__init__()

        self.enc = enc
        if freeze_enc:
            self.enc.requires_grad_(False)
        self.dec = dec
        if freeze_dec:
            self.dec.requires_grad_(False)

        self.put = torch.nn.Sequential(
            torch.nn.Linear(enc.latent + classes, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, dec.latent),
            torch.nn.Sigmoid()
        )

    def forward(self, image: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
        return self.dec(self.put(torch.cat((self.enc(image), label), dim=-1)))

class VAEEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.latent = 32

        self.enc = torch.nn.Sequential(
            torch.nn.Linear(28*28 + 10, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU()
        )

        self.means = torch.nn.Sequential(
            torch.nn.Linear(64, 32)
        )

        self.log_var = torch.nn.Sequential(
            torch.nn.Linear(64, 32)
        )

    def forward(self, image: torch.Tensor, labels: torch.Tensor) -> "tuple[torch.Tensor, torch.Tensor]":
        labels = torch.nn.functional.one_hot(labels, 10)
        image = image.flatten(1)
        data = torch.cat((image, labels), dim=-1)
        data = self.enc(data)
        means = self.means(data)
        log_var = self.log_var(data)
        return means, log_var

    def sample(self, means: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        return means + torch.exp(log_var / 2.0) * torch.randn_like(log_var)
    
class VAEDecoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.latent = 32

        self.dec = torch.nn.Sequential(
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU()
        )

        self.image = torch.nn.Sequential(
            torch.nn.Linear(128, 1 * 28 * 28),
            torch.nn.Sigmoid(),
            torch.nn.Unflatten(1, (1, 28, 28))
        )

        self.label = torch.nn.Sequential(
            torch.nn.Linear(128, 10),
            torch.nn.Softmax(-1)
        )

    def forward(self, latent: torch.Tensor) -> "tuple[torch.Tensor, torch.Tensor]":
        latent = self.dec(latent)
        image = self.image(latent)
        label = self.label(latent)
        return image, label
    
def load_mnist_dataset(
        batch_size: int, device: torch.device, root: str = './datasets'
) -> "tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.utils.data.DataLoader, torch.utils.data.DataLoader]":
    dataset = torchvision.datasets.MNIST(root, download=True, train=True)
    tr = torchvision.transforms.ToTensor()
    images = torch.stack([tr(im) for im, _ in dataset], dim=0).to(device)
    labels = torch.tensor([lab for _, lab in dataset]).to(device)
    del dataset
    train_images, train_labels = images[:50000], labels[:50000]
    val_images, val_labels = images[50000:], labels[50000:]
    train = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_images, train_labels), batch_size=batch_size, shuffle=True)
    val = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_images, val_labels), batch_size=batch_size, shuffle=True)
    dataset2 = torchvision.datasets.MNIST(root, download=True, train=False)
    test_images = torch.stack([tr(im) for im, _ in dataset2], dim=0).to(device)
    test_labels = torch.tensor([lab for _, lab in dataset2]).to(device)
    return images, labels, test_images, test_labels, train, val

# Train task (e) from Appendix D.3. weight_l2, weight_l1, and weight_ssim
# are the L2, L1, and SSIM components of the reconstruction loss for images.
def train_autoencoder(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader,
    encoder: Encoder, decoder: Decoder, 
    epochs: int = 20, lr: float = 0.001,
    weight_l2: float = 100.0, weight_l1: float = 0.0, weight_ssim: float = 1.0
):
    optim = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)

    for ep in range(epochs):
        losses = [0.0]*3
        for im, _ in tqdm.tqdm(train):
            optim.zero_grad()
            # Reconstruction Loss
            lab = encoder.forward(im)
            recon = decoder.forward(lab)
            loss_l2 = torch.nn.functional.mse_loss(im, recon)
            loss_l1 = torch.nn.functional.l1_loss(im, recon)
            loss_ssim = 1.0 - piq.ssim(im, recon)
            losses[0] += loss_l2
            losses[1] += loss_l1
            losses[2] += loss_ssim
            loss = weight_l2 * loss_l2 + weight_l1 * loss_l1 + weight_ssim * loss_ssim
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:g}' for i, l in enumerate(losses))}")

        losses = [0.0]*3
        for im, _ in tqdm.tqdm(val):
            with torch.no_grad():
                lab = encoder.forward(im)
                recon = decoder.forward(lab)
                loss_l2 = torch.nn.functional.mse_loss(im, recon)
                loss_l1 = torch.nn.functional.l1_loss(im, recon)
                loss_ssim = 1.0 - piq.ssim(im, recon)
                losses[0] += loss_l2
                losses[1] += loss_l1
                losses[2] += loss_ssim
        losses = [l / len(val) for l in losses]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:g}' for i, l in enumerate(losses))}")

# For semi-supervised training, select a class-balanced subset of the dataset.
# This is not used for the paper.
def make_supervised_subset(
    images: torch.Tensor, labels: torch.Tensor, per_class: int, classes: int = 10
) -> "tuple[torch.Tensor, torch.Tensor]":
    supervised_batch = []
    supervised_labels = []
    sup_labs = [0]*classes
    for i in range(images.shape[0]):
        if sup_labs[labels[i].item()] < per_class:
            sup_labs[labels[i].item()] += 1
            supervised_batch.append(images[i])
            supervised_labels.append(labels[i])
    supervised_batch = torch.stack(supervised_batch, dim=0)
    supervised_labels = torch.stack(supervised_labels, dim=0)
    return supervised_batch, supervised_labels

# Train task (a) from Appendix D.3, i.e training a `get` with supervision.
def train_getter_supervised(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader,
    getter: Getter, epochs: int = 20, lr: float = 0.001
):
    def evaluate(im: torch.Tensor, lab: torch.Tensor, losses: list, accs: list) -> torch.Tensor:
        # Classifier
        preds = getter.forward(im)
        loss_supervised = torch.nn.functional.nll_loss(preds.clip(min=1e-10).log(), lab)
        acc_supervised = (preds.argmax(-1) == lab).to(torch.float32).mean().item()
        losses[0] += loss_supervised
        accs[0] += acc_supervised
        return loss_supervised

    optim = torch.optim.Adam(list(p for p in getter.parameters() if p.requires_grad == True), lr=lr)

    for ep in range(epochs):
        losses = [0.0]*1
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(train):
            optim.zero_grad()
            loss = evaluate(im, lab, losses, accs)
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        accs = [l / len(train) for l in accs]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} train: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

        losses = [0.0]*1
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(val):
            with torch.no_grad():
                evaluate(im, lab, losses, accs)
        losses = [l / len(val) for l in losses]
        accs = [l / len(val) for l in accs]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} val: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

# Train task (d) from Appendix D.3, learn a new Getter (getter2) from an existing Getter (getter1).
def train_getter_transfer(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader,
    getter1: Getter, getter2: Getter, epochs: int = 20, lr: float = 0.001
):
    def evaluate(im: torch.Tensor, lab: torch.Tensor, losses: list, accs: list) -> torch.Tensor:
        # Classifier on getter1's labels
        labs = getter1.forward(im).detach().argmax(-1)
        preds = getter2.forward(im)
        loss_transfer = torch.nn.functional.nll_loss(preds.clip(min=1e-10).log(), labs)
        acc_transfer = (preds.argmax(-1) == labs).to(torch.float32).mean().item()
        losses[0] += loss_transfer
        accs[0] += acc_transfer
        return loss_transfer

    optim = torch.optim.Adam(list(p for p in getter2.parameters() if p.requires_grad == True), lr=lr)

    for ep in range(epochs):
        losses = [0.0]*1
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(train):
            optim.zero_grad()
            loss = evaluate(im, lab, losses, accs)
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        accs = [l / len(train) for l in accs]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} train: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

        losses = [0.0]*1
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(val):
            with torch.no_grad():
                evaluate(im, lab, losses, accs)
        losses = [l / len(val) for l in losses]
        accs = [l / len(val) for l in accs]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} val: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

# Train the `manipulation` pattern. This solves tasks (b) and (c) from Appendix D.3. 
# To train task (b), set freeze_put=True, for task (c) set freeze_get=True. 
# im_cmp_l2, im_cmp_l1, im_cmp_ssim set the relative weights of L2, L1, and SSIM terms in
# the loss function for comparing images. w_getput, w_putget, w_putput, w_undo control the 
# relative weights of the `manipulation` rules. w_dist1 and w_dist2 control the weight of
# the entropy regularization term. w_supervised can be used to enable semi-supervised learning.
# supervised_batch and supervised_labels can be generated with `make_supervised_subset`.
# w_classifier can be used to enable fully supervised learning.
def train_putter_getter(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader, device: torch.device,
    supervised_batch: torch.Tensor, supervised_labels: torch.Tensor,
    putter: Putter, getter: Getter, freeze_put: bool = False, freeze_get: bool = False,
    epochs: int = 20, lr: float = 0.001, classes: int = 10,
    im_cmp_l2: float = 1.0, im_cmp_l1: float = 0.0, im_cmp_ssim: float = 0.0,
    w_getput: float = 100.0, w_putget: float = 100.0, w_putput: float = 100.0,
    w_undo: float = 100.0, w_dist1: float = 10.0, w_dist2: float = 10.0, 
    w_supervised: float = 0.0, w_classifier: float = 0.0
):
    def image_loss(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        loss = im_cmp_l2 * torch.nn.functional.mse_loss(a, b)
        if im_cmp_l1 != 0.0:
            loss = loss + im_cmp_l1 * torch.nn.functional.l1_loss(a, b)
        if im_cmp_ssim != 0.0:
            loss = loss + im_cmp_ssim * (1.0 - piq.ssim(a, b))
        return loss
    
    def evaluate(im: torch.Tensor, lab_true: torch.Tensor, losses: list, accs: list) -> torch.Tensor:
        # GetPut
        lab = getter.forward(im)
        recon = putter.forward(im, lab)
        loss_getput = image_loss(im, recon)
        losses[0] += loss_getput
        # PutGet
        fake_labs = torch.randint(0, classes, (im.shape[0],), device=device)
        fake_vec = torch.nn.functional.one_hot(fake_labs, classes)
        im1 = putter.forward(im, fake_vec)
        pred_vec = getter.forward(im1)
        loss_putget = torch.nn.functional.nll_loss(pred_vec.clamp(min=1e-10).log(), fake_labs)
        acc_putget = (pred_vec.argmax(-1) == fake_labs).to(torch.float32).mean().item()
        losses[1] += loss_putget
        accs[0] += acc_putget
        # PutPut
        fake_labs2 = torch.randint(0, classes, (im.shape[0],), device=device)
        fake_vec2 = torch.nn.functional.one_hot(fake_labs2, classes)
        im1 = putter.forward(im1, fake_vec2)
        im2 = putter.forward(im, fake_vec2)
        loss_putput = image_loss(im1, im2)
        losses[2] += loss_putput
        # Undo
        im2 = putter.forward(im1, lab)
        loss_undo = image_loss(im, im2)
        losses[3] += loss_undo
        # Dist
        # Maximize entropy across images
        loss_dist1 = math.log(classes) - (-(lab.mean(dim=0) * lab.mean(dim=0).clamp(min=1e-10).log()).sum())
        losses[4] += loss_dist1
        # Minimize entropy within images
        loss_dist2 = (-(lab * lab.clamp(min=1e-10).log()).sum(dim=-1)).mean()
        losses[5] += loss_dist2
        # Small classifier
        if w_supervised != 0:
            supervised_preds = getter.forward(supervised_batch)
            loss_supervised = torch.nn.functional.nll_loss(supervised_preds.clip(min=1e-10).log(), supervised_labels)
            acc_supervised = (supervised_preds.argmax(-1) == supervised_labels).to(torch.float32).mean().item()
            losses[6] += loss_supervised
            accs[1] += acc_supervised
        else:
            loss_supervised = 0.0
        # Big classifier
        if w_classifier != 0:
            loss_classifier = torch.nn.functional.nll_loss(lab.clip(min=1e-10).log(), lab_true)
            acc_classifier = (lab.argmax(-1) == lab_true).to(torch.float32).mean().item()
            losses[7] += loss_classifier
            accs[2] += acc_classifier
        else:
            loss_classifier = 0.0
        loss = w_getput * loss_getput + w_putget * loss_putget + w_putput * loss_putput \
            + w_undo * loss_undo + w_dist1 * loss_dist1 + w_dist2 * loss_dist2 + w_supervised * loss_supervised \
            + w_classifier * loss_classifier
        return loss

    params = []
    if not freeze_get:
        params += list(p for p in getter.parameters() if p.requires_grad == True)
    if not freeze_put:
        params += list(p for p in putter.parameters() if p.requires_grad == True)
    optim = torch.optim.Adam(params, lr=lr)

    for ep in range(epochs):
        losses = [0.0]*8
        accs = [0.0]*3
        for im, lab_true in tqdm.tqdm(train):
            optim.zero_grad()
            loss = evaluate(im, lab_true, losses, accs)
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        accs = [l / len(train) for l in accs]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} train: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

        losses = [0.0]*8
        accs = [0.0]*3
        for im, lab_true in tqdm.tqdm(val):
            with torch.no_grad():
                evaluate(im, lab_true, losses, accs)
        losses = [l / len(val) for l in losses]
        accs = [l / len(val) for l in accs]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} val: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

# Train task (g) from Appendix D.3, learn a Getter from a VAE decoder
def train_getter_from_vae(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader, device: torch.device,
    decoder: VAEDecoder, getter: Getter,
    epochs: int = 20, lr: float = 0.001,
):
    def evaluate(_im: torch.Tensor, _lab: torch.Tensor, losses: list, accs: list) -> torch.Tensor:
        # Classifier
        with torch.no_grad():
            im, lab = decoder.forward(torch.randn((_im.shape[0], decoder.latent), device=device))
        preds = getter.forward(im)
        loss_classify = torch.nn.functional.nll_loss(preds.clamp(min=1e-10).log(), lab.argmax(-1))
        acc_classify = (preds.argmax(-1) == lab.argmax(-1)).to(torch.float32).mean().item()
        losses[0] += loss_classify
        accs[0] += acc_classify
        return loss_classify

    optim = torch.optim.Adam(list(p for p in getter.parameters() if p.requires_grad == True), lr=lr)

    for ep in range(epochs):
        losses = [0.0]*1
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(train):
            optim.zero_grad()
            loss = evaluate(im, lab, losses, accs)
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        accs = [l / len(train) for l in accs]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} train: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

        losses = [0.0]*1
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(val):
            with torch.no_grad():
                evaluate(im, lab, losses, accs)
        losses = [l / len(val) for l in losses]
        accs = [l / len(val) for l in accs]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} val: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

# Train task (f) from Appendix D.3, learn a VAE from a Getter.
# w_lab controls the strength of the label reconstruction loss, w_im_mse and w_im_ssim
# control the L2 and SSIM terms in the image reconstruction loss. w_kl controls the
# KL-divergence regularization term (care must be taken in picking this parameter).
def train_vae_from_getter(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader,
    encoder: VAEEncoder, decoder: VAEDecoder, getter: Getter,
    epochs: int = 40, lr: float = 0.001,
    w_lab: float = 1.0, w_im_mse: float = 100.0, w_im_ssim: float = 0.0, w_kl: float = 0.5
):
    def evaluate(im: torch.Tensor, _lab: torch.Tensor, losses: list, accs: list) -> torch.Tensor:
        # Reconstruction
        lab = getter.forward(im).detach().argmax(-1)
        means, log_var = encoder.forward(im, lab)
        latent = encoder.sample(means, log_var)
        recon, preds = decoder.forward(latent)
        loss_recon_lab = torch.nn.functional.nll_loss(preds.clip(min=1e-10).log(), lab)
        acc_recon_lab = (preds.argmax(-1) == lab).to(torch.float32).mean().item()
        loss_recon_im_mse = torch.nn.functional.mse_loss(recon, im)
        loss_recon_im_ssim = 1.0 - piq.ssim(recon, im)
        losses[0] += loss_recon_lab
        accs[0] += acc_recon_lab
        losses[1] += loss_recon_im_mse
        losses[2] += loss_recon_im_ssim
        # KL div
        loss_kl = -0.5 * torch.sum(1 + log_var - means.pow(2) - log_var.exp(), dim=-1).mean(dim=0)
        losses[3] += loss_kl
        loss = w_lab * loss_recon_lab + w_im_mse * loss_recon_im_mse + w_im_ssim * loss_recon_im_ssim + w_kl * loss_kl
        return loss

    optim = torch.optim.Adam(
        list(p for p in encoder.parameters() if p.requires_grad == True) + \
        list(p for p in decoder.parameters() if p.requires_grad == True), lr=lr
    )

    for ep in range(epochs):
        losses = [0.0]*4
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(train):
            optim.zero_grad()
            loss = evaluate(im, lab, losses, accs)
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        accs = [l / len(train) for l in accs]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} train: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

        losses = [0.0]*4
        accs = [0.0]*1
        for im, lab in tqdm.tqdm(val):
            with torch.no_grad():
                evaluate(im, lab, losses, accs)
        losses = [l / len(val) for l in losses]
        accs = [l / len(val) for l in accs]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} val: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

# Train a getter in a unsupervised/semi-supervised way. This is not included in the paper, but
# for example can get about 90% accuracy on MNIST with 50 labelled examples per class.
# w_dist1 and w_dist2 control the strength of the entropy regularization, w_supervised controls the 
# strength of semi-supervised classification loss.
def train_getter(
    train: torch.utils.data.DataLoader, val: torch.utils.data.DataLoader,
    supervised_batch: torch.Tensor, supervised_labels: torch.Tensor, getter: Getter,
    epochs: int = 20, lr: float = 0.001, classes: int = 10,
    w_dist1: float = 10.0, w_dist2: float = 10.0, w_supervised: float = 10.0
):
    def evaluate(im: torch.Tensor, losses: list, accs: list) -> torch.Tensor:
        lab = getter.forward(im)
        # Dist
        # Maximize entropy across images
        loss_dist1 = math.log(classes) - (-(lab.mean(dim=0) * lab.mean(dim=0).clamp(min=1e-10).log()).sum())
        losses[0] += loss_dist1
        # Minimize entropy within images
        loss_dist2 = (-(lab * lab.clamp(min=1e-10).log()).sum(dim=-1)).mean()
        losses[1] += loss_dist2
        # Small classifier
        supervised_preds = getter.forward(supervised_batch)
        loss_supervised = torch.nn.functional.nll_loss(supervised_preds.clip(min=1e-10).log(), supervised_labels)
        acc_supervised = (supervised_preds.argmax(-1) == supervised_labels).to(torch.float32).mean().item()
        losses[2] += loss_supervised
        accs[0] += acc_supervised
        loss = w_dist1 * loss_dist1 + w_dist2 * loss_dist2 + w_supervised * loss_supervised
        return loss

    params = list(p for p in getter.parameters() if p.requires_grad == True)
    optim = torch.optim.Adam(params, lr=lr)

    for ep in range(epochs):
        losses = [0.0]*3
        accs = [0.0]*1
        for im, _ in tqdm.tqdm(train):
            optim.zero_grad()
            loss = evaluate(im, losses, accs)
            loss.backward()
            optim.step()
        losses = [l / len(train) for l in losses]
        accs = [l / len(train) for l in accs]
        print(f"{ep} train: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} train: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

        losses = [0.0]*3
        accs = [0.0]*1
        for im, _ in tqdm.tqdm(val):
            with torch.no_grad():
                evaluate(im, losses, accs)
        losses = [l / len(val) for l in losses]
        accs = [l / len(val) for l in accs]
        print(f"{ep} val: {', '.join(f'[{i}] = {l:.4g}' for i, l in enumerate(losses))}")
        print(f"{ep} val: {', '.join(f'[{i}] = {l*100.0:.1f}%' for i, l in enumerate(accs))}")

# Plot some illustrative example from a VAE. Returns two figures, one
# showing some examples with their reconstructions, and the second showing
# generated images with their labels in a grid. 
def plot_vae(
    val: torch.utils.data.DataLoader, device: torch.device,
    encoder: Encoder, decoder: Decoder,
    num_recon: int = 12, num_generate: "tuple[int, int]" = (5, 5)
):
    ex_im, ex_lab = next(iter(val))
    latent = encoder.sample(*encoder.forward(ex_im, ex_lab))
    rec = decoder.forward(latent)[0].detach().cpu().numpy()
    num_examples = min(ex_im.shape[0], num_recon)
    fig1, (axes1, axes2) = plt.subplots(2, num_examples, sharex=True, sharey=True, figsize=(num_examples, 2))
    for i in range(num_examples):
        axes1[i].imshow(ex_im[i, 0, :, :].cpu().numpy(), aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
        axes2[i].imshow(rec[i, 0, :, :], aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
    fig1.tight_layout()

    rec, labs = decoder.forward(torch.randn((num_generate[0] * num_generate[1], decoder.latent), device=device))
    fig2, axes = plt.subplots(num_generate[0], num_generate[1], sharex=True, sharey=True, figsize=(num_generate[0], num_generate[1]))
    for i in range(num_generate[0]):
        for j in range(num_generate[1]):
            axes[i][j].set_title(str(labs[i * num_generate[1] + j, :].argmax(-1).item()))
            axes[i][j].imshow(rec[i * num_generate[1] + j, 0, :, :].detach().cpu().numpy(), aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
    fig2.tight_layout()

    return fig1, fig2

# Plot some random example images along with their reconstructions from an autoencoder.
# Returns a matplotlib figure.
def plot_autoencoder(
    val: torch.utils.data.DataLoader, encoder: Encoder, decoder: Decoder, num_examples: int = 12
):
    ex_im, _ = next(iter(val))
    latent = encoder.forward(ex_im)
    rec = decoder.forward(latent).detach().cpu().numpy()
    num_examples = min(num_examples, ex_im.shape[0])
    fig, (axes1, axes2) = plt.subplots(2, num_examples, sharex=True, sharey=True, figsize=(num_examples, 2))
    for i in range(num_examples):
        axes1[i].imshow(ex_im[i, 0, :, :].cpu().numpy(), aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
        axes2[i].imshow(rec[i, 0, :, :], aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
    fig.tight_layout()

    return fig

# Plot some representative example from a putter. Returns two figures, the first
# shows the results of GetPut by comparing an image with its reconstruction, as well as the
# true label and predicted label as the image titles. The second plot shows the result of
# putting all possible classes onto some random example images.
def plot_putter(
    val: torch.utils.data.DataLoader, getter: Getter, putter: Putter, device: torch.device, 
    num_getput: int = 12, num_put: int = 5, classes: int = 10
):
    ex_im, ex_lab = next(iter(val))
    vec = getter.forward(ex_im)
    preds = vec.argmax(-1).detach().cpu().numpy()
    rec = putter.forward(ex_im, vec).detach().cpu().numpy()
    num_getput = min(num_getput, ex_im.shape[0])
    fig1, (axes1, axes2) = plt.subplots(2, num_getput, sharex=True, sharey=True, figsize=(num_getput, 3))
    for i in range(num_getput):
        axes1[i].imshow(ex_im[i, 0, :, :].cpu().numpy(), aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
        axes1[i].set_title(f"{ex_lab[i]}")
        axes2[i].imshow(rec[i, 0, :, :], aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
        axes2[i].set_title(f"{preds[i]}")
    fig1.tight_layout()

    ex_im, ex_lab = next(iter(val))
    num_put = min(num_put, ex_im.shape[0])
    fig2, axes = plt.subplots(num_put, classes + 1, sharex=True, sharey=True, figsize=(classes + 1, num_put))
    for j in range(5):
        axes[j][0].imshow(ex_im[j, 0, :, :].cpu().numpy(), aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
        for i in range(classes):
            putted = putter.forward(ex_im[[j], :, :, :], torch.nn.functional.one_hot(torch.tensor([i], device=device), classes))
            axes[j][i+1].imshow(putted[0, 0, :, :].detach().cpu().numpy(), aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys')
    fig2.tight_layout()

    return fig1, fig2

# Plot some representative examples from a getter. Returns accuracy and a figure with three
# subplots, the leftmost shows the probability vector for the first image in the dataset.
# The middle plot shows the distribution of label classes over the whole dataset, and
# the rightmost plot shows the confusion matrix of predicted vs true labels.
def plot_getter(
        images: torch.Tensor, labels: torch.Tensor, getter: Getter, classes: int = 10
):
    vec = getter.forward(images)
    labels2 = vec.argmax(-1).detach().cpu().numpy()
    fig, (axes1, axes2, axes3) = plt.subplots(1, 3, figsize=(12, 4))
    axes1.plot(vec[0, :].detach().cpu().numpy())
    axes2.hist(labels2)
    conf = sklearn.metrics.confusion_matrix(labels.cpu().numpy(), labels2, labels=np.arange(classes), normalize='true')
    acc = np.mean((labels.cpu().numpy() == labels2).astype(np.float32))
    axes3.imshow(conf, vmin=0.0, vmax=1.0)
    axes3.set_title(f"Accuracy = {acc*100.0:.1f}%")
    fig.tight_layout()
    return acc, fig
