from __future__ import print_function

import matplotlib
matplotlib.use('Agg')

import argparse
import os
import time
import warnings
from datetime import date

import numpy as np
import contextlib

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from autoattack import AutoAttack
from torch.cuda import amp

import models.wideresnet as wrn_models
import models.resnet as res_models



import load_data.datasets as dataset
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p
from utils.swa import moving_average, bn_update
from loss import *
from attacks.pgd import PGD_Linf, PGD_L2, GA_PGD
from sklearn.metrics import confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(args: argparse.Namespace):
    #global best_acc
    #global attack_best_acc
    
    #global test_acc
    #global attack_test_acc
    
    print(args)
    
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    
    
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    
    args.save_dir += f'/{args.dataset}/'
    args.save_dir += str(date.today().strftime('%Y%m%d')[2:])
    
    args.save_dir += f'/{args.train_attack}'
    #args.save_dir += f'_wrn-{args.depth}-{args.widen_factor}'
    args.save_dir += f'_resnet18'
    
    args.save_dir += f'_loss-{args.loss}_perturbloss-{args.perturb_loss}_eps-{args.eps}_lrsche-{args.lr_scheduler}'
        
    
    if args.loss in ['trades', 'mart', 'arow', 'cow']:
        args.save_dir += f'_lamb-{args.lamb}'
    
    if args.loss in ['pgd', 'trades', 'arow', 'cow']:
        args.save_dir += f'_smooth-{args.smooth}'
    
    if args.orthogonal:
        args.save_dir += f'_gamma-{args.gamma}'
        args.save_dir += f'_beta-{args.beta}'
    elif args.var:
        args.save_dir += f'_var-{args.var}'
        args.save_dir += f'_gamma-{args.gamma}'
    
    if args.swa:
        args.save_dir += f'_swa-{args.swa}'
        
    if args.finetune:
        args.save_dir += f'_fine-{args.finetune}'    
    
    args.save_dir += f'_seed-{args.seed}' 
        
    if args.add_name != '':
        args.save_dir += f'_{args.add_name}'
    
    if not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)
    
    _, _, train_loader, test_loader = dataset.load_data(args.data_dir,
                                                        args.dataset,
                                                        batch_size=args.batch_size,
                                                        batch_size_test=100,
                                                        num_workers=2,
                                                        use_augmentation=True,
                                                        shuffle_train=True,
                                                        validation=False)
    # Model

    # def create_model(args):
    #     if args.model == 'wideresnet':
    #         print("==> creating WideResNet" + str(args.depth) + '-' + str(args.widen_factor))
    #         if args.swa:
    #             swa_model = wrn_models.WideResNet(num_classes=args.num_classes,
    #                                               depth=args.depth,
    #                                               widen_factor=args.widen_factor,
    #                                               activation=args.activation).cuda(args.gpu)
    #             model = wrn_models.WideResNet(num_classes=args.num_classes,
    #                                           depth=args.depth,
    #                                           widen_factor=args.widen_factor,
    #                                           activation=args.activation).cuda(args.gpu)

    #             return swa_model, model

    #         else:
    #             model = wrn_models.WideResNet(num_classes=args.num_classes,
    #                                           depth=args.depth,
    #                                           widen_factor=args.widen_factor,
    #                                           activation=args.activation).cuda(args.gpu)
                
    #             return model

    def create_model(args):
        if args.model == 'wideresnet':
            print("==> creating WideResNet" + str(args.depth) + '-' + str(args.widen_factor))
            if args.swa:
                swa_model = wrn_models.WideResNet(num_classes=10,
                                                  depth=args.depth,
                                                  widen_factor=args.widen_factor,
                                                  activation=args.activation).cuda(args.gpu)
                model = wrn_models.WideResNet(num_classes=10,
                                              depth=args.depth,
                                              widen_factor=args.widen_factor,
                                              activation=args.activation).cuda(args.gpu)

                return swa_model, model

            else:
                model = wrn_models.WideResNet(num_classes=10,
                                              depth=args.depth,
                                              widen_factor=args.widen_factor, 
                                              activation=args.activation).cuda(args.gpu)
                
                return model
        
        elif args.model == 'resnet18':
            print("==> creating ResNet18")
            if args.swa:
                swa_model = res_models.resnet('resnet18',
                                              3,
                                              num_classes=10).cuda(args.gpu)
                model = res_models.resnet('resnet18',
                                          3,
                                          num_classes=10).cuda(args.gpu)

                return swa_model, model
            
            else:
                model = res_models.resnet('resnet18',
                                          3,
                                          num_classes=10).cuda(args.gpu)
            
                return model
    #breakpoint()       

    if args.swa:
        swa_model, model = create_model(args)
        swa_n = 0
    else:
        model = create_model(args)

        
    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
    
    if args.finetune:
        checkpoint = torch.load('/home/ydy0415/data/experiments/decouple_robust' + \
                                f'/{args.dataset}/230925/pgd_linf_wrn-28-5_loss-pgd_perturbloss-ce_eps-8_lrsche-MultiStep_smooth-0.0_seed-0' + \
                                '/last.pth.tar')
        model.load_state_dict(checkpoint['state_dict'])
        del checkpoint
    
    #breakpoint()
    criterion = nn.CrossEntropyLoss()
    
    if args.train_attack == 'pgd_linf':
        train_attack = PGD_Linf(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.train_numsteps, random_start=args.random_start,
                                criterion=args.perturb_loss, bn_mode = args.bn_mode, train = True)
    elif args.train_attack == 'gapgd_linf':
        train_attack = GA_PGD(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.train_numsteps, random_start=args.random_start,
                                criterion=args.perturb_loss, bn_mode = args.bn_mode, train = True)
        
    test_attack = PGD_Linf(model=model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.test_numsteps, random_start=args.random_start, criterion='ce', bn_mode = args.bn_mode, train = False)
        
    
    if args.finetune:
        #head_params = [param for name, param in model.fc.named_parameters() if "bias" not in name]
        head_params = [param for name, param in model.linear.named_parameters() if "bias" not in name]
        base_optimizer = optim.SGD
        optimizer = DeCoupledSAM(head_params, base_optimizer, lr=args.lr, momentum=0.9, weight_decay=args.wd)
        #class DeCoupledSAM(torch.optim.Optimizer):
        #def __init__(self, params, base_optimizer, rhos_heads=None, head_indices=None, adaptive=False, feilter_keywords=None, **kwargs):
        
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
    
    if args.lr_scheduler == "MultiStep":
        if args.swa:
            scheduler = lr_scheduler.MultiStepLR(optimizer , milestones=[2/4*args.epochs, 3/4*args.epochs], gamma=0.1)
        else:
            scheduler = lr_scheduler.MultiStepLR(optimizer , milestones=[90, 95], gamma=0.1)
            #scheduler = lr_scheduler.MultiStepLR(optimizer , milestones=[50], gamma=0.1)
            
    elif args.lr_scheduler == "Cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    
    logger = Logger(os.path.join(args.save_dir, 'log.txt'), title=args.dataset)
    logger.set_names(['Epoch', 'Train Loss', 'Test Loss', 'Test Acc.', 'Attack Loss' , 'Attack Acc.'])
    
    logger_disparity = Logger(os.path.join(args.save_dir, 'log_disparity.txt'), title=args.dataset)
    logger_disparity.set_names(['Epoch', 'Rob-Acc(PGD)', 'Worst Rob-Acc(PGD)', 'Std of Rob-Acc(PGD)', 'Max-Min of Rob-Acc(PGD)', 'Std of Ws', 'Corr(W, Rob_Acc)'])
    
    attack_best_acc, best_acc, tolerance = 0, 0, 0
    
    # Train and val
    for epoch in range(args.start_epoch, args.epochs + 1):

        print('\n'+args.train_attack +' Epoch: [%d | %d] LR: %.5f Tol: %d Best ts acc: %.2f Best_att_acc: %.2f ' % (epoch, args.epochs, optimizer.param_groups[0]['lr'], tolerance, best_acc, attack_best_acc))
        # if args.finetune and epoch == 1:
        #     test_loss, test_acc = validate(test_loader, model, criterion, mode='Test')
        #     attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args, test_loader, model, criterion, mode='Attack_test', pgd_attack=test_attack)
        #     #logger.append([round(epoch), train_loss, test_loss, test_acc, attack_test_loss,  attack_test_acc])
        #     logger_disparity.append([round(epoch),
        #                             attack_test_acc,
        #                             cls_acc.min(),
        #                             cls_acc.std(),
        #                             cls_acc.max() - cls_acc.min(),
        #                             (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
        #                             np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
        #                             ])
            
        #     print("Std of Ws : %.6f"
        #     %((model.fc.weight.detach().cpu().square().mean(dim=1).sqrt()/model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std())
        #     )

        #     auto_attack= AutoAttack(model, norm='Linf', eps=8/255, version='standard', verbose=False)
        #     auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
        #     attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args,
        #                                                         test_loader,
        #                                                         model,
        #                                                         criterion,
        #                                                         "Final",
        #                                                         pgd_attack=None,
        #                                                         aa_attack=auto_attack)    

        #     logger_disparity.append([epoch,
        #                             attack_test_acc,
        #                             cls_acc.min(),
        #                             cls_acc.std(),
        #                             cls_acc.max() - cls_acc.min(),
        #                             (model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
        #                             np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
        #                             ])

        if args.finetune:
            train_loss = finetune_DeSAM(args, train_loader, model, optimizer, attack=train_attack)

        else:
            train_loss = train(args, train_loader, model, optimizer, attack=train_attack)
            
        if args.swa and epoch == args.swa_start:
            test_attack =  PGD_Linf(model=swa_model, epsilon=args.eps/255, step_size=(args.eps/4)/255, num_steps=args.test_numsteps, random_start=args.random_start, criterion='ce', bn_mode = args.bn_mode, train = False)
        
        if args.swa and epoch >= args.swa_start and (epoch - args.swa_start) % args.swa_c_epochs == 0:
            moving_average(swa_model, model, 1.0 / (swa_n + 1))
            swa_n += 1
            if epoch >= 70:
                bn_update(train_loader, swa_model)
                #_, train_acc = validate(train_loader, swa_model, criterion, use_cuda, mode='Train')
                test_loss, test_acc = validate(test_loader, swa_model, criterion, mode='Test')
                attack_test_loss, attack_test_acc = validate(test_loader, swa_model, criterion, mode='Attack_test', attack=test_attack)
                logger.append([round(epoch), train_loss, test_loss, test_acc, attack_test_loss,  attack_test_acc])
                #logger_disparity([])
            
        else:
            if epoch == 1 or epoch % 2 == 0 or epoch >= 90 or epoch == args.epochs:
                test_loss, test_acc = validate(test_loader, model, criterion, mode='Test')
                attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args, test_loader, model, criterion, mode='Attack_test', pgd_attack=test_attack)
                logger.append([round(epoch), train_loss, test_loss, test_acc, attack_test_loss,  attack_test_acc])
                logger_disparity.append([round(epoch),
                                        attack_test_acc,
                                        cls_acc.min(),
                                        cls_acc.std(),
                                        cls_acc.max() - cls_acc.min(),
                                        #(model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                        (model.linear.weight.detach().cpu().square().mean(dim=1).sqrt() / model.linear.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                                        np.corrcoef(model.linear.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                                        ])
                
                print("Std of Ws : %.6f"
                %((model.linear.weight.detach().cpu().square().mean(dim=1).sqrt()/model.linear.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std())
                )
            
        
        scheduler.step()
        
        # save model
        is_attack_best = attack_test_acc > attack_best_acc
        attack_best_acc = max(attack_test_acc, attack_best_acc)
            
        if is_attack_best:
            attack_best_acc = attack_test_acc
            best_acc = test_acc
        
        if args.swa:
            if epoch >= args.swa_start and (epoch - args.swa_start) % args.swa_c_epochs == 0 and is_attack_best:
                save_checkpoint(args.save_dir, epoch,
                filename='robust_best.pth.tar',
                swa_state_dict = swa_model.state_dict(),
                swa_n = swa_n,
                state_dict = model.state_dict(),
                test_acc =  test_acc,
                attack_test_acc = attack_test_acc,
                optimizer = optimizer.state_dict()
                )
            elif epoch < args.swa_start and is_attack_best:
                save_checkpoint(args.save_dir, epoch, 
                filename='robust_best.pth.tar',
                state_dict = model.state_dict(),
                test_acc = test_acc,
                attack_test_acc = attack_test_acc,
                optimizer = optimizer.state_dict()
                )
            elif epoch == args.epochs:
                save_checkpoint(args.save_dir, epoch, 
                filename='last.pth.tar',
                swa_state_dict = swa_model.state_dict(),
                test_acc = test_acc,
                attack_test_acc = attack_test_acc
                )
                
        elif not args.swa:
            if is_attack_best:
                save_checkpoint(args.save_dir, epoch, 
                    filename = 'robust_best.pth.tar',
                    state_dict = model.state_dict(),
                    test_acc = test_acc,
                    attack_test_acc = attack_test_acc,
                    optimizer = optimizer.state_dict()
                    )

            
            if epoch == args.epochs:
                save_checkpoint(args.save_dir, epoch, 
                filename='last.pth.tar',
                state_dict = model.state_dict(),
                test_acc = test_acc,
                attack_test_acc = attack_test_acc
                )

        if is_attack_best:
            tolerance = 0
        else:
            tolerance += 1
        
    logger.close()
    
    
    test_attack = PGD_Linf(model=model,
                           epsilon=8/255,
                           step_size=(8/4)/255,
                           num_steps=20,
                           random_start=args.random_start, 
                           criterion='ce',
                           bn_mode = args.bn_mode,
                           train = False)
    
    auto_attack= AutoAttack(model, norm='Linf', eps=8/255, version='standard', verbose=False)
    auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
    
    validate_disparity(args, test_loader, model, criterion, "Final", pgd_attack=None, aa_attack=None)
    validate_disparity(args, test_loader, model, criterion, "Final", pgd_attack=test_attack, aa_attack=None)
    validate_disparity(args, test_loader, model, criterion, "Final", pgd_attack=None, aa_attack=auto_attack)
    
    

    # weight normalization
    #normalized_weights = torch.nn.functional.normalize(fc_weights, p=2, dim=1)
    #fc_weights = model.fc.weight.data
    fc_weights = model.linear.weight.data
    normalized_weights = fc_weights/(torch.norm(fc_weights, p=2, dim=1)).pow(1).unsqueeze(-1)

    # 새로운 fully connected layer 생성
    #new_fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=model.fc.out_features, bias=True)
    new_fc = torch.nn.Linear(in_features=model.linear.in_features, out_features=model.linear.out_features, bias=True)
    new_fc.weight.data = normalized_weights
    #new_fc.bias.data = model.fc.bias.data
    new_fc.bias.data = model.linear.bias.data

    #

    # 기존 ResNet 모델에 새로운 fully connected layer 적용
    #model.fc = new_fc
    model.linear = new_fc
    model.cuda()
    
    test_attack = PGD_Linf(model=model,
                           epsilon=8/255, 
                           step_size=(8/4)/255,
                           num_steps=20,
                           random_start=args.random_start,
                           criterion='ce',
                           bn_mode = args.bn_mode,
                           train = False)
    
    auto_attack= AutoAttack(model, norm='Linf', eps=8/255, version='standard', verbose=False)
    auto_attack.attacks_to_run = ['apgd-ce', 'apgd-t']
    
    validate_disparity(args, test_loader, model, criterion, "Final_Norm", pgd_attack=None, aa_attack=None)
    validate_disparity(args, test_loader, model, criterion, "Final_Norm", pgd_attack=test_attack, aa_attack=None)
    attack_test_loss, attack_test_acc, cls_acc = validate_disparity(args,
                                                                    test_loader,
                                                                    model,
                                                                    criterion,
                                                                    "Final_Norm",
                                                                    pgd_attack=None,
                                                                    aa_attack=auto_attack)    
    
    logger_disparity.append([epoch,
                            attack_test_acc,
                            cls_acc.min(),
                            cls_acc.std(),
                            cls_acc.max() - cls_acc.min(),
                            #(model.fc.weight.detach().cpu().square().mean(dim=1).sqrt() / model.fc.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                            (model.linear.weight.detach().cpu().square().mean(dim=1).sqrt() / model.linear.weight.detach().cpu().square().mean(dim=1).sqrt().min()).std(),
                            #np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                            np.corrcoef(model.linear.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4)
                            ])


def train(args, train_loader, model, optimizer, attack):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    sup_losses = AverageMeter()
    reg_losses = AverageMeter()
    losses = AverageMeter()
    
    ce_loss=nn.CrossEntropyLoss()
    end = time.time()
    
    scaler = amp.GradScaler()
    amp_cm = amp.autocast if args.amp else contextlib.nullcontext
    
    bar = Bar('{:>12}'.format('Training'), max=len(train_loader))
    
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):

        data_time.update(time.time() - end)
        
        inputs, targets = inputs.to(device), targets.to(device)
    
        with amp_cm():
            
            if args.loss == "arow":
                #if args.perturb_loss not in ("kl", "revkl", "js"):
                #     raise ValueError("perturb loss must be kl or revkl divergence.")
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = ARoW_loss(inputs, adv_inputs, targets, model, args.smooth)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
                
            elif args.loss == "cow":
                if args.perturb_loss not in ("kl", "revkl", "js"):
                     raise ValueError("perturb loss must be kl or revkl divergence.")
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = CoW_loss(inputs, adv_inputs, targets, model, args.smooth)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
            
            elif args.loss == "mart":
                if args.perturb_loss != "ce":
                     raise ValueError("perturb loss must be ce.")
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = MART_loss(inputs, adv_inputs, targets, model)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
                
            elif args.loss == "trades":
                adv_inputs, _      = attack.perturb(inputs, targets)
                sup_loss, reg_loss = TRADES_loss(inputs, adv_inputs, targets, model, args.smooth)
                reg_loss           = args.lamb * reg_loss
                loss               = sup_loss + reg_loss
            
            elif args.loss == "pgd":
                if args.perturb_loss != "ce":
                     raise ValueError("perturb loss must be ce.")
                adv_inputs, _ = attack.perturb(inputs, targets)
                adv_outputs   = model(adv_inputs)
                sup_loss      = ce_loss(adv_outputs, targets)
                reg_loss      = torch.tensor(0)
                loss          = sup_loss
                
        if args.orthogonal:
            ortho_loss = Orthogonal_loss(model, args.beta)
            loss += args.gamma * ortho_loss
            
        if args.var:
            var_loss = Diag_Var_loss(model)
            loss += args.gamma * var_loss     
        # record loss
        sup_losses.update(sup_loss.item(), inputs.size(0))
        reg_losses.update(reg_loss.item(), inputs.size(0))
        losses.update(loss.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if args.amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Sup_loss: {sup_loss:.4f} | Reg_loss: {reg_loss:.4f} |  Tot loss:{loss:.4f}'.format(
                    batch   = batch_idx + 1,
                    size    = len(train_loader),
                    data    = data_time.avg,
                    bt      = batch_time.avg,
                    total   = bar.elapsed_td,
                    eta     = bar.eta_td,
                    sup_loss=sup_losses.avg,
                    reg_loss=reg_losses.avg,
                    loss=losses.avg
                    )
        bar.next()
    bar.finish()
                  
    return losses.avg


def update_optimizer_rhos(optimizer, new_rhos_head):
    """
    Optimizer의 param_groups 내 rhos_heads 값을 업데이트하는 함수
    """
    for group in optimizer.param_groups:
        group['rhos_heads'] = new_rhos_head


def finetune_DeSAM(args, train_loader, model, optimizer, attack):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    sup_losses = AverageMeter()
    reg_losses = AverageMeter()
    losses = AverageMeter()

    ce_loss = nn.CrossEntropyLoss()
    end = time.time()

    bar = Bar('{:>12}'.format('Training'), max=len(train_loader))

    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)
        if batch_idx == 0:  # rho_heads를 업데이트해야 하는 시점
            fc_weights = model.fc.weight.data
            l2_norms = torch.norm(fc_weights, p=2, dim=1)

            # Weight normalization
            normalized_weights = fc_weights / (l2_norms).unsqueeze(-1)

            # 새로운 fully connected layer 생성
            new_fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=model.fc.out_features, bias=True)
            new_fc.weight.data = normalized_weights
            new_fc.bias.data = model.fc.bias.data
            model.fc = new_fc
            model.cuda()

            # L2 norm을 기반으로 rhos 설정
            exp_values = torch.exp(-l2_norms)
            normalized_values = exp_values / exp_values.sum()
            rhos_head = args.rho * normalized_values  # 새로운 rho_heads 설정

            # Optimizer의 param_groups만 업데이트
            #head_params = [param for name, param in model.fc.named_parameters() if "bias" not in name]
            head_params = [param for name, param in model.fc.named_parameters()]
            base_optimizer = optim.SGD
            optimizer = DeCoupledSAM(head_params, base_optimizer, rhos_head, lr=args.lr, momentum=0.9, weight_decay=args.wd)

        # 기존 ResNet 모델에 새로운 fully connected layer 적용
        inputs, targets = inputs.to(device), targets.to(device)

        # Generate adversarial inputs
        adv_inputs, _ = attack.perturb(inputs, targets)

        # Define closure function that calculates the losses
        def closure():
            adv_inputs, _ = attack.perturb(inputs, targets)  # Perturb inputs again
            if args.loss == "arow":
                sup_loss, reg_loss = ARoW_loss(inputs, adv_inputs, targets, model, args.smooth)
            elif args.loss == "cow":
                sup_loss, reg_loss = CoW_loss(inputs, adv_inputs, targets, model, args.smooth)
            elif args.loss == "mart":
                sup_loss, reg_loss = MART_loss(inputs, adv_inputs, targets, model)
            elif args.loss == "trades":
                sup_loss, reg_loss = TRADES_loss(inputs, adv_inputs, targets, model, args.smooth)
            elif args.loss == "pgd":
                adv_outputs = model(adv_inputs)
                sup_loss = ce_loss(adv_outputs, targets)
                reg_loss = torch.tensor(0.)  # No regularization for PGD loss

            total_loss = sup_loss + args.lamb * reg_loss

            if args.orthogonal:
                ortho_loss = Orthogonal_loss(model, args.beta)
                total_loss += args.gamma * ortho_loss
            if args.var:
                var_loss = Diag_Var_loss(model)
                total_loss += args.gamma * var_loss

            return total_loss, sup_loss, reg_loss

        # First SAM step
        optimizer.zero_grad()
        total_loss, sup_loss, reg_loss = closure()  # Compute initial loss
        total_loss.backward()  # First backward pass
        optimizer.first_step()  # First SAM step

        # Calculate loss again after first step
        total_loss, sup_loss, reg_loss = closure()  # Compute loss after perturbation
        total_loss.backward()  # Second backward pass
        optimizer.second_step()  # Second SAM step

        # Record losses
        sup_losses.update(sup_loss.item(), inputs.size(0))
        reg_losses.update(reg_loss.item(), inputs.size(0))
        losses.update(total_loss.item(), inputs.size(0))

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Plot progress
        bar.suffix = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Sup_loss: {sup_loss:.4f} | Reg_loss: {reg_loss:.4f} |  Tot loss:{loss:.4f}'.format(
            batch=batch_idx + 1,
            size=len(train_loader),
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            sup_loss=sup_losses.avg,
            reg_loss=reg_losses.avg,
            loss=losses.avg
        )
        bar.next()
    bar.finish()

    return losses.avg



def validate(val_loader, model, criterion, mode, attack=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    
    bar = Bar('{mode:>12}'.format(mode=mode), max=len(val_loader))
   
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        
        inputs, targets = inputs.to(device), targets.to(device)
        
        # compute output
        if attack is not None:
            adv_inputs, _ = attack.perturb(inputs, targets)
            outputs = model(adv_inputs)
        else:
            outputs = model(inputs)
        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
        bar.next()
    bar.finish()
        
    return (losses.avg, top1.avg)

def validate_disparity(args, val_loader, model, criterion, mode, pgd_attack=None, aa_attack=None):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    
    bar = Bar('{mode:>12}'.format(mode=mode), max=len(val_loader))
    
    con_mat = np.zeros((args.num_classes, args.num_classes))
    
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        
        inputs, targets = inputs.to(device), targets.to(device)
        
        # compute output
        if pgd_attack is not None:
            adv_inputs, _ = pgd_attack.perturb(inputs, targets)
            outputs = model(adv_inputs)
            
        elif aa_attack and not pgd_attack:
            adv_inputs = aa_attack.run_standard_evaluation(inputs, targets, bs=args.batch_size)
            outputs = model(adv_inputs)
            
        else:
            outputs = model(inputs)
        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
        
        con_mat += confusion_matrix(targets.cpu(), outputs.argmax(dim=1).cpu(), labels=range(args.num_classes))
        
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch:>3}/{size:>3}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx + 1,
                        size=len(val_loader),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
        bar.next()
    bar.finish()
    
    cls_acc = np.diagonal(con_mat)/con_mat.sum(axis=1) * 100
    
    if pgd_attack is not None:
        print("Rob-Acc(PGD) : %.2f, Worst Rob-Acc(PGD) : %.2f, Std of Rob-Acc(PGD) : %.2f, Max-Min of Rob-Acc(PGD) : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
        print("Corr(W, rob_acc) : %.4f"
            #%(np.corrcoef(model.fc.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4))
            %(np.corrcoef(model.linear.weight.detach().cpu().square().mean(dim=1), torch.tensor(cls_acc))[0,1].round(4))
            )
        
    elif aa_attack and not pgd_attack:
        print("Rob-Acc(AA) : %.2f, Worst Rob-Acc(AA) : %.2f, Std of Rob-Acc(AA) : %.2f, Max-Min of Rob-Acc(AA) : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
       
    else:
        print("Clean-Acc : %.2f, Worst Clean-Acc : %.2f, Std of Clean-Acc : %.2f, Max-Min of Clean-Acc : %.2f"
            %(cls_acc.mean(), cls_acc.min(), cls_acc.std(), cls_acc.max()-cls_acc.min())
            )
        
    return (losses.avg, top1.avg, cls_acc)


def save_checkpoint(out_dir, epoch, filename='checkpoint.pth.tar', **kwargs):
    state={
        'epoch' : epoch
    }
    state.update(kwargs)
    filepath = os.path.join(out_dir, filename)
    torch.save(state, filepath)
    
    print("==> saving best model")



class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups


class DeCoupledSAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rhos_head=None, head_indices=None, adaptive=False, **kwargs):
        # params는 모든 파라미터를 포함하므로 별도의 필터링이 필요 없습니다.
        if rhos_head is not None:
            assert all(rho >= 0.0 for rho in rhos_head), "All rho values should be non-negative."

        defaults = dict(rhos_head=rhos_head, adaptive=adaptive, **kwargs)
        super(DeCoupledSAM, self).__init__(params, defaults)

        self.head_indices = head_indices
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()

        for group in self.param_groups:
            rhos_head = group["rhos_head"]

            for idx, p in enumerate(group["params"]):
                if p.grad is None:
                    continue

                if self.head_indices is not None and idx in self.head_indices:  # Check if head_indices is not None
                    if rhos_head is None:
                        raise ValueError("rhos_head must be initialized before the first step.")

                    class_index = self.head_indices.index(idx)
                    scale = rhos_head[class_index] / (grad_norm + 1e-12)
                else:
                    scale = 1.0  # 기본 스케일

                self.state[p]["old_p"] = p.data.clone()

                # Scale을 텐서로 변환하여 .to() 사용
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * torch.tensor(scale, device=p.device)
                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                p.data = self.state[p]["old_p"]

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups




if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch WRN Adversarially Robust Model')
    
    ########################## model setting ##########################
    parser.add_argument('--depth', type=int, default=28, help='wideresnet depth factor')
    parser.add_argument('--widen_factor', type=int, default=2, help='wideresnet widen factor')
    parser.add_argument('--activation', type=str, default= 'ReLU', choices=['ReLU', 'LeakyReLU', 'SiLU'], help='choice of activation')
    parser.add_argument('--model', type=str, default= 'wideresnet', help='architecture of model') #, choices=['resnet18, wideresnet'] : invalid choice

    ########################## optimization setting ##########################
    parser.add_argument('--epochs', default=100, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--start_epoch', default=1, type=int, metavar='N', help='manual epoch number (useful on restayts)')
    parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='train batchsize')
    parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('--wd', default=5e-4, type=float, metavar='WD', help='weight decay')
    parser.add_argument('--lr_scheduler', type=str, default= 'MultiStep', choices=['MultiStep', 'Cosine', 'Cyclic'], help='learning rate scheduling')
    parser.add_argument('--amp', default=True, action='store_false', help='use of amp')

    ########################## basic setting ##########################
    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--gpu', default=0, type=int, help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--data_dir', default='/home/ydy0415/data/datasets', help='Directory of dataset')
    parser.add_argument('--save_dir', default='/home/ydy0415/data/experiments/decouple_robust/', help='Directory to output the result')
    parser.add_argument('--tolerance', default=150, type=int, metavar='N', help='tolerance')
    parser.add_argument('--finetune', action='store_true', help='fine-tuning mode : freezing feature extractor part, (default: off)')

    ######################### Dataset #############################
    parser.add_argument('--dataset', type=str, default= 'cifar10', choices=['cifar10', 'cifar100', 'fmnist', 'svhn', 'stl10'], help='benchmark dataset')
    parser.add_argument('--num_classes', default=10, type=int, help='the number of classes')

    ########################## attack setting ##########################
    parser.add_argument('--train_attack', default='pgd_linf', choices=['pgd_linf', 'gapgd_linf'], help='train-attack method')
    parser.add_argument('--perturb_loss', default='ce', choices=['ce','kl'], help='perturbation loss for adversarial examples')
    parser.add_argument('--eps', type=float, default=8, help= 'maximum of perturbation magnitude' )
    parser.add_argument('--train_numsteps', type=int, default=10, help= 'train PGD number of steps')
    parser.add_argument('--test_numsteps', type=int, default=10, help= 'test PGD number of steps')
    parser.add_argument('--random_start', action='store_false', help='PGD use random start')
    parser.add_argument('--bn_mode', metavar='BN', default='eval', choices=['eval', 'train'], help='batch normalization mode of attack')

    ########################## loss setting ##########################
    parser.add_argument('--loss', metavar='LOSS', default='pgd', choices=['pgd', 'trades', 'mart', 'fatat', 'arow', 'cow'], help='surrogate loss function to optimize')
    parser.add_argument('--smooth', type=float, default=0., help='alpha of label smoothing')
    parser.add_argument('--lamb', type=float, default=6., help='coefficient of rob_loss')
    parser.add_argument('--rho', type=float, default=0.005, help='perturbation budget of DeSAM')

    ########################## orthogonal setting ##########################
    parser.add_argument('--orthogonal', action='store_true', help='orthogonal usage loss flag (default: off)')
    parser.add_argument('--var', action='store_true', help='diag-variance loss usage flag (default: off)')
    parser.add_argument('--gamma', type=float, default=1., help='coefficient of orthogonal_loss or var_loss')
    parser.add_argument('--beta', type=float, default=1., help='multiplier of identity matrix')
    
    
    ########################## SWA setting ##########################
    parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
    parser.add_argument('--swa_start', type=float, default=51, metavar='N', help='SWA start epoch number (default: 50)')
    parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', help='SWA model collection frequency/cycle length in epochs (default: 1)')

    ######################### add name #############################
    parser.add_argument('--add_name', default='', type=str, help='add_name')

    args = parser.parse_args()
    
    main(args)
    
    