import os
import time
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler


from utils import load_model, load_dataset, parse_args, setup_seed, logits_accuracy, test, pgd

args = parse_args("FLAT+", {
    "--lam": {"default": 4.0, "type": float}, 
    "--k": {"default": 0.0, "type": float}, 
    "--alpha": {"default": 0.1, "type": float},
    "--beta": {"default": 0.1, "type": float},
    "--Dyn": {"default": False, "action": "store_true"},
    "--Rk": {"default": False, "action": "store_true"}
})

setup_seed(args.seed)
device = torch.device("cuda" if args.cuda and torch.cuda.is_available() else "cpu")

net = load_model(args).to(device)
train_loader, test_loader = load_dataset(args)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)

@torch.enable_grad()
def train(lam, epoch, model, dataloader, criterion, optimizer):
    Acc, Loss = 0, 0
    Ratio, Lpr = 0, 0
    CE = nn.CrossEntropyLoss(reduction="none")
    radius = int(args.epsilon*255)
    if args.k>0 and args.k < 1.:
        k = args.k
    else:
        k = 1/radius
    
    with tqdm(enumerate(dataloader), total=dataloader.__len__(), desc=f"Train Epoch {epoch:3d} {lam}") as t:
        for i, (images, labels) in t:
            images, labels = images.to(device), labels.to(device)
            zeros = torch.zeros_like(labels)
            
            model.eval()
            delta = torch.zeros_like(images)
            delta.requires_grad_(True)
            loss = criterion(model(images + delta), labels)
            grad = torch.autograd.grad(loss, [delta])[0].detach()
            delta = args.epsilon * grad.sign()
            delta = torch.clamp(images + delta, 0.0, 1.0) - images

            if args.Rk:
                # k = np.random.choice([1, 1, 2, 2, 4, 6])/int(args.epsilon*255)
                # k = np.random.choice([1, 2, 3, 4], p=[0.3, 0.3, 0.2, 0.2])/int(args.epsilon*255)
                # k = np.random.choice(np.arange(1, radius, 2))/radius
                k = np.random.choice([0.1, 0.2, 0.4, 0.8], p=[0.4, 0.3, 0.2, 0.1])
            eta = k * delta

            model.train()
            logits_ori, logits_eta, logits_adv = model(images), model(images + eta), model(images + delta)
            loss_ori, loss_eta, loss_adv = CE(logits_ori, labels), CE(logits_eta, labels), CE(logits_adv, labels)
            LPR = ((1-k)*loss_ori +  k*loss_adv - loss_eta)/k
            # 
            loss = (loss_adv + args.beta * loss_ori - lam * torch.minimum(zeros, LPR)+ args.alpha * torch.maximum(zeros, LPR)).mean()

            Ratio += (LPR.detach()<=0.0).sum().item()/labels.size(0)
            Lpr += LPR.detach().mean().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            Loss += loss.item()
            Acc += logits_accuracy(logits_adv.detach(), labels)

            t.set_postfix({"Loss": f"{Loss/(i+1):6.3f}", "Acc": f"{Acc/(i+1):4.2%}", "LPR": f"{Lpr/(i+1): 6.3f}", "Ratio": f"{Ratio/(i+1): 4.2%}"})

    return Acc/(i+1), Loss/(i+1)

times = []
lr = optimizer.state_dict()["param_groups"][0]["lr"]
lam = args.lam
for epoch in range(1, args.epoch_num+1):
    t = time.time()
    train_ori, train_adv = train(lam, epoch, net, train_loader, criterion, optimizer)
    times.append(time.time() - t)
    torch.save(net.state_dict(), os.path.join(args.result_path, f"model_{args.seed}.pth"))
    scheduler.step()
    if args.Dyn and lr != optimizer.state_dict()["param_groups"][0]["lr"]:
        lr = optimizer.state_dict()["param_groups"][0]["lr"]
        lam += args.lam/2

def PGD(model, images, labels):
    return pgd(model, images, labels, steps=10, step_size=2/255, epsilon=args.epsilon, restarts=1)
acc_clean, acc_pgd = test(epoch, net, test_loader, PGD)
print(f"CLEAN: {acc_clean:6.2%}, PGD10: {acc_pgd:6.2%}\nTime Consumed: {round(np.mean(times), 1)} sec/epoch")







