#!/usr/bin/env python3

import math
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import pyfamilywise

import trackexp as tx

import argparse
import sys

parser = argparse.ArgumentParser(description='MNIST Training with Different Loss Functions')
parser.add_argument('--loss_func', type=str, default='CE', help='Loss function (default: CE)')
parser.add_argument('--epochs', type=int, default=20, help='# of epochs to train (default: 5)')
parser.add_argument('--trackexp_dir', type=str, default='mnist_gdtuo_example_output', help='Output directory for trackexp')
args = parser.parse_args()

tx.init()
tx.metadata(vars(args))


class MNIST_FullyConnected(nn.Module):
    """
    A fully-connected NN for the MNIST task. This is Optimizable but not itself
    an optimizer.
    """
    def __init__(self, num_inp, num_hid, num_out):
        super(MNIST_FullyConnected, self).__init__()
        self.layer1 = nn.Linear(num_inp, num_hid)
        self.layer2 = nn.Linear(num_hid, num_out)

    def initialize(self):
        nn.init.kaiming_uniform_(self.layer1.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.layer2.weight, a=math.sqrt(5))

    def forward(self, x):
        """Compute a prediction."""
        x = self.layer1(x)
        x = torch.tanh(x)
        x = self.layer2(x)
        x = torch.tanh(x)
        # Compared to the gdtuo github, we remove log_softmax since CrossEntropyLoss includes softmax
        return x

BATCH_SIZE = 256
"Total number of epochs to run"
EPOCHS = args.epochs
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
dl_train = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=True)
dl_test = torch.utils.data.DataLoader(mnist_test, batch_size=10000, shuffle=False)

model = MNIST_FullyConnected(28 * 28, 128, 10).to(DEVICE)

from gradient_descent_the_ultimate_optimizer import gdtuo

optim = gdtuo.Adam(optimizer=gdtuo.SGD(1e-5))

mw = gdtuo.ModuleWrapper(model, optimizer=optim)
mw.initialize()

if args.loss_func == 'CE':
    print("Using cross entropy loss")
    criterion = nn.CrossEntropyLoss()
elif args.loss_func == 'FW':
    print("Using familywise loss")
    criterion = pyfamilywise.FWLoss(num_classes = 10, device = DEVICE)
else:
    print("Loss not supported!")
    assert(False)

params_to_track = ['alpha', 'beta1', 'beta2']

for i in range(1, EPOCHS+1):
    tx.start_timer("training", i)
    running_loss = 0.0
    for j, (features_, labels_) in enumerate(dl_train):
        mw.begin()  # call this before each step, enables gradient tracking on desired params
        features, labels = torch.reshape(features_, (-1, 28 * 28)).to(DEVICE), labels_.to(DEVICE)
        pred = mw.forward(features)
        # Use CrossEntropyLoss instead of nll_loss
        loss = criterion(pred, labels)



        # track the optimizer
        for pname in params_to_track:
            tx.log("training", pname, i, mw.optimizer.parameters[pname].detach().item())

        mw.zero_grad()
        loss.backward(create_graph=True)  # important! use create_graph=True
        mw.step()
        running_loss += loss.item() * features_.size(0)
    train_loss = running_loss / len(dl_train.dataset)
    print("EPOCH: {}, TRAIN LOSS: {}".format(i, train_loss))

    tx.stop_timer("training",i)
    tx.log("training", "train_loss", i, train_loss)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for features_, labels_ in dl_test:
            features, labels = torch.reshape(features_, (-1, 28 * 28)).to(DEVICE), labels_.to(DEVICE)
            outputs = model(features)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = 100 * correct / total
    print("EPOCH: {}, TEST ACC: {:.2f}%".format(i, test_acc))
    tx.log("training", "test_acc", i, test_acc)
    model.train()
