import os
import time
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler


from utils import load_model, load_dataset, parse_args, setup_seed, logits_accuracy, test, PGD10

args = parse_args("StableAT", {"--C": {"default": 3, "type": int}})
setup_seed(args.seed)
device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")

net = load_model(args).to(device)
train_loader, test_loader = load_dataset(args)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)

@torch.enable_grad()
def train(epoch, model, dataloader, criterion, optimizer):
    Acc, Loss = 0, 0
    with tqdm(enumerate(dataloader), total=dataloader.__len__(), desc=f"Train Epoch {epoch}") as t:
        for i, (images, labels) in t:
            images, labels = images.to(device), labels.to(device)

            model.eval()
            delta = torch.zeros_like(images)
            delta.requires_grad_(True)
            loss = criterion(model(images + delta), labels)
            grad = torch.autograd.grad(loss, delta, only_inputs=True)[0].detach()
            delta = delta.detach() + args.epsilon * grad.sign()
            delta = torch.clamp(delta, - args.epsilon, args.epsilon)
            images_adv = torch.clamp(images + delta, 0.0, 1.0)
            delta = (images_adv - images).clone()

            flags = torch.ones_like(labels, dtype=torch.bool)
            with torch.no_grad():
                for k in range(args.C):
                    if flags.sum().item() == 0: 
                        break
                    imgs = torch.clamp(images[flags] + k/args.C * delta[flags], 0.0, 1.0)
                    f = (model(imgs).argmax(dim=-1) != labels[flags])
                    ff = flags.nonzero(as_tuple=False)[f].squeeze()
                    images_adv[ff] = imgs[f]
                    flags[ff] = False
                delta = images_adv - images

            model.train()
            logits = model(images + delta)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            Loss += loss.item()
            Acc += logits_accuracy(logits.detach(), labels)

            t.set_postfix({"Loss": f"{Loss/(i+1):6.2f}", "Acc": f"{Acc/(i+1):4.2%}"})

    return Acc/(i+1), Loss/(i+1)

times = []
for epoch in range(1, args.epoch_num+1):
    t = time.time()
    train_ori, train_adv = train(epoch, net, train_loader, criterion, optimizer)
    times.append(time.time() - t)
    torch.save(net.state_dict(), os.path.join(args.result_path, f"model_{args.seed}.pth"))
    scheduler.step()
    if epoch%10 == 0:
        acc_clean, acc_pgd = test(epoch, net, test_loader, PGD10)

print(f"CLEAN: {acc_clean:6.2%}, PGD10: {acc_pgd:6.2%}\nTime Consumed: {round(np.mean(times), 1)} sec/epoch")





