import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torch.optim import Adam

import mnist_config
from tqdm import tqdm
import matplotlib.pyplot as plt


if __name__ == '__main__':
    model_type = 'mlp_map'
    print("Model Type: {}".format(model_type))

    transform = Compose([ToTensor(), Normalize(mean=.5, std=.5)])
    train_data = MNIST("./data/", train=True, transform=transform, download=True)

    n_seeds = 10
    n_classes = 10
    n_channels = 1
    n_height = 28
    n_width = 28
    n_features = n_channels * n_height * n_width
    max_precision = 60000
    args = (n_classes, n_features, max_precision)
    config = getattr(mnist_config, model_type)
    kwargs = config['kwargs']

    initial_seed = 12345
    for seed in range(n_seeds):
        print()
        print("Seed: {}/{}".format(seed + 1, n_seeds))
        torch.manual_seed(initial_seed + seed)
        model = config['model'](*args, **kwargs)

        train_data_loader = DataLoader(train_data, batch_size=config['n_batch'], shuffle=True)
        optimizer = Adam(model.parameters(), lr=config['lr'])
        nll_trace = []
        fkl_trace = []
        model.train()
        for i in range(config['epochs']):
            print()
            print("Epoch: {}/{}".format(i + 1, config['epochs']))
            for x, y in tqdm(train_data_loader):
                current_batch_size = x.shape[0]
                optimizer.zero_grad()
                raw_logits = model(x)
                nll = model.nll_loss(raw_logits, y)
                nll_trace.append(nll.item())
                loss = nll
                if callable(getattr(model, 'fkl_loss', None)):
                    rot90 = x.transpose(-2, -1).contiguous()
                    rot180 = x.flip(-2).contiguous()
                    raw_logits_s = model(torch.cat([rot90, rot180], dim=0))
                    fkl = model.fkl_loss(raw_logits_s) / config['n_batch']
                    fkl_trace.append(fkl.item())
                    loss += fkl
                loss.backward()
                optimizer.step()
        torch.save(model.state_dict(), "./state_dicts/MNIST_{}_{}.pt".format(model_type, seed))
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        ax1.plot(nll_trace, color='tab:blue')
        ax2 = ax1.twinx()
        ax2.plot(fkl_trace, color='tab:orange')
        plt.title("MNIST, {}".format(model_type))
        plt.ylabel("Loss")
        plt.xlabel("Batch Iterations, n_batch = {}, epochs = {}".format(config['n_batch'], config['epochs']))
        plt.savefig("./plots/MNIST_{}_{}.png".format(model_type, seed))
