from __future__ import print_function

import matplotlib
matplotlib.use('Agg')

import argparse
import os
import time
import warnings
from datetime import date

import numpy as np
import contextlib

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from autoattack import AutoAttack
from torch.cuda import amp

import models.wideresnet as wrn_models



import load_data.datasets as dataset
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p
from utils.swa import moving_average, bn_update
from loss import *
from attacks.pgd import PGD_Linf, PGD_L2, GA_PGD
from sklearn.metrics import confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(args: argparse.Namespace):
    #global best_acc
    #global attack_best_acc
    
    #global test_acc
    #global attack_test_acc
    
    print(args)
    
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    
    
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    
    args.save_dir += f'/{args.dataset}/'
    args.save_dir += str(date.today().strftime('%Y%m%d')[2:])
    
    args.save_dir += f'/{args.train_attack}'
    args.save_dir += f'_wrn-{args.depth}-{args.widen_factor}'
    
    args.save_dir += f'_loss-{args.loss}_perturbloss-{args.perturb_loss}_eps-{args.eps}_lrsche-{args.lr_scheduler}'
        
    
    if args.loss in ['trades', 'mart', 'arow', 'cow']:
        args.save_dir += f'_lamb-{args.lamb}'
    
    if args.loss in ['pgd', 'trades', 'arow', 'cow']:
        args.save_dir += f'_smooth-{args.smooth}'
    
    if args.orthogonal:
        args.save_dir += f'_gamma-{args.gamma}'
        args.save_dir += f'_beta-{args.beta}'
    elif args.var:
        args.save_dir += f'_var-{args.var}'
        args.save_dir += f'_gamma-{args.gamma}'
    
    if args.swa:
        args.save_dir += f'_swa-{args.swa}'
        
    if args.finetune:
        args.save_dir += f'_fine-{args.finetune}'    
    
    args.save_dir += f'_seed-{args.seed}' 
        
    if args.add_name != '':
        args.save_dir += f'_{args.add_name}'
    
    if not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)
    
    _, _, train_loader, test_loader = dataset.load_data(args.data_dir,
                                                        args.dataset,
                                                        batch_size=args.batch_size,
                                                        batch_size_test=100,
                                                        num_workers=2,
                                                        use_augmentation=True,
                                                        shuffle_train=True,
                                                        validation=False)
    # Model

    def create_model(args):
        if args.model == 'wideresnet':
            print("==> creating WideResNet" + str(args.depth) + '-' + str(args.widen_factor))
            if args.swa:
                swa_model = wrn_models.WideResNet(num_classes=args.num_classes,
                                                  depth=args.depth,
                                                  widen_factor=args.widen_factor,
                                                  activation=args.activation).cuda(args.gpu)
                model = wrn_models.WideResNet(num_classes=args.num_classes,
                                              depth=args.depth,
                                              widen_factor=args.widen_factor,
                                              activation=args.activation).cuda(args.gpu)

                return swa_model, model

            else:
                model = wrn_models.WideResNet(num_classes=args.num_classes,
                                              depth=args.depth,
                                              widen_factor=args.widen_factor,
                                              activation=args.activation).cuda(args.gpu)
                
                return model
    
    if args.swa:
        swa_model, model = create_model(args)
        swa_n = 0
    else:
        model = create_model(args)    
        
    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
    
    if args.finetune:
        checkpoint = torch.load('/home/ydy0415/data/experiments/decouple_robust' + \
                                f'/{args.dataset}/230925/{args.load_dir}' + \
                                '/last.pth.tar')
                                #f'/{args.dataset}/existing/cfa' + \
                                #'/FAWA_99.pth')
                                #pgd_linf_wrn-28-5_loss-pgd_perturbloss-ce_eps-8_lrsche-MultiStep_smooth-0.0_seed-0
                                #pgd_linf_wrn-28-5_loss-trades_perturbloss-kl_eps-8_lrsche-MultiStep_lamb-6.0_smooth-0.0_seed-0
                                #pgd_linf_wrn-28-5_loss-arow_perturbloss-kl_eps-8_lrsche-MultiStep_lamb-7.0_smooth-0.0_seed-0
                                #pgd_linf_wrn-28-5_loss-mart_perturbloss-ce_eps-8_lrsche-MultiStep_lamb-3.0_seed-0
                                #f'/{args.dataset}/230912/pgd_linf_wrn-28-5_loss-trades_perturbloss-kl_eps-8_lrsche-MultiStep_lamb-6.0_smooth-0.0_seed-0' + \
                                #'/home/ydy0415/data/experiments' + \
                                #'/cow/cifar10_lossplain/checkpoint.pth.tar')
        model.load_state_dict(checkpoint['state_dict'])#
        #model.load_state_dict(checkpoint)#
        del checkpoint

    criterion = nn.CrossEntropyLoss()
    
    if args.train_attack == 'pgd_linf':
        train_attack = PGD_Linf(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.train_numsteps, random_start=args.random_start,
                                criterion=args.perturb_loss, bn_mode = args.bn_mode, train = True)
    elif args.train_attack == 'gapgd_linf':
        train_attack = GA_PGD(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.train_numsteps, random_start=args.random_start,
                                criterion=args.perturb_loss, bn_mode = args.bn_mode, train = True)
        
    test_attack = PGD_Linf(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.test_numsteps, random_start=args.random_start, criterion='ce', bn_mode = args.bn_mode, train = False)
        
    
    if args.finetune:

        head_params = model.fc.parameters()
        #head_params = model.parameters()
        base_optimizer = optim.SGD
        optimizer = SAM(head_params, base_optimizer, lr=args.lr, momentum=0.9, weight_decay=args.wd, rho=0.01)
        
        
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
    
    if args.lr_scheduler == "MultiStep":
        if args.swa:
            scheduler = lr_scheduler.MultiStepLR(optimizer , milestones=[2/4*args.epochs, 3/4*args.epochs], gamma=0.1)
        else:
            scheduler = lr_scheduler.MultiStepLR(optimizer , milestones=[90, 95], gamma=0.1)
            #scheduler = lr_scheduler.MultiStepLR(optimizer , milestones=[50], gamma=0.1)
            
    elif args.lr_scheduler == "Cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    
    logger = Logger(os.path.join(args.save_dir, 'log.txt'), title=args.dataset)
    logger.set_names(['Epoch', 'Train Loss', 'Test Loss', 'Test Acc.', 'Attack Loss' , 'Attack Acc.'])
    
    logger_disparity = Logger(os.path.join(args.save_dir, 'log_disparity.txt'), title=args.dataset)
    logger_disparity.set_names(['Epoch', 'Rob-Acc(PGD)', 'Worst Rob-Acc(PGD)', 'Std of Rob-Acc(PGD)', 'Max-Min of Rob-Acc(PGD)', 'Std of Ws', 'Corr(W, Rob_Acc)'])
    
    attack_best_acc, best_acc, tolerance = 0, 0, 0
    
    # Train and val
    for epoch in range(args.start_epoch, args.epochs + 1):

        print('\n'+args.train_attack +' Epoch: [%d | %d] LR: %.5f Tol: %d Best ts acc: %.2f Best_att_acc: %.2f ' % (epoch, args.epochs, optimizer.param_groups[0]['lr'], tolerance, best_acc, attack_best_acc))
        if args.finetune and epoch == 1:
            test_loss, test_acc, cls_acc = validate_disparity(args, test_loader, model, criterion, mode='Standard')
            #breakpoint()
            attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args, test_loader, model, criterion, mode='Attack_test', pgd_attack=test_attack)
            #logger.append([round(epoch), train_loss, test_loss, test_acc, attack_test_loss,  attack_test_acc])
            logger_disparity.append([round(epoch),
                                    attack_test_acc,
                                    cls_acc.min(),
                                    cls_acc.std(),
                                    cls_acc.max() - cls_acc.min(),
                                    (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                    np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                                    ])
            
            print("Std of Ws : %.6f"
            %((model.fc.weight.detach().cpu().square().mean(dim=1).sqrt()/model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std())
            )

            auto_attack= AutoAttack(model, norm='Linf', eps=8/255, version='standard', verbose=False)
            auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
            attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args,
                                                                test_loader,
                                                                model,
                                                                criterion,
                                                                "Final",
                                                                pgd_attack=None,
                                                                aa_attack=auto_attack)
            logger_disparity.append([epoch,
                                    attack_test_acc,
                                    cls_acc.min(),
                                    cls_acc.std(),
                                    cls_acc.max() - cls_acc.min(),
                                    (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                    np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                                    ])

        if args.finetune:
            train_loss = finetune_DeSAM(args, train_loader, model, optimizer,attack=train_attack)
            #def finetune_DeSAM(args, train_loader, model, optimizer, attack):

        else:
            train_loss = train(args, train_loader, model, optimizer, attack=train_attack)
            
        if args.swa and epoch == args.swa_start:
            test_attack =  PGD_Linf(model=swa_model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.test_numsteps, random_start=args.random_start, criterion='ce', bn_mode = args.bn_mode, train = False)
        
        if args.swa and epoch >= args.swa_start and (epoch - args.swa_start) % args.swa_c_epochs == 0:
            moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            if epoch >= 70:
                bn_update(train_loader, swa_model)
                #_, train_acc = validate(train_loader, swa_model, criterion, use_cuda, mode='Train')
                test_loss, test_acc = validate(test_loader, swa_model, criterion, mode='Test')
                attack_test_loss, attack_test_acc = validate(test_loader, swa_model, criterion, mode='Attack_test', attack=test_attack)
                logger.append([round(epoch), train_loss, test_loss, test_acc, attack_test_loss,  attack_test_acc])
                #logger_disparity([])
            
        else:
            if epoch == 1 or epoch % 2 == 0 or epoch >= 90 or epoch == args.epochs:
                test_loss, test_acc = validate(test_loader, model, criterion, mode='Test')
                attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args, test_loader, model, criterion, mode='Attack_test', pgd_attack=test_attack)
                logger.append([round(epoch), train_loss, test_loss, test_acc, attack_test_loss,  attack_test_acc])
                logger_disparity.append([round(epoch),
                                        attack_test_acc,
                                        cls_acc.min(),
                                        cls_acc.std(),
                                        cls_acc.max() - cls_acc.min(),
                                        (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                        np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                                        ])
                
                print("Std of Ws : %.6f"
                %((model.fc.weight.detach().cpu().square().mean(dim=1).sqrt()/model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std())
                )
        
        # save model
        is_attack_best = attack_test_acc > attack_best_acc
        attack_best_acc = max(attack_test_acc, attack_best_acc)
            
        if is_attack_best:
            attack_best_acc = attack_test_acc
            best_acc = test_acc
        
        if args.swa:
            if epoch >= args.swa_start and (epoch - args.swa_start) % args.swa_c_epochs == 0 and is_attack_best:
                save_checkpoint(args.save_dir, epoch,
                filename='robust_best.pth.tar',
                swa_state_dict = swa_model.state_dict(),
                swa_n = swa_n,
                state_dict = model.state_dict(),
                test_acc =  test_acc,
                attack_test_acc = attack_test_acc,
                optimizer = optimizer.state_dict()
                )
            elif epoch < args.swa_start and is_attack_best:
                save_checkpoint(args.save_dir, epoch, 
                filename='robust_best.pth.tar',
                state_dict = model.state_dict(),
                test_acc = test_acc,
                attack_test_acc = attack_test_acc,
                optimizer = optimizer.state_dict()
                )
            elif epoch == args.epochs:
                save_checkpoint(args.save_dir, epoch, 
                filename='last.pth.tar',
                swa_state_dict = swa_model.state_dict(),
                test_acc = test_acc,
                attack_test_acc = attack_test_acc
                )
                
        elif not args.swa:
            if is_attack_best:
                save_checkpoint(args.save_dir, epoch, 
                    filename = 'robust_best.pth.tar',
                    state_dict = model.state_dict(),
                    test_acc = test_acc,
                    attack_test_acc = attack_test_acc,
                    optimizer = optimizer.state_dict()
                    )
            
            if epoch == args.epochs:
                save_checkpoint(args.save_dir, epoch, 
                filename='last.pth.tar',
                state_dict = model.state_dict(),
                test_acc = test_acc,
                attack_test_acc = attack_test_acc
                )

        if is_attack_best:
            tolerance = 0
        else:
            tolerance += 1
        
    logger.close()
    
    
    test_attack = PGD_Linf(model=model,
                           epsilon=8/255,
                           step_size=(8/4)/255,
                           num_steps=20,
                           random_start=args.random_start, 
                           criterion='ce',
                           bn_mode = args.bn_mode,
                           train = False)
    
    auto_attack= AutoAttack(model, norm='Linf', eps=8/255, version='standard', verbose=False)
    auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
    
    validate_disparity(args, test_loader, model, criterion, "Final", pgd_attack=None, aa_attack=None)
    validate_disparity(args, test_loader, model, criterion, "Final", pgd_attack=test_attack, aa_attack=None)
    validate_disparity(args, test_loader, model, criterion, "Final", pgd_attack=None, aa_attack=auto_attack)
    
    
    fc_weights = model.fc.weight.data

    # weight normalization
    #normalized_weights = torch.nn.functional.normalize(fc_weights, p=2, dim=1)
    fc_weights = model.fc.weight.data
    normalized_weights = fc_weights/(torch.norm(fc_weights, p=2, dim=1)).pow(1).unsqueeze(-1)

    # 새로운 fully connected layer 생성
    new_fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=model.fc.out_features, bias=True)
    new_fc.weight.data = normalized_weights
    new_fc.bias.data = model.fc.bias.data
    #new_fc.bias.data = torch.tensor(0.)

    #

    # 기존 ResNet 모델에 새로운 fully connected layer 적용
    model.fc = new_fc
    model.cuda()
    
    save_checkpoint(args.save_dir, epoch, 
                    filename = 'final_normalized.pth.tar',
                    state_dict = model.state_dict()
                    )

    test_attack = PGD_Linf(model=model,
                           epsilon=8/255, 
                           step_size=(8/4)/255,
                           num_steps=20,
                           random_start=args.random_start,
                           criterion='ce',
                           bn_mode = args.bn_mode,
                           train = False)
    
    auto_attack= AutoAttack(model, norm='Linf', eps=8/255, version='standard', verbose=False)
    auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
    
    validate_disparity(args, test_loader, model, criterion, "Final_Norm", pgd_attack=None, aa_attack=None)
    validate_disparity(args, test_loader, model, criterion, "Final_Norm", pgd_attack=test_attack, aa_attack=None)
    attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args,
                                                                    test_loader,
                                                                    model,
                                                                    criterion,
                                                                    "Final_Norm",
                                                                    pgd_attack=None,
                                                                    aa_attack=auto_attack)    
    
    logger_disparity.append([epoch,
                            attack_test_acc,
                            cls_acc.min(),
                            cls_acc.std(),
                            cls_acc.max() - cls_acc.min(),
                            (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                            np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                            ])


def train(args, train_loader, model, optimizer, attack):


    batch_time = AverageMeter()
    data_time = AverageMeter()
    sup_losses = AverageMeter()
    reg_losses = AverageMeter()
    losses = AverageMeter()
    
    ce_loss=nn.CrossEntropyLoss()
    end = time.time()
    
    scaler = amp.GradScaler()
    amp_cm = amp.autocast if args.amp else contextlib.nullcontext
    
    bar = Bar('{:>12}'.format('Training'), max=len(train_loader))
    
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):

        data_time.update(time.time() - end)
        
        inputs, targets = inputs.to(device), targets.to(device)
    
        with amp_cm():
            
            if args.loss == "arow":
                #if args.perturb_loss not in ("kl", "revkl", "js"):
                #     raise ValueError("perturb loss must be kl or revkl divergence.")
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = ARoW_loss(inputs, adv_inputs, targets, model, args.smooth)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
                
            elif args.loss == "cow":
                if args.perturb_loss not in ("kl", "revkl", "js"):
                     raise ValueError("perturb loss must be kl or revkl divergence.")
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = CoW_loss(inputs, adv_inputs, targets, model, args.smooth)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
            
            elif args.loss == "mart":
                if args.perturb_loss != "ce":
                     raise ValueError("perturb loss must be ce.")
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = MART_loss(inputs, adv_inputs, targets, model)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
                
            elif args.loss == "trades":
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = TRADES_loss(inputs, adv_inputs, targets, model, args.smooth)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
            
            elif args.loss == "pgd":
                if args.perturb_loss != "ce":
                     raise ValueError("perturb loss must be ce.")
                adv_inputs, _ = attack.perturb(inputs, targets)
                adv_outputs   = model(adv_inputs)
                sup_loss      = ce_loss(adv_outputs, targets)
                reg_loss      = torch.tensor(0)
                loss          = sup_loss
                
        if args.orthogonal:
            ortho_loss = Orthogonal_loss(model, args.beta)
            loss += args.gamma * ortho_loss
            
        if args.var:
            var_loss = Diag_Var_loss(model)
            loss += args.gamma * var_loss     
        # record loss
        sup_losses.update(sup_loss.item(), inputs.size(0))
        reg_losses.update(reg_loss.item(), inputs.size(0))
        losses.update(loss.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if args.amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Sup_loss: {sup_loss:.4f} | Reg_loss: {reg_loss:.4f} |  Tot loss:{loss:.4f}'.format(
                    batch   = batch_idx + 1,
                    size    = len(train_loader),
                    data    = data_time.avg,
                    bt      = batch_time.avg,
                    total   = bar.elapsed_td,
                    eta     = bar.eta_td,
                    sup_loss=sup_losses.avg,
                    reg_loss=reg_losses.avg,
                    loss=losses.avg
                    )
        bar.next()
    bar.finish()
                  
    return losses.avg


def finetune_DeSAM(args, train_loader, model, optimizer, attack):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    sup_losses = AverageMeter()
    reg_losses = AverageMeter()
    losses = AverageMeter()

    ce_loss = nn.CrossEntropyLoss()
    end = time.time()

    bar = Bar('{:>12}'.format('Training'), max=len(train_loader))

    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)

        if batch_idx == 0:
            fc_weights = model.fc.weight.data
            l2_norms = torch.norm(fc_weights, p=2, dim=1)

            # weight normalization
            #normalized_weights = torch.nn.functional.normalize(fc_weights, p=2, dim=1)

            normalized_weights = fc_weights/(l2_norms).pow(1).unsqueeze(-1)

            # 새로운 fully connected layer 생성
            new_fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=model.fc.out_features, bias=True)
            new_fc.weight.data = normalized_weights
            new_fc.bias.data = model.fc.bias.data
            model.fc = new_fc

            model.cuda()


            # L2 norm에 음수 값을 곱한 후 exp를 적용하여 값이 크면 작은 값이 나오도록 함
            exp_values = torch.exp(-l2_norms)
            # exp 값을 정규화하여 합이 1이 되도록 함
            normalized_values = exp_values / exp_values.sum()

            
            rhos_head = args.rho * normalized_values

            #head_params = model.fc.parameters()
            # head_indices에 해당 인덱스들을 넣습니다.
            #head_params = model.parameters()
            
            #optimizer = DeCoupledSAM(head_params, base_optimizer, rhos_head, head_indices, lr=args.lr, momentum=0.9, weight_decay=args.wd)
            head_indices = [0]
            params = model.fc.named_parameters()  # 이 경우 name과 param을 포함하는 튜플이 생성됩니다.
            base_optimizer = optim.SGD
            optimizer = DeCoupledSAM(params, base_optimizer, rhos_head, head_indices, lr=args.lr, momentum=0.9, weight_decay=args.wd)
            #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50000/args.batch_size, eta_min=0.03)
            #optimizer = SAM(params, base_optimizer, rhos_head, head_indices, lr=args.lr, momentum=0.9, weight_decay=args.wd)
            #new: eta_min =0.01, newnew : eta_min = 0.03

            #def __init__(self, params, base_optimizer, rhos_heads, adaptive=False, **kwargs):

        # 기존 ResNet 모델에 새로운 fully connected layer 적용
        inputs, targets = inputs.to(device), targets.to(device)

        # Generate adversarial inputs
        adv_inputs, _ = attack.perturb(inputs, targets)

        # Define closure function that calculates the losses
        def closure():
            # Forward pass for loss calculation
            adv_inputs, _ = attack.perturb(inputs, targets)  # Perturb inputs again
            if args.loss == "arow":
                sup_loss, reg_loss = ARoW_loss(inputs, adv_inputs, targets, model, args.smooth)
            elif args.loss == "cow":
                sup_loss, reg_loss = CoW_loss(inputs, adv_inputs, targets, model, args.smooth)
            elif args.loss == "mart":
                sup_loss, reg_loss = MART_loss(inputs, adv_inputs, targets, model)
            elif args.loss == "trades":
                sup_loss, reg_loss = TRADES_loss(inputs, adv_inputs, targets, model, args.smooth)
            elif args.loss == "pgd":
                adv_outputs = model(adv_inputs)
                sup_loss = ce_loss(adv_outputs, targets)
                reg_loss =torch.tensor(0.)  # No regularization for PGD loss

            # Combine losses
            total_loss = sup_loss + args.lamb * reg_loss
            
            if args.orthogonal:
                ortho_loss = Orthogonal_loss(model, args.beta)
                total_loss += args.gamma * ortho_loss
            if args.var:
                var_loss = Diag_Var_loss(model)
                total_loss += args.gamma * var_loss

            return total_loss, sup_loss, reg_loss

        # First SAM step
        #scheduler.step()
        optimizer.zero_grad()
        total_loss, sup_loss, reg_loss = closure()  # Compute initial loss
        total_loss.backward()  # First backward pass
        optimizer.first_step()  # First SAM step

        # Calculate loss again after first step
        total_loss, sup_loss, reg_loss = closure()  # Compute loss after perturbation
        total_loss.backward()  # Second backward pass
        optimizer.second_step()  # Second SAM step

        

        # Record losses
        sup_losses.update(sup_loss.item(), inputs.size(0))
        reg_losses.update(reg_loss.item(), inputs.size(0))
        losses.update(total_loss.item(), inputs.size(0))

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Plot progress
        bar.suffix = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Sup_loss: {sup_loss:.4f} | Reg_loss: {reg_loss:.4f} |  Tot loss:{loss:.4f}'.format(
            batch=batch_idx + 1,
            size=len(train_loader),
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            sup_loss=sup_losses.avg,
            reg_loss=reg_losses.avg,
            loss=losses.avg
        )
        bar.next()
    bar.finish()

    return losses.avg





def validate(val_loader, model, criterion, mode, attack=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    
    bar = Bar('{mode:>12}'.format(mode=mode), max=len(val_loader))
   
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        
        inputs, targets = inputs.to(device), targets.to(device)
        
        # compute output
        if attack is not None:
            adv_inputs, _ = attack.perturb(inputs, targets)
            outputs = model(adv_inputs)
        else:
            outputs = model(inputs)
        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
        bar.next()
    bar.finish()
        
    return (losses.avg, top1.avg)

def validate_disparity(args, val_loader, model, criterion, mode, pgd_attack=None, aa_attack=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    
    bar = Bar('{mode:>12}'.format(mode=mode), max=len(val_loader))
    
    con_mat = np.zeros((args.num_classes, args.num_classes))
    
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        
        inputs, targets = inputs.to(device), targets.to(device)
        
        # compute output
        if pgd_attack is not None:
            adv_inputs, _ = pgd_attack.perturb(inputs, targets)
            outputs = model(adv_inputs)
            
        elif aa_attack and not pgd_attack:
            adv_inputs = aa_attack.run_standard_evaluation(inputs, targets, bs=args.batch_size)
            outputs = model(adv_inputs)
            
        else:
            outputs = model(inputs)
        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        
        con_mat += confusion_matrix(targets.cpu(), outputs.argmax(dim=1).cpu(), labels=range(args.num_classes))
        
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
        bar.next()
    bar.finish()
    
    cls_acc = np.diagonal(con_mat)/con_mat.sum(axis=1) * 100
    #print(con_mat)
    
    if pgd_attack is not None:
        print("Rob-Acc(PGD) : %.2f, Worst Rob-Acc(PGD) : %.2f, Std of Rob-Acc(PGD) : %.2f, Max-Min of Rob-Acc(PGD) : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
        print("Corr(W, rob_acc) : %.4f"
            %(np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4))
            )
        
    elif aa_attack and not pgd_attack:
        print("Rob-Acc(AA) : %.2f, Worst Rob-Acc(AA) : %.2f, Std of Rob-Acc(AA) : %.2f, Max-Min of Rob-Acc(AA) : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
       
    else:
        print("Clean-Acc : %.2f, Worst Clean-Acc : %.2f, Std of Clean-Acc : %.2f, Max-Min of Clean-Acc : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
        
    return (losses.avg, top1.avg, cls_acc)


def save_checkpoint(out_dir, epoch, filename='checkpoint.pth.tar', **kwargs):
    state={
        'epoch' : epoch
    }
    state.update(kwargs)
    filepath = os.path.join(out_dir, filename)
    torch.save(state, filepath)
    
    print("==> saving best model")



class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

class DeCoupledSAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rhos_heads, head_indices, adaptive=False, **kwargs):
        # params는 named_parameters()로 전달된 경우
        #filtered_params = [param for name, param in params]
        filtered_params = [param for name, param in params if "bias" not in name]

        assert all(rho >= 0.0 for rho in rhos_heads), "All rho values should be non-negative."

        defaults = dict(rhos_heads=rhos_heads, adaptive=adaptive, **kwargs)
        super(DeCoupledSAM, self).__init__(filtered_params, defaults)

        self.param_groups = self.param_groups
        self.head_indices = head_indices

        # base_optimizer 초기화
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            rhos_heads = group["rhos_heads"]

            for idx, p in enumerate(group["params"]):
                if p.grad is None:
                    continue

                if idx in self.head_indices:  # head 가중치에만 클래스별로 다른 rho 적용
                    class_index = self.head_indices.index(idx)
                    scale = rhos_heads[class_index] / (grad_norm + 1e-12)

                self.state[p]["old_p"] = p.data.clone()

                # 가중치에 perturbation 추가
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p.device)
                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.data = self.state[p]["old_p"]  # 원래의 가중치 복원

        self.base_optimizer.step()  # 실제 업데이트 진행

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # full forward-backward pass를 수행하는 closure 호출

        self.first_step(zero_grad=True)  # 첫 번째 SAM 스텝
        closure()  # loss 계산 및 backward 수행
        self.second_step()  # 두 번째 SAM 스텝

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # 모델 병렬화시 동일 장치에서 연산
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups


# class DeCoupledSAM(torch.optim.Optimizer):
#     def __init__(self, params, base_optimizer, rhos_heads, head_indices, adaptive=False, **kwargs):
#         # params는 named_parameters()로 전달된 경우
#         filtered_params = [param for name, param in params if "bias" not in name]

#         assert all(rho >= 0.0 for rho in rhos_heads), "All rho values should be non-negative."

#         defaults = dict(rhos_heads=rhos_heads, adaptive=adaptive, **kwargs)
#         super(DeCoupledSAM, self).__init__(filtered_params, defaults)

#         self.param_groups = self.param_groups
#         self.head_indices = head_indices

#         # base_optimizer 초기화
#         self.base_optimizer = base_optimizer(self.param_groups, **kwargs)


#     @torch.no_grad()
#     def first_step(self, zero_grad=False):
#         grad_norm = self._grad_norm()
#         for group in self.param_groups:
#             rhos_heads = group["rhos_heads"]

#             for idx, p in enumerate(group["params"]):
#                 if p.grad is None:
#                     continue

#                 if idx in self.head_indices:  # head 가중치에만 클래스별로 다른 rho 적용
#                     class_index = self.head_indices.index(idx)
#                     scale = rhos_heads[class_index] / (grad_norm + 1e-12)

#                 self.state[p]["old_p"] = p.data.clone()

#                 # 가중치에 perturbation 추가
#                 e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p.device)
#                 p.add_(e_w)

#         if zero_grad:
#             self.zero_grad()

#     @torch.no_grad()
#     def second_step(self, zero_grad=False):
#         for group in self.param_groups:
#             for p in group["params"]:
#                 if p.grad is None:
#                     continue
#                 p.data = self.state[p]["old_p"]  # 원래의 가중치 복원

#         self.base_optimizer.step()  # 실제 업데이트 진행

#         if zero_grad:
#             self.zero_grad()

#     @torch.no_grad()
#     def step(self, closure=None):
#         assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
#         closure = torch.enable_grad()(closure)  # full forward-backward pass를 수행하는 closure 호출

#         self.first_step(zero_grad=True)  # 첫 번째 SAM 스텝
#         closure()  # loss 계산 및 backward 수행
#         self.second_step()  # 두 번째 SAM 스텝

#     def _grad_norm(self):
#         shared_device = self.param_groups[0]["params"][0].device  # 모델 병렬화시 동일 장치에서 연산
#         norm = torch.norm(
#             torch.stack([
#                 ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
#                 for group in self.param_groups for p in group["params"]
#                 if p.grad is not None
#             ]),
#             p=2
#         )
#         return norm

#     def load_state_dict(self, state_dict):
#         super().load_state_dict(state_dict)
#         self.base_optimizer.param_groups = self.param_groups





if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch WRN Adversarially Robust Model')
    
    ########################## model setting ##########################
    parser.add_argument('--depth', type=int, default=28, help='wideresnet depth factor')
    parser.add_argument('--widen_factor', type=int, default=2, help='wideresnet widen factor')
    parser.add_argument('--activation', type=str, default= 'ReLU', choices=['ReLU', 'LeakyReLU', 'SiLU'], help='choice of activation')
    parser.add_argument('--model', type=str, default= 'wideresnet', help='architecture of model') #, choices=['resnet18, wideresnet'] : invalid choice

    ########################## optimization setting ##########################
    parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--start_epoch', default=1, type=int, metavar='N', help='manual epoch number (useful on restayts)')
    parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='train batchsize')
    parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('--wd', default=5e-4, type=float, metavar='WD', help='weight decay')
    parser.add_argument('--lr_scheduler', type=str, default= 'MultiStep', choices=['MultiStep', 'Cosine', 'Cyclic'], help='learning rate scheduling')
    parser.add_argument('--amp', default=True, action='store_false', help='use of amp')

    ########################## basic setting ##########################
    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--gpu', default=0, type=int, help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--data_dir', default='/home/ydy0415/data/datasets', help='Directory of dataset')
    parser.add_argument('--save_dir', default='/home/ydy0415/data/experiments/decouple_robust/', help='Directory to output the result')
    parser.add_argument('--load_dir', default='pgd_linf_wrn-28-5_loss-trades_perturbloss-kl_eps-8_lrsche-MultiStep_lamb-6.0_smooth-0.0_seed-0', help='Directory to output the result')
    parser.add_argument('--tolerance', default=150, type=int, metavar='N', help='tolerance')
    parser.add_argument('--finetune', action='store_true', help='fine-tuning mode : freezing feature extractor part, (default: off)')

    ######################### Dataset #############################
    parser.add_argument('--dataset', type=str, default= 'cifar10', choices=['cifar10', 'cifar100', 'fmnist', 'svhn', 'stl10'], help='benchmark dataset')
    parser.add_argument('--num_classes', default=10, type=int, help='the number of classes')

    ########################## attack setting ##########################
    parser.add_argument('--train_attack', default='pgd_linf', choices=['pgd_linf', 'gapgd_linf'], help='train-attack method')
    parser.add_argument('--perturb_loss', default='ce', choices=['ce','kl'], help='perturbation loss for adversarial examples')
    parser.add_argument('--eps', type=float, default=8, help= 'maximum of perturbation magnitude' )
    parser.add_argument('--train_numsteps', type=int, default=10, help= 'train PGD number of steps')
    parser.add_argument('--test_numsteps', type=int, default=10, help= 'test PGD number of steps')
    parser.add_argument('--random_start', action='store_false', help='PGD use random start')
    parser.add_argument('--bn_mode', metavar='BN', default='eval', choices=['eval', 'train'], help='batch normalization mode of attack')

    ########################## loss setting ##########################
    parser.add_argument('--loss', metavar='LOSS', default='pgd', choices=['pgd', 'trades', 'mart', 'fatat', 'arow', 'cow'], help='surrogate loss function to optimize')
    parser.add_argument('--smooth', type=float, default=0., help='alpha of label smoothing')
    parser.add_argument('--lamb', type=float, default=6., help='coefficient of rob_loss')
    parser.add_argument('--rho', type=float, default=0.0001, help='perturbation budget of DeSAM')

    ########################## orthogonal setting ##########################
    parser.add_argument('--orthogonal', action='store_true', help='orthogonal usage loss flag (default: off)')
    parser.add_argument('--var', action='store_true', help='diag-variance loss usage flag (default: off)')
    parser.add_argument('--gamma', type=float, default=1., help='coefficient of orthogonal_loss or var_loss')
    parser.add_argument('--beta', type=float, default=1., help='multiplier of identity matrix')
    
    
    ########################## SWA setting ##########################
    parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
    parser.add_argument('--swa_start', type=float, default=51, metavar='N', help='SWA start epoch number (default: 50)')
    parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', help='SWA model collection frequency/cycle length in epochs (default: 1)')

    ######################### add name #############################
    parser.add_argument('--add_name', default='', type=str, help='add_name')

    args = parser.parse_args()
    
    main(args)
    
    