import os
import argparse
import torch
import numpy as np
from cifar10_models import *
from attacks import *
from status import ProgressBar
from args import create_parser
try:
    import wandb
except ImportError:
    wandb = None
from utils import load_student
from autoattack import AutoAttack

parser = create_parser()
args = parser.parse_known_args()[0]

print(args)


save_dir=f'./result_models/{args.dataset}/{args.student}/'
os.makedirs(save_dir, exist_ok=True)


torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True

basepath = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if not args.nowand:
    assert wandb is not None, "Wandb not installed, please install it or run without wandb"
    wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), name=args.wandb_name, tags=[args.wandb_tags])
    args.wandb_url = wandb.run.get_url()


epochs = args.epochs
from utils import load_dataset, load_teacher
trainloader, testloader = load_dataset(args.dataset, args.batch)

student = load_student(args.student, args.dataset, args.depth, args.widen_factor).cuda()
student.train()

teacher = load_teacher(args.teacher, args.dataset).cuda()
teacher.eval()

initial_lr = args.lr
optimizer = torch.optim.SGD(student.parameters(), lr=initial_lr, momentum=0.9, weight_decay=2e-4)
progress_bar = ProgressBar()



TRAIN_PGD_STEPS = 10
EVAL_PGD_STEPS = 20
criterion_kl = nn.KLDivLoss(reduction="batchmean")
XENT_loss = nn.CrossEntropyLoss()
current_training_eps = int(args.alpha) /255
current_training_alpha = current_training_eps / 4
best_robust_acc = 0.0

for epoch in range(1, epochs + 1):
    student.train()
    train_robust_correct = 0
    train_clean_correct = 0 
    train_total = 0
    epoch_loss_sum = 0.0 
    for step, (X, y) in enumerate(trainloader):
        X, y = X.cuda().float(), y.cuda()
        inputs_adv = PGD(X, y, student, eps=current_training_eps, alpha=current_training_alpha, steps=TRAIN_PGD_STEPS)
        optimizer.zero_grad()
        with torch.no_grad():
            teacher_plus = teacher(X)
        student_plus = student(inputs_adv)
        loss = criterion_kl(F.log_softmax(student_plus, dim=1), F.softmax(teacher_plus, dim=1))
        loss.backward()
        optimizer.step()
        epoch_loss_sum += loss.item()

        with torch.no_grad():
            student_plus_clean = student(X)
            train_clean_correct += (student_plus_clean.argmax(1) == y).sum().item()
            train_robust_correct += (student_plus.argmax(1) == y).sum().item()
            train_total += y.size(0)
        progress_bar.prog(step, len(trainloader), epoch, loss.item())
    avg_epoch_loss = epoch_loss_sum / len(trainloader)
    train_robust_acc = train_robust_correct / train_total
    train_clean_acc = train_clean_correct / train_total
        
    if epoch in [100, 150]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1

    if epoch %1==0 :
        test_accs = []
        test_accs_adv = []
        student.eval()
        for step,(test_batch_data,test_batch_labels) in enumerate(testloader):
            test_ifgsm_data = PGD(test_batch_data, test_batch_labels, student, eps=current_training_eps, alpha=current_training_alpha, steps=20)
            student.eval()
            with torch.no_grad():
                logits = student(test_batch_data)
                logits_adv = student(test_ifgsm_data)
            
            predictions_adv = np.argmax(logits_adv.cpu().detach().numpy(),axis=1)
            predictions_adv = predictions_adv - test_batch_labels.cpu().detach().numpy()
            
            predictions = np.argmax(logits.cpu().detach().numpy(),axis=1)
            predictions = predictions - test_batch_labels.cpu().detach().numpy()
            
            test_accs = test_accs + predictions.tolist()
            test_accs_adv = test_accs_adv + predictions_adv.tolist()
        test_accs = np.array(test_accs)
        test_accs_adv = np.array(test_accs_adv)
        test_acc = np.sum(test_accs==0)/len(test_accs)
        test_acc_adv = np.sum(test_accs_adv==0)/len(test_accs_adv)
        print('PGD20 acc',test_acc_adv)
        if not args.nowand:
            d2={'clean_acc': np.round(test_acc,4), 'robust_acc': np.round(test_acc_adv,4), 'train_clean_acc': np.round(train_clean_acc,4), 'train_robust_acc': np.round(train_robust_acc,4), 'avg_epoch_loss': avg_epoch_loss, 'main_epoch' : epoch}
            wandb.log(d2)

        if test_acc_adv > best_robust_acc:
            best_robust_acc = test_acc_adv
            file_name_best = f"{args.wandb_name}_best.pt"
            save_path_best = os.path.join(save_dir, file_name_best)
            torch.save(student.state_dict(), save_path_best)
            print(f"--- New best model saved at epoch {epoch} with robust acc: {best_robust_acc:.4f} ---")
        # ---------------------------------------------

file_name = f"{args.wandb_name}.pt"
save_path = os.path.join(save_dir, file_name)
torch.save(student.state_dict(), save_path)
print(f"Model saved to: {save_path}")

test_accs = []
test_accs_pgd = []
test_accs_cw = []
test_accs_fgsm = []
student.eval()
for step,(X,y) in enumerate(testloader):
    X = X.cuda().float()
    y = y.cuda()
    inputs_cw = cw_Linf_attack(X, y, student, eps=int(args.alpha)/255, alpha=(int(args.alpha)/4)/255)
    inputs_fgsm = FGSM(X, y, student, eps=int(args.alpha)/255)
    inputs_pgd = PGD(X, y, student, eps=int(args.alpha)/255, alpha=(int(args.alpha)/4)/255, steps=20, random_start=True)
    student.eval()
    with torch.no_grad():
        logits = student(X)
        logits_cw = student(inputs_cw)
        logits_fgsm = student(inputs_fgsm)
        logits_pgd = student(inputs_pgd)
    
    predictions_pgd = np.argmax(logits_pgd.cpu().detach().numpy(),axis=1)
    predictions_pgd = predictions_pgd - y.cpu().detach().numpy()
    
    predictions_cw = np.argmax(logits_cw.cpu().detach().numpy(),axis=1)
    predictions_cw = predictions_cw - y.cpu().detach().numpy()

    predictions_fgsm = np.argmax(logits_fgsm.cpu().detach().numpy(),axis=1)
    predictions_fgsm = predictions_fgsm - y.cpu().detach().numpy()

    predictions = np.argmax(logits.cpu().detach().numpy(),axis=1)
    predictions = predictions - y.cpu().detach().numpy()
    
    test_accs = test_accs + predictions.tolist()
    test_accs_pgd = test_accs_pgd + predictions_pgd.tolist()
    test_accs_cw = test_accs_cw + predictions_cw.tolist()
    test_accs_fgsm = test_accs_fgsm + predictions_fgsm.tolist()

test_accs = np.array(test_accs)
test_accs_adv = np.array(test_accs_pgd)
test_accs_fgsm = np.array(test_accs_fgsm)
test_accs_cw = np.array(test_accs_cw)
test_acc = np.sum(test_accs==0)/len(test_accs)
test_acc_adv = np.sum(test_accs_adv==0)/len(test_accs_adv)
test_acc_fgsm = np.sum(test_accs_fgsm==0)/len(test_accs_fgsm)
test_acc_cw = np.sum(test_accs_cw==0)/len(test_accs_cw)
print('PGD20 acc',test_acc_adv)
print('CW acc',test_acc_cw)
print('FGSM acc',test_acc_fgsm)
overfitting_gap = best_robust_acc - test_acc_adv
generalization_gap = train_robust_acc - test_acc_adv

if not args.nowand:
    d2={'clean_acc': test_acc, 'robust_acc': test_acc_adv, 'fgsm_acc':test_acc_fgsm, 'cw_acc':test_acc_cw,    'overfitting_gap': overfitting_gap,
        'generalization_gap': generalization_gap}
    wandb.log(d2)


student.eval()
autoattack = AutoAttack(student, norm='Linf', eps=8/255.0, version='standard')
x_total = [x for (x, y) in testloader]
y_total = [y for (x, y) in testloader]
x_total = torch.cat(x_total, 0)
y_total = torch.cat(y_total, 0)
_, robust_acc = autoattack.run_standard_evaluation(x_total, y_total)
print('final AA',robust_acc)
if not args.nowand:
    AA_d = {'RESULT_AA': robust_acc}
    wandb.log(AA_d)

if not args.nowand:
    wandb.finish()