"""Train CIFAR10 with PyTorch."""

from __future__ import print_function

import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time
from models import *
from adabound import AdaBound
from torch.optim import SGD
from optimizers import *
from torchvision.transforms.functional import rotate
from copy import deepcopy
import random
import numpy as np
import torch
from common import create_optimizer, get_grad_norm, get_ckpt_name

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

def rotate_batch(batch, angle):
    return torch.stack([rotate(img, angle) for img in batch])


def get_parser():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 Training")
    parser.add_argument(
        "--total_epoch", default=100, type=int, help="Total number of training epochs"
    )
    parser.add_argument(
        "--model",
        default="vgg",
        type=str,
        help="model",
        choices=["resnet", "densenet", "vgg"],
    )
    parser.add_argument(
        "--optim",
        default="sgd",
        type=str,
        help="optimizer",
        choices=[
            "sgd",
            "adam",
            "adamw",
            "adabelief",
            "yogi",
            "msvag",
            "radam",
            "fromage",
            "adabound",
            "cadam",
            "cadamw",
            "cadamw-all",
            "amsgrad",
            "camsgrad",
        ],
    )
    parser.add_argument("--run", default=0, type=int, help="number of runs")
    parser.add_argument("--lr", default=0.1, type=float, help="learning rate")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="learning rate")
    parser.add_argument(
        "--final_lr", default=0.1, type=float, help="final learning rate of AdaBound"
    )
    parser.add_argument(
        "--gamma", default=1e-3, type=float, help="convergence speed term of AdaBound"
    )

    parser.add_argument("--eps", default=1e-8, type=float, help="eps for var adam")

    parser.add_argument("--momentum", default=0.9, type=float, help="momentum term")
    parser.add_argument(
        "--beta1", default=0.9, type=float, help="Adam coefficients beta_1"
    )
    parser.add_argument(
        "--beta2", default=0.999, type=float, help="Adam coefficients beta_2"
    )
    parser.add_argument(
        "--resume", "-r", action="store_true", help="resume from checkpoint"
    )
    parser.add_argument("--batchsize", type=int, default=128, help="batch size")
    parser.add_argument(
        "--weight_decay", default=5e-4, type=float, help="weight decay for optimizers"
    )
    parser.add_argument(
        "--reset",
        action="store_true",
        help="whether reset optimizer at learning rate decay",
    )
    parser.add_argument(
        "--noise_level",
        default=0,
        type=float,
        help="noise level",
    )
    return parser


def build_dataset(args):
    print("==> Preparing data..")
    transform_train = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    transform_test = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )

    trainset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform_train
    )
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batchsize, shuffle=True, num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=args.batchsize, shuffle=False, num_workers=2
    )

    # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    return train_loader, test_loader



def load_checkpoint(ckpt_name):
    print("==> Resuming from checkpoint..")
    path = os.path.join("checkpoint", ckpt_name)
    assert os.path.isdir("checkpoint"), "Error: no checkpoint directory found!"
    assert os.path.exists(path), "Error: checkpoint {} not found".format(ckpt_name)
    return torch.load(path)


def build_model(args, device, ckpt=None):
    print(f"==> Building model {args.model}..")
    net = {
        "resnet": ResNet34,
        "densenet": DenseNet121,
        "vgg": vgg11,
    }[args.model]()
    net = net.to(device)
    if device == "cuda":
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if ckpt:
        net.load_state_dict(ckpt["net"])

    return net




def train(net, epoch, device, data_loader, optimizer, criterion, noise_level=0):
    print("\nEpoch: %d" % epoch, "noise level: ", noise_level)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    statistics = []
    perturb_idxes = np.random.choice(len(data_loader), int(len(data_loader) * noise_level), replace=False)
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        if batch_idx in perturb_idxes:
            targets = torch.randint(0, 10, targets.size())
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        local_correct = predicted.eq(targets).sum().item()
        correct += local_correct
        statistics.append(deepcopy(optimizer.statistic))
        statistics[-1]["grad_norm"] = get_grad_norm(net)
        statistics[-1]["loss"] = loss.item()
        statistics[-1]["accuracy"] = 100.0 * local_correct / targets.size(0)
        statistics[-1]["is_perturbed"] = batch_idx in perturb_idxes

    accuracy = 100.0 * correct / total
    print("train acc %.3f" % accuracy)

    return accuracy, statistics


def test(net, device, data_loader, criterion, rotate_angle=0):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            if rotate_angle != 0:
                inputs = rotate_batch(inputs, rotate_angle)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    accuracy = 100.0 * correct / total
    print(" test acc %.3f" % accuracy)

    return accuracy



def main():
    parser = get_parser()
    args = parser.parse_args()

    train_loader, test_loader = build_dataset(args)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ckpt_name = get_ckpt_name(
        model=args.model,
        optimizer=args.optim,
        lr=args.lr,
        final_lr=args.final_lr,
        momentum=args.momentum,
        beta1=args.beta1,
        beta2=args.beta2,
        gamma=args.gamma,
        eps=args.eps,
        reset=args.reset,
        run=args.run,
        weight_decay=args.weight_decay,
    )
    noise_level = args.noise_level
    CURVE_DIR = f"curve_no_rotate_{noise_level}"
    print("ckpt_name: ", ckpt_name)
    ckpt = None
    best_acc = 0
    start_epoch = -1
    train_accuracies = []
    test_accuracies = []
    statistics = []

    net = build_model(args, device, ckpt=ckpt)
    criterion = nn.CrossEntropyLoss()
    optimizer = create_optimizer(args, net.parameters())
    
    for epoch in range(start_epoch + 1, args.total_epoch):
        start = time.time()
        train_acc, local_statistics = train(net, epoch, device, train_loader, optimizer, criterion, noise_level)
        test_acc = test(net, device, test_loader, criterion)
        end = time.time()
        print("Time: {}".format(end - start))
        print("Estimated time Left: {} min".format((end - start) * (args.total_epoch - epoch - 1)/60))

        # Save checkpoint.
        if test_acc > best_acc:
            print("Saving..")
            state = {
                "net": net.state_dict(),
                "acc": test_acc,
                "epoch": epoch,
            }
            if not os.path.isdir("checkpoint"):
                os.mkdir("checkpoint")
            torch.save(state, os.path.join("checkpoint", ckpt_name))
            best_acc = test_acc

        train_accuracies.append(train_acc)
        test_accuracies.append(test_acc)
        statistics.append(deepcopy(local_statistics))
        if not os.path.isdir(CURVE_DIR):
            os.mkdir(CURVE_DIR)
        torch.save(
            {"train_acc": train_accuracies, "test_acc": test_accuracies,  "statistics": statistics},
            os.path.join(CURVE_DIR, ckpt_name),
        )


if __name__ == "__main__":
    main()
