from __future__ import print_function
import os
import argparse
import torch
import torch.optim as optim
from attack import *


import pdb
import numpy as np
import shutil

try:
    import wandb
except ImportError:
    wandb = None
    
from utils import *
import random
# we fix the random seed to 0, this method can keep the results consistent in the same conputer. 
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
random.seed(0)
# np.random.seed(0)

from robustbench.utils import load_model
from autoattack import AutoAttack
import time
from pathlib import Path
import methods
start_time = time.strftime('%Y-%m-%d-%M', time.localtime(time.time()))

#########################################################################################################


def main(config,wandb):
    
    train_loader, test_loader, samples_per_cls, trainset = build_loader_ours2(config)
    model = build_model(config, samples_per_cls)
    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
    best_rob = 0 
    best_pre_tail_rob = 0
    
    
    save_pre_dir = './result_models/'+config.dataset+"/" + config.model + "/pretrained_best"
    os.makedirs(save_pre_dir, exist_ok=True)
    save_pre_dir = save_pre_dir + "/" + str(config.beta) + str(config.gamma) + "_" + str(config.pre_batch) + "_" + str(config.pre_epochs) + "_" + str(config.pre_lr)  + "_" + str(config.IR) 
    pre_model_path = save_pre_dir + "_" + "balance_best.pt"
    balance_model =  build_balance_model(config, samples_per_cls)
    flatness_loss_criteria_dir = save_pre_dir + "_" + "flatness_criteria_best.txt"

    # balance training

    if os.path.exists(pre_model_path):

        flatness_loss_criteria = load_flatness_criteria(flatness_loss_criteria_dir)
        balance_model.load_state_dict(torch.load(pre_model_path))
        balance_model.eval()
        
    else : 
        balance_loader = build_balance_loader(trainset,samples_per_cls, config)
        balance_optimizer = optim.SGD(balance_model.parameters(), lr=config.pre_lr, momentum=config.momentum, weight_decay=config.weight_decay)
        for epoch in range(1, config.pre_epochs + 1):
            adjust_pre_learning_rate(balance_optimizer, epoch, config)
            flatness_loss = getattr(methods, "train_" + config.pre_method)(balance_model,balance_loader,epoch,balance_optimizer,config, samples_per_cls)
            if epoch % config.interval == 0 :
                clean, rob, class_robs = evaluate_interval(balance_model, test_loader, config, wandb, epoch, pretraining = True)
                if best_rob < rob:
                    best_rob = rob
                    torch.save(balance_model.state_dict(), pre_model_path)
                    flatness_loss_criteria = flatness_loss/len(balance_loader.dataset)
                    save_results(flatness_loss_criteria_dir, clean, rob, best_rob, flatness_loss_criteria)
                
    clean, rob, _ = evaluate_interval(balance_model, test_loader, config, wandb, 0)
    if not config.nowand:
        d2={'pre_clean_acc': clean, 'pre_rob_acc': rob,'pretrain_epoch': 0}
        wandb.log(d2)

    best_rob = 0
    
    # main training
    for epoch in range(1, config.epochs + 1):
        adjust_learning_rate(optimizer, epoch, config)
        model.train()
        getattr(methods, "train_" + config.method)(model,train_loader,epoch,optimizer,config, samples_per_cls,balance_model,flatness_loss_criteria)
        if epoch % config.interval == 0 :
            clean, rob, _ = evaluate_interval(model, test_loader, config, wandb, epoch)
            if rob >= best_rob : 
                best_epoch, best_clean, best_rob = epoch, clean, rob
                save_dir = './result_models/'+ config.dataset+"/" + config.model 
                os.makedirs(save_dir, exist_ok=True)
                best_save_fname = os.path.join(save_dir, config.wandb_name + "_" + start_time  + 'total_epochs_{:d}_best.pt'.format(config.epochs))
                if os.path.exists(best_save_fname):
                    os.remove(best_save_fname)
                torch.save(model.state_dict(),best_save_fname)

    save_dir = './result_models/'+ config.dataset+"/" + config.model 
    os.makedirs(save_dir, exist_ok=True)
    save_fname = os.path.join(save_dir, config.wandb_name + "_" + start_time  + 'total_epochs_{:d}_last.pt'.format(config.epochs))
    torch.save(model.state_dict(),save_fname)
    model.eval()      


    weak1_acc, head_weak1_acc, tail_weak1_acc, weak2_acc, head_weak2_acc, tail_weak2_acc, fgsm_acc, head_fgsm_acc, tail_fgsm_acc = evaluate_fgsm_head_tail(model,test_loader,ratio = 0.1 *  config.num_classes,config = config)
    clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_pgd_acc, tail_pgd_acc, clean_accs, pgd_accs = evaluate_pgd_head_tail(model,test_loader,ratio = 0.1 * config.num_classes, config = config)
    if not config.nowand:
        d2={'clean_acc': clean_acc, 'head_clean_acc': head_clean_acc, 'tail_clean_acc':clean_acc}
        wandb.log(d2)
        d2={'fgsm_acc': fgsm_acc, 'head_fgsm_acc': head_fgsm_acc, 'tail_fgsm_acc' : tail_fgsm_acc}
        wandb.log(d2)
        d2={'pgd_acc': pgd_accs, 'head_pgd_acc': head_pgd_acc, 'tail_pgd_acc' : tail_pgd_acc}
        wandb.log(d2)
        
        d2={'weak1_fgsm_acc': weak1_acc, 'weak1_head_fgsm_acc': head_weak1_acc, 'weak1_tail_fgsm_acc' : tail_weak1_acc}
        wandb.log(d2)
        d2={'weak2_fgsm_acc': weak2_acc, 'weak2_head_fgsm_acc': head_weak2_acc, 'weak2_tail_fgsm_acc' : tail_weak2_acc}
        wandb.log(d2)



    if config.evaluate :
        AA_acc = evaluate_final_aa(model,test_loader)
        if not config.nowand:
            d2={'last_clean_acc': clean, 'last_robust_acc': rob,'RESULT_AA': AA_acc}
            wandb.log(d2)

        best_model = build_model(config, samples_per_cls)
        best_model.load_state_dict(torch.load(best_save_fname))
        best_model.eval()

        AA_acc = evaluate_final_aa(best_model,test_loader)
        if not config.nowand:
            d2={'best_clean_acc': best_clean, 'best_robust_acc': best_rob,'best_RESULT_AA': AA_acc}
            wandb.log(d2)



if __name__ == "__main__":
    args = load_parser()
    path = Path(os.path.realpath(__file__))
    config = load_config(args)

    if not config.nowand:
        wandb.require("core")
        wandb.init(project="project_name", entity="wandb_entity", config=config, name = config.wandb_name, tags = [config.tag])
        args.wandb_url = wandb.run.get_url()

    main(config,wandb)