import argparse
import gc
import importlib
import os

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import torch.nn.functional as F

# mainly from https://github.com/pytorch/examples/blob/main/mnist/main.py


def train(args, model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, num_step = model(data)
        loss = criterion(output/num_step, F.one_hot(target, output.shape[1]).float())
        loss.backward()
        optimizer.step()
        gc.collect()
        if batch_idx % args.log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )
            if args.dry_run:
                break


def validate(args, model, device, val_loader, criterion):
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output, num_step = model(data)
            val_loss += criterion(output/num_step, F.one_hot(target, output.shape[1]).float()).sum().item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            gc.collect()
            if args.dry_run:
                break
            print(f"val_loss: {val_loss}")
        acc = correct / len(val_loader.dataset)
        print(f"val_acc: {acc}")

    return val_loss


# def test(model, device, test_loader):
#     model.eval()
#     test_loss = 0
#     correct = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
#             pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()

#     test_loss /= len(test_loader.dataset)

#     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#         test_loss, correct, len(test_loader.dataset),
#         100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch Parser")
    parser.add_argument(
        "--model-file",
        type=str,
        default="model_CIFAR10",
        metavar="M",
        choices=["model_CIFAR10", "model_MNIST", "model_MLP"],
        help="name of the model file",
    )
    parser.add_argument(
        "--modelname",
        type=str,
        default="FINAL_ourmodel_PLIF_PTRACE",
        metavar="M",
        help="model to use ",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="CIFAR10",
        metavar="D",
        choices=["MNIST", "FMNIST", "CIFAR10"],
        help="dataset to use ",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=300,
        metavar="N",
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="Adam",
        metavar="O",
        choices=["SGD", "Adam"],
        help="optimizer to use (default: Adam)",
    )
    parser.add_argument(
        "--lr", type=float, default=0.0001, metavar="LR", help="learning rate (default: 1e-4)"
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
    )
    parser.add_argument(
        "--dry-run", action="store_true", default=False, help="quickly check a single pass"
    )
    parser.add_argument(
        "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
    )
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument(
        "--save-model", action="store_true", default=False, help="For Saving the current Model"
    )
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    # find the model file
    model_file = args.model_file + ".py"
    # check if the model file exists
    if not os.path.isfile(model_file):
        print(f"Model file {model_file} does not exist")
        exit()

    # get the model class
    Net = getattr(importlib.import_module(args.model_file), args.modelname)
    criterion = torch.nn.MSELoss()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    train_kwargs = {"batch_size": args.batch_size}

    if use_cuda:
        cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
        train_kwargs.update(cuda_kwargs)

    transform = transforms.Compose([transforms.ToTensor()])
    if args.dataset == "MNIST":
        dataset = datasets.MNIST("../data", train=True, download=True, transform=transform)
        train_val_split = (55000, 5000)
    elif args.dataset == "FMNIST":
        dataset = datasets.FashionMNIST("../data", train=True, download=True, transform=transform)
        train_val_split = (55000, 5000)
    elif args.dataset == "CIFAR10":
        dataset = datasets.CIFAR10("../data", train=True, download=True, transform=transform)
        train_val_split = (45000, 5000)
    else:
        print("Dataset not found")
        exit()

    train_set, val_set = random_split(
        dataset,
        train_val_split,
        generator=torch.Generator().manual_seed(args.seed),
    )
    train_loader = DataLoader(train_set, **train_kwargs)
    val_loader = DataLoader(val_set, **train_kwargs)

    model = Net().to(device)
    # optimizer
    if args.optimizer == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=args.lr)
    else:
        print("Optimizer not found")
        exit()
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=10, verbose=True)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, criterion, epoch)
        val_loss = validate(args, model, device, val_loader, criterion)
        scheduler.step(val_loss)

    if args.save_model:
        torch.save(model.state_dict(), "out.pt")


if __name__ == "__main__":
    main()
