from __future__ import print_function

import time
import os
import csv
import shutil
import time
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn. functional as F
import torch.backends.cudnn as cudnn
import argparse


from PIL import Image
from torchvision import models, transforms
from torchvision.utils import save_image
from torchvision import datasets, transforms
import torch.utils.data as data

from attacks.pgd import PGD_Linf, PGD_L2
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p
#from tqdm.notebook import tqdm
from autoattack import AutoAttack

import models.wideresnet_silu as wrn_models
#import models.resnet_silu as res_models
import models.resnet as res_models
import models.preact_resnet_silu as pre_res_models
import load_data.datasets as dataset

from sklearn.metrics import confusion_matrix

parser = argparse.ArgumentParser(description='Test the robustness to adversarial attack')

# ########################## basic settin
parser.add_argument('--seed', type=int, default=0, help='manual seed')
parser.add_argument('--gpu', default=0, type=int, help='id(s) for CUDA_VISIBLE_DEVICES')
parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='test batchsize')

######################### Dataset #############################
parser.add_argument('--dataset', type=str, default= 'cifar10', choices=['cifar10', 'cifar100', 'fmnist', 'svhn'], help='benchmark dataset')
parser.add_argument('--data_dir', default='/home/ydy0415/data/datasets', help='Directory of dataset')
parser.add_argument('--num_classes', default=10, type=int, help='the number of classes')

######################### Robust Evaluation Setting #############################
parser.add_argument('--attack_method', metavar='METHOD', default='both', choices=['autoattack', 'pgd_linf' , 'both','pgd_l2', 'fgsm'], help=' attack method')
parser.add_argument('--eps', type=float, default=8, help= 'maximum of perturbation magnitude' )
parser.add_argument('--test_numsteps', type=int, default=20, help= 'test PGD number of steps')
parser.add_argument('--random_start', action='store_false', help='PGD use random start')
parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--ema', action='store_true', help='swa usage flag (default: off)')
parser.add_argument('--bn_mode', metavar='BN', default='eval', choices=['eval', 'train'], help='batch normalization mode of attack')

########################## Model Setting ##########################
parser.add_argument('--model', type=str, default= 'wideresnet', help='architecture of model') #, choices=['resnet18, wideresnet'] : invalid choice
parser.add_argument('--depth', type=int, default=28, help='wideresnet depth factor')
parser.add_argument('--widen_factor', type=int, default=5, help='wideresnet widen factor')
parser.add_argument('--activation', type=str, default= 'ReLU', choices=['ReLU', 'LeakyReLU', 'SiLU'], help='choice of activation')
parser.add_argument('--model_dir', default='/home/ydy0415/data/experiments/decouple_robust/cifar10/230913/pgd_linf_wrn-28-5_loss-trades_perturbloss-kl_eps-10.0_lrsche-MultiStep_lamb-6.0_smooth-0.0_seed-0', help='Directory of model saved')

########################## Misc ##########################
parser.add_argument('--add_name', default='', type=str, help='add_name')



args = parser.parse_args()
print (args)

state = {k: v for k, v in args._get_kwargs()}

# Use CUDA
#os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.cuda.set_device(args.gpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


if args.dataset in ['cifar10', 'cifar100',  'svhn']:
    input_channel = 3
elif args.dataset in ['fmnist']:
    input_channel = 1

# Data
def main():
    print('==> Preparing ' + str(args.dataset))
    
    data_directory = args.data_dir
    
    _, _, _, test_loader = dataset.load_data(data_directory, args.dataset, batch_size=args.batch_size, batch_size_test=args.batch_size, num_workers=0, use_augmentation=True, shuffle_train=True, validation=False)
    # Model
    
    def load_model():
        if args.model == 'wideresnet':
            print("==> creating WideResNet" + str(args.depth) + '-' + str(args.widen_factor))
            model = wrn_models.WideResNet(num_classes=100 if args.dataset=='cifar100' else 10, depth=args.depth, widen_factor=args.widen_factor, activation=args.activation).cuda(args.gpu)
            checkpoint = torch.load(args.model_dir + '/last.pth.tar', map_location= 'cuda:' + str(args.gpu))
            if args.swa:
                model.load_state_dict(checkpoint['swa_state_dict'])
            elif args.ema:
                model.load_state_dict(checkpoint['ema_state_dict'])    
            else:
                model.load_state_dict(checkpoint['state_dict'])
        
        elif args.model == 'resnet18':
            print("==> creating ResNet18")
            model = res_models.resnet('resnet18', input_channel, num_classes=100 if args.dataset=='cifar100' else 10).cuda(args.gpu)
            checkpoint = torch.load(args.model_dir + '/last.pth.tar', map_location= 'cuda:' + str(args.gpu))
            #checkpoint = torch.load(args.model_dir + '/last.pth.tar', map_location= 'cuda:' + str(args.gpu))
            if args.swa:
                model.load_state_dict(checkpoint['swa_state_dict'])              
            elif args.ema:
                model.load_state_dict(checkpoint['ema_state_dict'])    
            else:
                model.load_state_dict(checkpoint['state_dict'])
            del checkpoint
            torch.cuda.empty_cache()
            
        elif args.model == 'pre-resnet18':
            print("==> creating Pre-ResNet18")
            model = pre_res_models.preact_resnet('preact-resnet18', input_channel, num_classes=10).cuda(args.gpu)
            checkpoint = torch.load(args.model_dir + '/last.pth.tar', map_location= 'cuda:' + str(args.gpu))
            #checkpoint = torch.load(args.model_dir + '/last.pth.tar', map_location= 'cuda:' + str(args.gpu))
            if args.swa:
                model.load_state_dict(checkpoint['swa_state_dict'])              
            elif args.ema:
                model.load_state_dict(checkpoint['ema_state_dict'])    
            else:
                model.load_state_dict(checkpoint['state_dict'])
            del checkpoint
            torch.cuda.empty_cache()
                
        return model
    
    model = load_model()
    
    criterion = nn.CrossEntropyLoss()
    #kl_div= nn.KLDivLoss()
    #cam_criterion = nn.MSELoss()
    
    if args.attack_method == 'pgd_l2':
        test_attack = PGD_L2(model=model, epsilon=args.eps/255, step_size=(args.eps/10)/255, num_steps=args.test_numsteps, random_start=args.random_start, train=False)
    elif args.attack_method == 'pgd_linf':
        test_attack= PGD_Linf(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.test_numsteps, random_start=args.random_start, criterion='ce',
                              bn_mode = args.bn_mode, train = False)
    elif args.attack_method == 'autoattack':
        auto_attack = AutoAttack(model, norm='Linf', eps=args.eps/255, version='standard', verbose=False)
        #auto_attack.attacks_to_run = ['apgd-ce']
        #auto_attack.attacks_to_run = ['apgd-t']
        #auto_attack.attacks_to_run = ['fab']
        auto_attack.attacks_to_run = ['square']
        #auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t', 'fab', 'square']
        #auto_attack.attacks_to_run = ['fab', 'square']
    elif args.attack_method == 'both':
        test_attack= PGD_Linf(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.test_numsteps, random_start=args.random_start, criterion='ce',
                              bn_mode = args.bn_mode, train = False)
        auto_attack = AutoAttack(model, norm='Linf', eps=args.eps/255, version='standard', verbose=False)
        auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
        #auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t', 'fab']
        #auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t', 'fab', 'square']
    
    resname = args.model_dir + f'/eps-{args.eps}_eval.csv'
    
    with open(resname, 'w') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow(['Normalization', 'Attack', 'Rob-Acc', 'Worst Rob-Acc', 'Std of Rob-Acc', 'Max-Min of Rob-Acc', 'Std of Ws', 'Corr(W, Rob_Acc)'])
            

    cudnn.benchmark = True
    
    print("==> Starting test for " + str(args.attack_method))
    _, acc, cls_acc = validate_disparity(args, test_loader, model, criterion, "Plain", pgd_attack=None, aa_attack=None)
    
    if args.attack_method == 'both':
        _, attack_pgd_acc, cls_pgd_acc = validate_disparity(args,
                                                            test_loader,
                                                            model,
                                                            criterion,
                                                            "Plain",
                                                            pgd_attack=test_attack
                                                            ,aa_attack=None)
        
        _, attack_aa_acc, cls_aa_acc    = validate_disparity(args,
                                                             test_loader,
                                                             model,criterion,
                                                             "Plain",
                                                             pgd_attack=None,
                                                             aa_attack=auto_attack)  
        
    elif args.attack_method=='autoattack':
        _, attack_aa_acc, cls_aa_acc    = validate_disparity(args,
                                                             test_loader,
                                                             model,criterion,
                                                             "Plain",
                                                             pgd_attack=None,
                                                             aa_attack=auto_attack) 

    else:
        _, attack_pgd_acc, cls_pgd_acc = validate_disparity(args,
                                                            test_loader,
                                                            model,
                                                            criterion,
                                                            "Plain",
                                                            pgd_attack=test_attack
                                                            ,aa_attack=None)
    
    
    #################### Write results ####################
    
    with open(resname, 'a') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow(['No', 'Clean', acc, (cls_acc.min()).round(4), (cls_acc.std()).round(4), (cls_acc.max() - cls_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                            np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)])
        
        if args.attack_method == 'both':
            logwriter.writerow(['No', 'PGD', attack_pgd_acc, (cls_pgd_acc.min()).round(4), (cls_pgd_acc.std()).round(4), (cls_pgd_acc.max() - cls_pgd_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_pgd_acc))[0,1].round(4)])
            logwriter.writerow(['No', 'AA', attack_aa_acc, cls_aa_acc.min(), cls_aa_acc.std(), cls_aa_acc.max() - cls_aa_acc.min(), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_aa_acc))[0,1].round(4)])
            
        elif args.attack_method == 'autoattack':
            logwriter.writerow(['No', 'AA', attack_aa_acc, (cls_aa_acc.min()).round(4), (cls_aa_acc.std()).round(4), (cls_aa_acc.max() - cls_aa_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_aa_acc))[0,1].round(4)]) 
            #'Rob-Acc', 'Worst Rob-Acc', 'Std of Rob-Acc', 'Max-Min of Rob-Acc', 'Std of Ws', 'Corr(W, Rob_Acc)'
        else:
            logwriter.writerow(['No', 'PGD', attack_pgd_acc, (cls_pgd_acc.min()).round(4), (cls_pgd_acc.std()).round(4), (cls_pgd_acc.max() - cls_pgd_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_pgd_acc))[0,1].round(4)])

        
        fc_weights = model.fc.weight.data

        # weight normalization
        #normalized_weights = torch.nn.functional.normalize(fc_weights, p=2, dim=1)
        normalized_weights = fc_weights/(torch.norm(fc_weights, p=2, dim=1)).pow(1).unsqueeze(-1)

        # 새로운 fully connected layer 생성
        new_fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=model.fc.out_features, bias=True)
        new_fc.weight.data = normalized_weights
        new_fc.bias.data = model.fc.bias.data

        # 기존 ResNet 모델에 새로운 fully connected layer 적용
        model.fc = new_fc
        model.cuda()

        test_attack = PGD_Linf(model=model,
                                epsilon=args.eps/255, 
                                step_size=(args.eps/4)/255,
                                num_steps=args.test_numsteps,
                                random_start=args.random_start,
                                criterion='ce',
                                bn_mode = args.bn_mode,
                                train = False)

        auto_attack= AutoAttack(model, norm='Linf', eps=args.eps/255, version='standard', verbose=False)
        auto_attack.attacks_to_run = ['apgd-t']
        
        
        _, acc, cls_acc = validate_disparity(args, test_loader, model, criterion, "Norm", pgd_attack=None, aa_attack=None)
    
        if args.attack_method == 'both':
            _, attack_pgd_acc, cls_pgd_acc = validate_disparity(args,
                                                                test_loader,
                                                                model,
                                                                criterion,
                                                                "Norm",
                                                                pgd_attack=test_attack
                                                                ,aa_attack=None)
            
            _, attack_aa_acc, cls_aa_acc    = validate_disparity(args,
                                                                test_loader,
                                                                model,criterion,
                                                                "Norm",
                                                                pgd_attack=None,
                                                                aa_attack=auto_attack)  
            
        elif args.attack_method=='autoattack':
            _, attack_aa_acc, cls_aa_acc    = validate_disparity(args,
                                                                test_loader,
                                                                model,criterion,
                                                                "Norm",
                                                                pgd_attack=None,
                                                                aa_attack=auto_attack) 

        else:
            _, attack_pgd_acc, cls_pgd_acc = validate_disparity(args,
                                                                test_loader,
                                                                model,
                                                                criterion,
                                                                "Norm",
                                                                pgd_attack=test_attack
                                                                ,aa_attack=None)
        
            
        logwriter.writerow(['Norm', 'Clean', acc, cls_acc.min(), cls_acc.std(), cls_acc.max() - cls_acc.min(), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                            np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)])
        
        if args.attack_method == 'both':
            logwriter.writerow(['Norm', 'PGD', attack_pgd_acc, (cls_pgd_acc.min()).round(4), (cls_pgd_acc.std()).round(4), (cls_pgd_acc.max() - cls_pgd_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_pgd_acc))[0,1].round(4)])
            logwriter.writerow(['Norm', 'AA', attack_aa_acc, cls_aa_acc.min(), cls_aa_acc.std(), cls_aa_acc.max() - cls_aa_acc.min(), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_aa_acc))[0,1].round(4)])
            
        elif args.attack_method == 'autoattack':
            logwriter.writerow(['Norm', 'AA', attack_aa_acc, (cls_aa_acc.min()).round(4), (cls_aa_acc.std()).round(4), (cls_aa_acc.max() - cls_aa_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_aa_acc))[0,1].round(4)]) 
            #'Rob-Acc', 'Worst Rob-Acc', 'Std of Rob-Acc', 'Max-Min of Rob-Acc', 'Std of Ws', 'Corr(W, Rob_Acc)'
        else:
            logwriter.writerow(['Norm', 'PGD', attack_pgd_acc, (cls_pgd_acc.min()).round(4), (cls_pgd_acc.std()).round(4), (cls_pgd_acc.max() - cls_pgd_acc.min()).round(4), (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_pgd_acc))[0,1].round(4)])
        

def validate_disparity(args, val_loader, model, criterion, mode, pgd_attack=None, aa_attack=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    
    bar = Bar('{mode:>12}'.format(mode=mode), max=len(val_loader))
    
    con_mat = np.zeros((args.num_classes, args.num_classes))
    
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        
        inputs, targets = inputs.to(device), targets.to(device)
        
        # compute output
        if pgd_attack is not None:
            adv_inputs, _ = pgd_attack.perturb(inputs, targets)
            outputs = model(adv_inputs)
            
        elif aa_attack and not pgd_attack:
            adv_inputs = aa_attack.run_standard_evaluation(inputs, targets, bs=args.batch_size)
            outputs = model(adv_inputs)
            
        else:
            outputs = model(inputs)
        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        
        con_mat += confusion_matrix(targets.cpu(), outputs.argmax(dim=1).cpu(), labels=range(args.num_classes))
        
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
        bar.next()
    bar.finish()
    
    cls_acc = np.diagonal(con_mat)/con_mat.sum(axis=1) * 100
    
    if pgd_attack is not None:
        print("Rob-Acc(PGD) : %.2f, Worst Rob-Acc(PGD) : %.2f, Std of Rob-Acc(PGD) : %.2f, Max-Min of Rob-Acc(PGD) : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
        print("Corr(W, rob_acc) : %.4f"
            %(np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4))
            )
        
    elif aa_attack and not pgd_attack:
        print("Rob-Acc(AA) : %.2f, Worst Rob-Acc(AA) : %.2f, Std of Rob-Acc(AA) : %.2f, Max-Min of Rob-Acc(AA) : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
       
    else:
        print("Clean-Acc : %.2f, Worst Clean-Acc : %.2f, Std of Clean-Acc : %.2f, Max-Min of Clean-Acc : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
        
    return (losses.avg, top1.avg, cls_acc)


def save_checkpoint(out_dir, epoch, filename='checkpoint.pth.tar', **kwargs):
    state={
        'epoch' : epoch
    }
    state.update(kwargs)
    filepath = os.path.join(out_dir, filename)
    torch.save(state, filepath)
    
    print("==> saving best model")

if __name__ == '__main__':
    main()