# Modified from https://github.com/kuangliu/pytorch-cifar

import argparse
import os
import time
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from imbalance_cifar import IMBALANCECIFAR100


parser = argparse.ArgumentParser(description="PyTorch CIFAR100 Training")
parser.add_argument("--lr", default=0.1, type=float, help="learning rate")
parser.add_argument("--resume", "-r", action="store_true", help="resume from checkpoint")
parser.add_argument("--output-dir", default="./save", type=str)
parser.add_argument("--epochs", default=200, type=int)
parser.add_argument("--check-ckpt", default=None, type=str)
parser.add_argument("--long-tail", action="store_true", help="use long tail version of cifar")
parser.add_argument("--imb-factor", default=0.01, type=float)
parser.add_argument("--rand-value", type=int)
parser.add_argument("--fix-number", default=None, type=int)
parser.add_argument("--imb-type", default="exp", type=str)
parser.add_argument("--stage-number", default=-1, type=int)
parser.add_argument("--specific-classes", default=None, type=list)
parser.add_argument("--specific-number", default=-1, type=int)
args = parser.parse_args()

if args.specific_classes is None:
    args.specific_classes = list(range(20))

rand_value = args.rand_value
np.random.seed(rand_value)
random.seed(rand_value)
os.environ["PYTHONHASHSEED"] = str(rand_value)
torch.manual_seed(rand_value)
torch.cuda.manual_seed(rand_value)
torch.cuda.manual_seed_all(rand_value)


if args.check_ckpt:
    checkpoint = torch.load(args.check_ckpt)
    best_acc = checkpoint["acc"]
    start_epoch = checkpoint["epoch"]
    print(f"==> test ckp: {args.check_ckpt}, acc: {best_acc}, epoch: {start_epoch}")
    exit()


if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)


device = "cuda" if torch.cuda.is_available() else "cpu"
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print("==> Preparing data..")
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)


# trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
if not args.long_tail:
    args.imb_factor = 1
print("[MY LOG] Is long tail: ", args.long_tail)
print("[MY LOG] Imbalance factor: ", args.imb_factor)
print("[MY LOG] Fix number: ", args.fix_number)
trainset = IMBALANCECIFAR100(root="/nas/dataset/dataset_distillation/cifar/cifar100", train=True, imb_factor=args.imb_factor, rand_number=rand_value, download=False, transform=transform_train, fix_number=args.fix_number, imb_type=args.imb_type, stage_number=args.stage_number, specific_classes=args.specific_classes, specific_number=args.specific_number)
print(f"Each Class Number: {trainset.img_num_list}")
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root="./data/cifar100", train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# Model
print("==> Building model..")

model = torchvision.models.get_model("resnet18", num_classes=100)
#* as said by paper: the first 7x7 Conv layer is replaced by 3x3 Conv layer and the maxpool layer is discarded, following MoCo(CIFAR)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()


net = model.to(device)
if device == "cuda":
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print("==> Resuming from checkpoint..")
    assert os.path.isdir("checkpoint"), "Error: no checkpoint directory found!"
    checkpoint = torch.load("./checkpoint/ckpt.pth")
    net.load_state_dict(checkpoint["net"])
    best_acc = checkpoint["acc"]
    start_epoch = checkpoint["epoch"]

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)


# Train
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print(f"Epoch: [{epoch}], Acc@1 {100.*correct/total:.3f}, Loss {train_loss/(batch_idx+1):.4f}")


# Test
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print(f"Test: Acc@1 {100.*correct/total:.3f}, Loss {test_loss/(batch_idx+1):.4f}")

    # Save checkpoint.
    acc = 100.0 * correct / total
    # if acc > best_acc:
    # save last checkpoint
    if True:
        state = {
            "state_dict": net.state_dict(),
            "acc": acc,
            "epoch": epoch,
        }
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')

        if (epoch <= 50 and epoch % 10 == 0) or (epoch > 50 and epoch % 50 == 0):
            path = os.path.join(args.output_dir, f"./ckpt-{epoch}.pth")
            torch.save(state, path)
            best_acc = acc

        path = os.path.join(args.output_dir, "./ckpt.pth")
        torch.save(state, path)
        best_acc = acc


start_time = time.time()
for epoch in range(start_epoch, start_epoch + args.epochs):
    train(epoch)
    # fast test
    if epoch % 10 == 0 or epoch == args.epochs - 1:
        test(epoch)
    scheduler.step()
end_time = time.time()
print(f"total time: {end_time - start_time} s")
