from __future__ import print_function
import argparse
import os

import numpy as np
import torch
import torch.nn.functional as F

from sonic_conv import conv_to_sonic
from dataset import get_loaders
from models import MODEL_DICT, get_model
from analysis import get_flops, get_inference_time
from utils import DTYPES

import gc


def train(args, model, device, train_loader, optimizer, epoch):
    criterion = torch.nn.CrossEntropyLoss()
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        gc.collect()
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        # output = F.log_softmax(output, dim=1)
        # loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


@torch.no_grad()
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            output = F.log_softmax(output, dim=1)
            test_loss += F.nll_loss(output.float(), target, reduction="sum").item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print("\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="MNIST", choices=["MNIST", "FashionMNIST", "CIFAR10", "CIFAR100", "ImageNet"])
    parser.add_argument("--model", type=str, default="LeNet", choices=MODEL_DICT.keys())
    parser.add_argument("--batch-size", type=int, default=128, metavar="N",
                        help="input batch size for training (default: 64)")
    parser.add_argument("--epochs", type=int, default=25, metavar="N", help="number of epochs to train")
    parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
    parser.add_argument("--gamma", type=float, default=0.7, metavar="M",
                        help="Learning rate step gamma (default: 0.7)")
    parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
    parser.add_argument("--dtype", type=str, default="fp16", choices=["fp32", "fp16"])
    parser.add_argument("--test-repeats", type=int, default=20)
    parser.add_argument("--seed", type=int, default=0, metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument("--log-interval", type=int, default=50, metavar="N",
                        help="how many batches to wait before logging training status")
    parser.add_argument("--sonic-conv", action="store_true")
    parser.add_argument("--sonic-resolution", type=str, default="adaptive")
    parser.add_argument("--analyze", action="store_true")
    parser.add_argument('--speedtest-batch-sizes', default=[8192], nargs='+', type=int)
    parser.add_argument('--lr-decay-epochs', default=[10, 15, 20], nargs='+', type=int)
    parser.add_argument("--lr-decay-rate", type=float, default=0.1)
    args = parser.parse_args()

    torch.set_default_dtype(DTYPES[args.dtype])

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    train_loader, val_loader = get_loaders(args.dataset, args.batch_size, args.analyze)
    args.input_shape = val_loader.dataset[0][0].shape
    args.num_classes = len(val_loader.dataset.classes)

    if args.sonic_resolution != "adaptive":
        args.sonic_resolution = int(args.sonic_resolution)

    model = get_model(args.model, args.input_shape, args.num_classes, args.device)
    if args.sonic_conv:
        conv_to_sonic(
            model=model,
            input_shape=args.input_shape,
            sonic_resolution=args.sonic_resolution,
            dtype=args.dtype,
            device=args.device,
            cache_name=f"{args.dataset}-{args.model}",
        )

    if args.analyze:
        flops = get_flops(
            model=model,
            input_shape=args.input_shape,
        )
        # inference_times = get_inference_time(
        #     model=model,
        #     input_shape=args.input_shape,
        #     batch_sizes=args.speedtest_batch_sizes,
        #     repeats=args.test_repeats,
        # )
    else:
        os.makedirs("checkpoints", exist_ok=True)
        # optimizer = torch.optim.Adadelta(model.parameters(), lr=args.lr)
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.gamma)
        best_acc = 0
        for epoch in range(1, args.epochs + 1):
            train(args, model, args.device, train_loader, optimizer, epoch)
            acc = test(model, args.device, val_loader)
            # scheduler.step()
            if best_acc < acc:
                best_acc = acc
                torch.save({
                    "state_dict": model.state_dict(),
                    "accuracy": best_acc,
                }, f"checkpoints/{args.dataset}_{args.model}{f'_sonic-{args.sonic_resolution}' if args.sonic_conv else ''}_seed{args.seed}.pt")
            if epoch in args.lr_decay_epochs:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= args.lr_decay_rate
        print(f"Best accuracy: {best_acc:0.2f}%")

        os.makedirs("logs", exist_ok=True)
        with open(f"logs/{args.dataset}_{args.model}{f'_sonic-{args.sonic_resolution}' if args.sonic_conv else ''}_B{args.batch_size}.txt", "a") as f:
            f.write(f"seed: {args.seed}, best_accuracy: {best_acc:0.2f}%\n")

if __name__ == "__main__":
    main()