import os
import time
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler


from models.yopo import PreActResNet18_yopo, ResNet18_yopo
from utils import load_dataset, parse_args, setup_seed, logits_accuracy, test, PGD10

args = parse_args("YOPO", {"--M": {"default": 5, "type": int}, "--N": {"default": 3, "type": int}})
setup_seed(args.seed)
device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")

net = ResNet18_yopo(10).to(device)
train_loader, test_loader = load_dataset(args)
criterion = nn.CrossEntropyLoss()

optimizer_first = optim.SGD(net.first_layer.parameters(), lr=args.lr/args.M, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler_first = lr_scheduler.MultiStepLR(optimizer_first, milestones=args.milestones, gamma=args.gamma)

optimizer_rest = optim.SGD(net.other_layers.parameters(), lr=args.lr/(2*args.M), momentum=args.momentum, weight_decay=args.weight_decay) 
scheduler_rest = lr_scheduler.MultiStepLR(optimizer_rest, milestones=args.milestones, gamma=args.gamma)

@torch.enable_grad()
def train(epoch, model, dataloader, criterion, optimizer_first, optimizer_rest):
    model.train()
    clean_acc, yopo_acc = 0, 0
    Loss = 0
    with tqdm(enumerate(dataloader), total=dataloader.__len__(), desc=f"Train Epoch {epoch}") as t:
        for i, (images, labels) in t:
            images, labels = images.to(device), labels.to(device)

            delta = torch.zeros_like(images).uniform_(-args.epsilon, args.epsilon)
            delta = torch.clamp(images + delta, 0, 1) - images
            optimizer_first.zero_grad()
            optimizer_rest.zero_grad()

            for j in range(args.M):
                logits = model(images + delta)
                loss = criterion(logits, labels)
                Loss += loss.item()
                loss.backward()
                p = -1.0 * net.first_layer_out.grad.detach()

                for _ in range(args.N):
                    delta.requires_grad_(True)
                    Hamilton = torch.sum(model.first_layer(images + delta) * p)
                    grad = torch.autograd.grad(Hamilton, delta, only_inputs=True, retain_graph=False)[0].sign()
                    delta = delta.detach() - args.step_size * grad.detach()
                    
                    delta = torch.clamp(delta, -args.epsilon, args.epsilon)
                    delta = torch.clamp(images + delta, 0.0, 1.0) - images

                with torch.no_grad():
                    if j == 0:
                        clean_acc += logits_accuracy(logits, labels)
                    elif j == args.M - 1:
                        yopo_acc += logits_accuracy(logits, labels)
                    else:
                        pass
            
            optimizer_first.step()
            optimizer_rest.step()
            t.set_postfix({"Loss": f"{Loss/(args.M*(i+1)):6.2f}", "CLEAN": f"{clean_acc/(i+1):4.2%}", "YOPO": f"{yopo_acc/(i+1):4.2%}"})

    return clean_acc/(i+1), yopo_acc/(i+1)

times = []
for epoch in range(1, args.epoch_num+1):
    t = time.time()
    train_ori, train_adv = train(epoch, net, train_loader, criterion, optimizer_first, optimizer_rest)
    times.append(time.time() - t)
    scheduler_first.step()
    scheduler_rest.step()
    torch.save(net.state_dict(), os.path.join(args.result_path, f"model_{args.seed}.pth"))

acc_clean, acc_pgd = test(epoch, net, test_loader, PGD10)
print(f"CLEAN: {acc_clean:6.2%}, PGD10: {acc_pgd:6.2%}\nTime Consumed: {round(np.mean(times), 1)} sec/epoch")







