
import torch
import torch.nn.functional as F
import os,pdb,yaml
import argparse
import torch
import torch.optim as optim
from torchvision.transforms import AutoAugmentPolicy
from models.wideresnet import *
from models.resnet import *
from attack import *
from pgd_attack import eval_adv_test_whitebox
from datasets.builder import build_datasets
from datasets.loader.build_loader import build_dataloader
from args import create_parser
from torchvision import datasets, transforms
from torch.utils.data import SubsetRandomSampler, DataLoader, TensorDataset
import pdb
import numpy as np
import shutil
from pathlib import Path
try:
    import wandb
except ImportError:
    wandb = None
# we fix the random seed to 0, this method can keep the results consistent in the same conputer. 
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
from robustbench.utils import load_model
from autoattack import AutoAttack
import time
start_time = time.strftime('%Y-%m-%d-%M', time.localtime(time.time()))
from status import ProgressBar
import torchattacks

def evaluate_pgd_ensemble(models,loader, args = None):
    if args.dataset == "cifar10":
        num_classes = 10
    elif args.dataset == "cifa100":
        num_classes = 100
    eps = 8/255
    # model.eval()
    clean_acc = 0
    fgsm_acc = aa_acc = pgd_acc = pgd_l1_acc = pgd_l2_acc =  0
    total_samples = 0
    # pgd10_attack = torchattacks.PGD(model, eps= eps, alpha=2/225, steps=10, random_start=True)
    num_per_class = torch.zeros([num_classes])
    clean_accs = torch.zeros([num_classes])
    robust_accs = torch.zeros([num_classes])

    
    for i,(x,y) in enumerate(loader):
        total_samples += x.shape[0]
        x,y = x.cuda(), y.cuda()
        Zs = []
        Z_advs = []
        # x_pgd10 = pgd10_attack(x,y)
        for model in models : 
            Zs.append(model(x))
            pgd10_attack = torchattacks.PGD(model, eps= eps, alpha=2/225, steps=10, random_start=True)
            x_pgd10 = pgd10_attack(x,y)
            Z_advs.append(model(x_pgd10))
        
        z = torch.mean(torch.stack(Zs), dim=0)
        z_pgd = torch.mean(torch.stack(Z_advs), dim=0)

        pgd_acc += ( z_pgd.argmax(dim = 1) == y).sum()
        clean_acc += (z.argmax(dim = 1) == y).sum()
        
        for class_id in range(num_classes):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            clean_accs[class_id] += (z[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            robust_accs[class_id] += (z_pgd[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
        
        Zs.clear()
        Z_advs.clear()

    clean_acc = clean_acc/total_samples * 100.0
    pgd_acc = pgd_acc/total_samples * 100.0
    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    
    clean_acc = clean_acc.detach().cpu().item()
    pgd_acc = pgd_acc.detach().cpu().item()
    head_clean_acc, tail_clean_acc = clean_accs[:2].mean(), clean_accs[-2:].mean()
    head_rob_acc, tail_rob_acc = robust_accs[:2].mean(), robust_accs[-2:].mean()


    print (clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_rob_acc, tail_rob_acc)
    return  clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_rob_acc, tail_rob_acc



def evaluate_pgd_classwise(model,loader, args = None):
    if args.dataset == "cifar10":
        num_classes = 10
    elif args.dataset == "cifa100":
        num_classes = 100
    eps = 8/255
    clean_acc = 0
    fgsm_acc = aa_acc = pgd_acc = pgd_l1_acc = pgd_l2_acc =  0
    total_samples = 0
    pgd10_attack = torchattacks.PGD(model, eps= eps, alpha=2/225, steps=10, random_start=True)
    num_per_class = torch.zeros([num_classes])
    clean_accs = torch.zeros([num_classes])
    robust_accs = torch.zeros([num_classes])


    for i,(x,y) in enumerate(loader):
        total_samples += x.shape[0]
        x,y = x.cuda(), y.cuda()
        x_pgd10 = pgd10_attack(x,y)
        z_pgd = model(x_pgd10)
        z = model(x)
        pgd_acc += ( z_pgd.argmax(dim = 1) == y).sum()
        clean_acc += (z.argmax(dim = 1) == y).sum()
        
        for class_id in range(num_classes):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            clean_accs[class_id] += (z[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            robust_accs[class_id] += (z_pgd[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
        
        

    clean_acc = clean_acc/total_samples * 100.0
    pgd_acc = pgd_acc/total_samples * 100.0
    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    
    clean_acc = clean_acc.detach().cpu().item()
    pgd_acc = pgd_acc.detach().cpu().item()


    # print (clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_rob_acc, tail_rob_acc)
    return  clean_accs, robust_accs


def evaluate_pgd(model,loader, config = None):
    eps = config.eps if config == None else 8/255
    model.eval()
    clean_acc = 0
    fgsm_acc = aa_acc = pgd_acc = pgd_l1_acc = pgd_l2_acc =  0
    total_samples = 0
    pgd10_attack = torchattacks.PGD(model, eps= eps, alpha=2/225, steps=10, random_start=True)

    for i,(x,y) in enumerate(loader):
        total_samples += x.shape[0]
        x,y = x.cuda(), y.cuda()
        x_pgd10 = pgd10_attack(x,y)

        z_pgd = model(x_pgd10)
        z = model(x)

        z_pgd_out = z_pgd.argmax(dim = 1)
        z_out = z.argmax(dim = 1)

        pgd_acc += (z_pgd_out == y).sum()
        clean_acc += (z_out == y).sum()


    clean_acc = clean_acc/total_samples * 100.0
    pgd_acc = pgd_acc/total_samples * 100.0


    return  clean_acc, pgd_acc




def evaluate_fgsm_head_tail2(model,loader,ratio, config = None):
    num_classes = config.num_classes

    eps = config.eps/255
    clean_acc = 0
    fgsm_acc = aa_acc = pgd_acc = pgd_l1_acc = pgd_l2_acc =  0
    total_samples = 0

    pgd_with_eps_2 = torchattacks.PGD(model, eps= eps/4, alpha=eps/4, steps=10, random_start=True)
    num_per_class = torch.zeros([num_classes])
    clean_accs = torch.zeros([num_classes])
    robust_accs = torch.zeros([num_classes])
    weak_accs = torch.zeros([num_classes])

    ratio = int(ratio)

    for i,(x,y) in enumerate(loader):
        total_samples += x.shape[0]
        x,y = x.cuda(), y.cuda()
        x_weak = pgd_with_eps_2(x,y)
        z = model(x)
        z_weak = model(x_weak)
        
        for class_id in range(num_classes):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            clean_accs[class_id] += (z[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            weak_accs[class_id] += (z_weak[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
 
        

    clean_accs = clean_accs/num_per_class * 100
    weak_accs = weak_accs/num_per_class * 100

    full_clean = clean_accs.mean()
    tail_clean = clean_accs[-ratio:].mean()

    full_weak_adv = weak_accs.mean()
    tail_weak_adv = weak_accs[-ratio:].mean()

    return  full_clean, tail_clean, full_weak_adv, tail_weak_adv




def evaluate_fgsm_head_tail(model,loader,ratio, config = None):
    num_classes = config.num_classes

    eps = config.eps/255
    clean_acc = 0
    fgsm_acc = aa_acc = pgd_acc = pgd_l1_acc = pgd_l2_acc =  0
    total_samples = 0
    pgd10_attack = torchattacks.PGD(model, eps= eps, alpha=eps, steps=1, random_start=True)
    weak_attacks1 = torchattacks.PGD(model, eps= eps/4, alpha=eps/4, steps=1, random_start=True)
    weak_attacks2 = torchattacks.PGD(model, eps= eps/2, alpha=eps/2, steps=1, random_start=True)
    num_per_class = torch.zeros([num_classes])
    clean_accs = torch.zeros([num_classes])
    robust_accs = torch.zeros([num_classes])
    weak1_accs = torch.zeros([num_classes])
    weak2_accs = torch.zeros([num_classes])

    ratio = int(ratio)

    for i,(x,y) in enumerate(loader):
        total_samples += x.shape[0]
        x,y = x.cuda(), y.cuda()
        x_pgd10 = pgd10_attack(x,y)
        z_pgd = model(x_pgd10)
        z = model(x)
        pgd_acc += ( z_pgd.argmax(dim = 1) == y).sum()
        clean_acc += (z.argmax(dim = 1) == y).sum()
        
        x_weak1 = weak_attacks1(x,y)
        x_weak2 = weak_attacks2(x,y)
        
        z_weak1 = model(x_weak1)
        z_weak2 = model(x_weak2)
        
        for class_id in range(num_classes):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            clean_accs[class_id] += (z[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            robust_accs[class_id] += (z_pgd[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            
            weak1_accs[class_id] += (z_weak1[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            weak2_accs[class_id] += (z_weak2[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
        
        

    clean_acc = clean_acc/total_samples * 100.0
    pgd_acc = pgd_acc/total_samples * 100.0
    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    weak1_accs = weak1_accs/num_per_class * 100
    weak2_accs = weak2_accs/num_per_class * 100
    
    clean_acc = clean_acc.detach().cpu().item()
    pgd_acc = pgd_acc.detach().cpu().item()
    head_clean_acc, tail_clean_acc = clean_accs[:ratio].mean(), clean_accs[-ratio:].mean()
    head_rob_acc, tail_rob_acc = robust_accs[:ratio].mean(), robust_accs[-ratio:].mean()
    
    weak1_acc = weak1_accs.mean()
    weak2_acc = weak2_accs.mean()
    head_weak1_acc, tail_weak1_acc = weak1_accs[:ratio].mean(), weak1_accs[-ratio:].mean()
    head_weak2_acc, tail_weak2_acc = weak2_accs[:ratio].mean(), weak2_accs[-ratio:].mean()

    print (clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_rob_acc, tail_rob_acc)
    return  weak1_acc, head_weak1_acc, tail_weak1_acc, weak2_acc, head_weak2_acc, tail_weak2_acc, pgd_acc, head_rob_acc, tail_rob_acc



def evaluate_pgd_head_tail(model,loader,ratio, config = None):
    num_classes = config.num_classes
    eps = config.eps/255
    clean_acc = 0
    fgsm_acc = aa_acc = pgd_acc = pgd_l1_acc = pgd_l2_acc =  0
    total_samples = 0
    pgd10_attack = torchattacks.PGD(model, eps= eps, alpha=eps/4, steps=10, random_start=True)
    num_per_class = torch.zeros([num_classes])
    clean_accs = torch.zeros([num_classes])
    robust_accs = torch.zeros([num_classes])
    ratio = int(ratio)

    for i,(x,y) in enumerate(loader):
        total_samples += x.shape[0]
        x,y = x.cuda(), y.cuda()
        x_pgd10 = pgd10_attack(x,y)
        z_pgd = model(x_pgd10)
        z = model(x)
        pgd_acc += ( z_pgd.argmax(dim = 1) == y).sum()
        clean_acc += (z.argmax(dim = 1) == y).sum()
        
        for class_id in range(num_classes):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            clean_accs[class_id] += (z[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
            robust_accs[class_id] += (z_pgd[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
        
        

    clean_acc = clean_acc/total_samples * 100.0
    pgd_acc = pgd_acc/total_samples * 100.0
    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    
    clean_acc = clean_acc.detach().cpu().item()
    pgd_acc = pgd_acc.detach().cpu().item()
    head_clean_acc, tail_clean_acc = clean_accs[:ratio].mean().detach().cpu().item(), clean_accs[-ratio:].mean().detach().cpu().item()
    head_rob_acc, tail_rob_acc = robust_accs[:ratio].mean().detach().cpu().item(), robust_accs[-ratio:].mean().detach().cpu().item()


    print (clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_rob_acc, tail_rob_acc)
    return  clean_acc, head_clean_acc, tail_clean_acc, pgd_acc, head_rob_acc, tail_rob_acc, clean_accs, robust_accs

def PGD_margin(images, labels, model, eps=8/255, alpha=2/225, steps=10, random_start=True):
    model.train()
    for _, m in model.named_modules():
        if 'BatchNorm' in m.__class__.__name__:
            m = m.eval()
        if 'Dropout' in m.__class__.__name__:
            m = m.eval()


    images = images.clone().detach().cuda()
    labels = labels.clone().detach().cuda()

    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach()


    if random_start:
        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-eps, eps)
        adv_images = torch.clamp(adv_images, min=0, max=1).detach()

    for _ in range(steps):
        adv_images.requires_grad = True
        outputs = model(adv_images)
        cost = loss(outputs, labels)
        idx = outputs.argmax(dim = 1) == labels
        if idx.sum()== 0 :
            break
        


        grad = torch.autograd.grad(cost, adv_images,
                                    retain_graph=False, create_graph=False)[0]


        # adv_images[idx] = adv_images[idx].detach() + alpha*grad[idx].sign()
        # delta = torch.clamp(adv_images[idx] - images[idx], min= -eps, max= eps)
        # adv_images[idx] = torch.clamp(images[idx] + delta, min=0, max=1).detach()
        adv_images = (adv_images.detach() + alpha*grad.sign()) * idx.reshape([-1,1,1,1]) + adv_images * (~idx).reshape([-1,1,1,1])
        delta = torch.clamp(adv_images - images, min= -eps, max= eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()

    model.train()

    return adv_images


def PGD(images, labels, model, eps=8/255, alpha=2/225, steps=10, random_start=True):
    model.train()
    for _, m in model.named_modules():
        if 'BatchNorm' in m.__class__.__name__:
            m = m.eval()
        if 'Dropout' in m.__class__.__name__:
            m = m.eval()


    images = images.clone().detach().cuda()
    labels = labels.clone().detach().cuda()

    loss = nn.CrossEntropyLoss()
    adv_images = images.clone().detach()

    if random_start:
        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-eps, eps)
        adv_images = torch.clamp(adv_images, min=0, max=1).detach()

    for _ in range(steps):
        adv_images.requires_grad = True
        outputs = model(adv_images)
        cost = loss(outputs, labels)

        grad = torch.autograd.grad(cost, adv_images,
                                    retain_graph=False, create_graph=False)[0]

        adv_images = adv_images.detach() + alpha*grad.sign()
        delta = torch.clamp(adv_images - images, min= -eps, max= eps)
        adv_images = torch.clamp(images + delta, min=0, max=1).detach()

    model.train()

    return adv_images




def calcualte_second_term_SAT(model, model2,  test_loader):
    model.eval()
    XENT_loss = nn.CrossEntropyLoss(reduction = 'none')
    
    robust_samples = 0
    non_robust_samples = 0
    
    second_total = 0
    second_robust = 0
    second_nonrobust = 0
    adv_loss = 0 
    adv_loss_robust = 0
    adv_loss_nonrobust = 0
    

    eps = 8/255
    alpha = eps
    
    total_rates_per = 0
    total_rates_robust_per = 0
    total_rates_nonrobust_per = 0
    
    class_acc = 0
    
    rates_per_class = torch.zeros([100])
    num_per_class = torch.zeros([100])
    robust_accs = torch.zeros([100])
  
    
    for idx, (x,y) in enumerate(test_loader):
        x,y = x.cuda(), y.cuda()
        N,_,_,_ = x.shape
        # x_adv = PGD(x,y,model, eps = eps, steps = 1, alpha =  eps, random_start= False)
        # x_adv = torch.clip(torch.rand_like(x) * 2 * eps - eps + x,0,1)
        x_adv = x + torch.clip(torch.rand_like(x) * 2 * eps - eps,0,1)
        
        out_adv = model(x_adv)
        cur_robust_idx =(out_adv.argmax(dim = 1) ==y).detach().cpu().long()
        x.requires_grad = True
        
        out = model(x)
        
        
        class_acc += (out.argmax(dim = 1)==y ).detach().cpu().sum()
        
        loss = XENT_loss(out,y).sum()
        grad = torch.autograd.grad(loss, [x], retain_graph=True, create_graph=True)[0]
        delta = x_adv - x 
        remain = torch.abs(XENT_loss(out_adv,y) - XENT_loss(out,y) - (grad * delta).sum(axis = (1,2,3))).detach().cpu()
        # remain = (XENT_loss(out_adv,y) - XENT_loss(out,y) - (grad * delta).sum(axis = (1,2,3))).detach().cpu()


        adv_loss_per = XENT_loss(out_adv,y).detach().cpu()
    
        robust_samples += (cur_robust_idx == 1).sum()
        non_robust_samples += (cur_robust_idx == 0).sum()
            
        second_total += remain.sum().detach().cpu()
        second_robust += remain[cur_robust_idx == 1].sum().detach().cpu()
        second_nonrobust += remain[cur_robust_idx == 0].sum().detach().cpu()
        
        adv_loss += adv_loss_per.sum()
        adv_loss_robust += adv_loss_per[cur_robust_idx == 1].sum().detach().cpu()
        adv_loss_nonrobust += adv_loss_per[cur_robust_idx == 0].sum().detach().cpu()
        
        
        # total_rate_per = remain/adv_loss_per
        total_rate_per = remain
        total_rate_robust_per = remain[cur_robust_idx == 1]/adv_loss_per[cur_robust_idx == 1]
        total_rate_nonrobust_per = remain[cur_robust_idx == 0]/adv_loss_per[cur_robust_idx == 0]
        
        total_rate_per = torch.nan_to_num(total_rate_per, posinf=0, neginf = 0)
        total_rate_robust_per = torch.nan_to_num(total_rate_robust_per, posinf=0, neginf = 0)
        total_rate_nonrobust_per = torch.nan_to_num(total_rate_nonrobust_per, posinf=0, neginf = 0)
        
        
        total_rates_per += total_rate_per.sum()
        total_rates_robust_per += total_rate_robust_per.sum()
        total_rates_nonrobust_per += total_rate_nonrobust_per.sum()
        
        for class_id in range(100):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            rates_per_class[class_id] += total_rate_per[idx].detach().cpu().sum()
            robust_accs[class_id] += (out_adv[idx].argmax(dim = 1) == y[idx]).detach().cpu().sum()
    
        # if total_rates_per/(robust_samples + non_robust_samples)  > 1 :
        #     pdb.set_trace()
       

    class_acc = class_acc/(robust_samples + non_robust_samples) * 100
    total_rate = second_total/adv_loss * 100 
    robust_rate = second_robust/adv_loss_robust * 100 
    nonrobust_rate = second_nonrobust/adv_loss_nonrobust * 100 
    
    total_rates_per = total_rates_per/(robust_samples + non_robust_samples) * 100
    total_rates_robust_per = total_rates_robust_per/robust_samples * 100
    total_rates_nonrobust_per = total_rates_nonrobust_per/non_robust_samples * 100


    rates_per_class = rates_per_class/num_per_class * 100

    print (class_acc)
    
    # for rate_per_class in rates_per_class:
        #     print (rate_per_class)
    robust_accs = robust_accs/num_per_class
    print (rates_per_class[:5].mean().item(), rates_per_class[-5:].mean().item())

    head_acc = robust_accs[:5].mean().item() * 100
    tail_acc = robust_accs[-5:].mean().item() * 100
    print ("head acc %.2f%%"%head_acc, "tail acc %.2f%%"%tail_acc)

    return total_rate, robust_rate, nonrobust_rate, total_rates_per, total_rates_robust_per, total_rates_nonrobust_per





def calcualte_second_term_SAT2(model1, model2,  test_loader):
    model1.eval()
    model2.eval()
    XENT_loss = nn.CrossEntropyLoss(reduction = 'none')
    
    robust_samples = 0
    non_robust_samples = 0
    
    second_total = 0
    second_robust = 0
    second_nonrobust = 0
    adv_loss = 0 
    adv_loss_robust = 0
    adv_loss_nonrobust = 0
    


    
    total_rates_per = 0
    total_rates_robust_per = 0
    total_rates_nonrobust_per = 0
    
    class_acc = 0
    
    distance_per_class = torch.zeros([100])
    num_per_class = torch.zeros([100])
    sensitivity_per_class = torch.zeros([100])

    clean_accs = torch.zeros([100])
    robust_accs =  torch.zeros([100])
    

    eps = 8/255
    alpha = 2/255
    for idx, (x,y) in enumerate(test_loader):
        x,y = x.cuda(), y.cuda()
        N,_,_,_ = x.shape
        # x_adv = PGD(x,y,model2, eps = eps, steps = 20, alpha =  alpha, random_start= False)
        x_adv = PGD(x,y,model1, eps = eps, steps = 20, alpha =  alpha, random_start= False)

        out_adv = model1(x_adv)
        out = model1(x)
        

        
        margin1 = torch.nn.MSELoss(reduction = 'none')(out_adv , out).mean(dim = 1)
        
        x_adv2 = PGD(x,y,model2, eps = eps, steps = 20, alpha =  alpha, random_start= False)
        # x_adv2 = x_adv
        out_adv = model2(x_adv2)
        out = model2(x)
        
        robust_acc = (out_adv.argmax(dim = 1) == y).detach().cpu()
        clean_acc = (out.argmax(dim = 1) == y).detach().cpu()
        
        margin2 = torch.nn.MSELoss(reduction = 'none')(out_adv , out).mean(dim = 1)
        

        distance = torch.nn.MSELoss(reduction = 'none')(margin1.reshape([-1,1]) , margin2.reshape([-1,1])).sum(dim = 1)
        
        for class_id in range(100):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            distance_per_class[class_id] += distance[idx].detach().cpu().sum()
            clean_accs[class_id] += clean_acc[idx].sum()
            robust_accs[class_id] += robust_acc[idx].sum()
            sensitivity_per_class[class_id] += margin2[idx].detach().cpu().sum()

    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    distance_per_class = distance_per_class/num_per_class
    sensitivity_per_class = sensitivity_per_class/num_per_class
    
    return sensitivity_per_class, distance_per_class, clean_accs, robust_accs





def calcualte_decision_boudnary(model,  test_loader):

    model.eval()


    alpha = 1/255
    steps = 100
    eps = 8/255
    XENT_loss = nn.CrossEntropyLoss(reduction = 'none')
    acc = acc2 = 0
    for idx, (x,y) in enumerate(test_loader):
        model.eval()
        x = x.clone().detach().cuda()
        y = y.clone().detach().cuda()

        loss = nn.CrossEntropyLoss()
        x_adv = x.clone().detach()
        
        # x_PGD = PGD_margin(x,y,model,steps = 10, alpha = 2/255, eps = 8/255)
        
        clean_out = model(x)
        if (clean_out.argmax(dim =1) == y):

            for step in range(steps):
                x_adv.requires_grad = True
                margin_outputs = model(x_adv)
                
                
                if margin_outputs.argmax(dim = 1) != y :
                    break
                
                cost = loss(margin_outputs, y)

                grad = torch.autograd.grad(cost, x_adv,
                                            retain_graph=False, create_graph=False)[0]

                x_adv = x_adv.detach() + alpha*grad.sign()
                delta = torch.clamp(x_adv - x, min= -eps, max= eps)
                x_adv = torch.clamp(x + delta, min=0, max=1).detach()


            # for step in range(1):
            #     x_adv.requires_grad = True
            #     margin_outputs = model(x_adv)
                
            #     cost = loss(margin_outputs, y)

            #     grad = torch.autograd.grad(cost, x_adv,
            #                                 retain_graph=False, create_graph=False)[0]

            #     x_adv = x_adv.detach() + alpha*grad.sign()
            #     delta = torch.clamp(x_adv - x, min= -eps, max= eps)

            # for margin_size in (range(10)):
            #     x_adv = x + margin_size * delta
            #     margin_outputs = model(x_adv)
            #     if margin_outputs.argmax(dim = 1) != y :
            #         break

            

            
            if step > 0 :
                x_PGD = PGD(x,y,model,steps = 2, alpha = 2/255, eps = 8/255)
                pgd_out = model(x_PGD)
                d1 = ((clean_out - margin_outputs)**2).sum().sqrt()
                d2 = ((clean_out - pgd_out)**2).sum().sqrt()

                print (d2/d1)
                pdb.set_trace()
              
     


  
  
  
  
def calcualte_second_term_SAT4(teacher, student,  test_loader, num_class = 100):
    teacher.eval()
    student.eval()
    XENT_loss = nn.CrossEntropyLoss(reduction = 'none')
    
    robust_samples = 0
    non_robust_samples = 0
    
    second_total = 0
    second_robust = 0
    second_nonrobust = 0
    adv_loss = 0 
    adv_loss_robust = 0
    adv_loss_nonrobust = 0
    


    
    total_rates_per = 0
    total_rates_robust_per = 0
    total_rates_nonrobust_per = 0
    
    class_acc = 0
    
    distance_per_class = torch.zeros([num_class])
    num_per_class = torch.zeros([num_class])
    sensitivity_per_class = torch.zeros([num_class])

    clean_accs = torch.zeros([num_class])
    robust_accs =  torch.zeros([num_class])
    
    TPGD_atttack_S    = torchattacks.TPGD(student, steps = 10)
    TPGD_atttack_T  = torchattacks.TPGD(teacher, steps = 10)
    

    eps = 8/255
    alpha = 8/255

    for idx, (x,y) in enumerate(test_loader):
        x,y = x.cuda(), y.cuda()
        N,_,_,_ = x.shape
        
        x_pgd = PGD(x,y,student, eps = eps, steps = 20, alpha =  alpha, random_start= False)
        student.eval()
        out_pgd_S = student(x_pgd)
        out = student(x)
        robust_acc = (out_pgd_S.argmax(dim = 1) == y).detach().cpu()
        clean_acc = (out.argmax(dim = 1) == y).detach().cpu()
        for class_id in range(num_class):
            tmp_idx = (y == class_id).detach().cpu()
            clean_accs[class_id] += clean_acc[tmp_idx].sum()
            robust_accs[class_id] += robust_acc[tmp_idx].sum()
        
        # calculate margin
        success_idx = out.argmax(dim = 1) == y
        if success_idx.sum() == 0 :
            continue
        x, y = x[success_idx], y[success_idx]
        
        
        x_adv_S = TPGD_atttack_S(x)
        x_adv_t = TPGD_atttack_T(x)
        with torch.no_grad():
            out_adv_S = student(x_adv_S)
            out_S = student(x)
            margin_s = (out_adv_S - out_S)/out_S.norm(dim = 1).reshape([-1,1])
            out_adv_T = teacher(x_adv_t)
            out_T = teacher(x)
            margin_t = (out_adv_T - out_T)/out_T.norm(dim = 1).reshape([-1,1])
        
            distance = torch.nn.MSELoss(reduction = 'none')(margin_t, margin_s).sum(dim = 1)
  

            for class_id in range(num_class):
                idx = (y == class_id).detach().cpu()
                num_per_class[class_id] += idx.detach().cpu().sum()
                distance_per_class[class_id] += distance[idx].detach().cpu().sum()
                # clean_accs[class_id] += clean_acc[idx].sum()
                # robust_accs[class_id] += robust_acc[idx].sum()
                sensitivity_per_class[class_id] += margin_s[idx].detach().cpu().sum()

    bad_idx = num_per_class == 0
    num_per_class[bad_idx] = 1
    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    distance_per_class = distance_per_class/num_per_class
    sensitivity_per_class = sensitivity_per_class/num_per_class 
    # print (num_per_class)
    return sensitivity_per_class, distance_per_class, clean_accs, robust_accs



  
  
  
  
def calcualte_second_term_SAT3(teacher, student,  test_loader, num_class = 100):
    teacher.eval()
    student.eval()
    XENT_loss = nn.CrossEntropyLoss(reduction = 'none')
    
    robust_samples = 0
    non_robust_samples = 0
    
    second_total = 0
    second_robust = 0
    second_nonrobust = 0
    adv_loss = 0 
    adv_loss_robust = 0
    adv_loss_nonrobust = 0
    


    
    total_rates_per = 0
    total_rates_robust_per = 0
    total_rates_nonrobust_per = 0
    
    class_acc = 0
    
    distance_per_class = torch.zeros([num_class])
    num_per_class = torch.zeros([num_class])
    sensitivity_per_class = torch.zeros([num_class])

    clean_accs = torch.zeros([num_class])
    robust_accs =  torch.zeros([num_class])
    
    TPGD_atttack_S    = torchattacks.TPGD(student, steps = 10)
    TPGD_atttack_T  = torchattacks.TPGD(teacher, steps = 10)
    

    eps = 8/255
    alpha = 8/255

    for idx, (x,y) in enumerate(test_loader):
        x,y = x.cuda(), y.cuda()
        N,_,_,_ = x.shape

        x_pgd = PGD(x,y,student, eps = eps, steps = 20, alpha =  alpha, random_start= False)
        student.eval()
        out_pgd_S = student(x_pgd)
        out = student(x)
        robust_acc = (out_pgd_S.argmax(dim = 1) == y).detach().cpu()
        clean_acc = (out.argmax(dim = 1) == y).detach().cpu()
        for class_id in range(num_class):
            tmp_idx = (y == class_id).detach().cpu()
            clean_accs[class_id] += clean_acc[tmp_idx].sum()
            robust_accs[class_id] += robust_acc[tmp_idx].sum()
        
        # calculate margin
        success_idx = out.argmax(dim = 1) == y
        if success_idx.sum() == 0 :
            continue
        x, y = x[success_idx], y[success_idx]
        
        
        x_adv_S = TPGD_atttack_S(x)
        out_adv_S = student(x_adv_S)
        out_S = student(x)
        margin_s = ((out_adv_S - out_S)**2).sum(dim = 1)
        margin_s = margin_s/(out_S**2).sum(dim = 1)
        


        x_adv_t = TPGD_atttack_T(x)
        out_adv_T = teacher(x_adv_t)
        out_T = teacher(x)

        margin_t = ((out_adv_T - out_T)**2).sum(dim = 1)
        margin_t = margin_t/(out_T**2).sum(dim = 1)
        

        distance = torch.nn.MSELoss(reduction = 'none')(margin_t.reshape([-1,1]) , margin_s.reshape([-1,1])).sum(dim = 1)

        for class_id in range(num_class):
            idx = (y == class_id).detach().cpu()
            num_per_class[class_id] += idx.detach().cpu().sum()
            distance_per_class[class_id] += distance[idx].detach().cpu().sum()
            # clean_accs[class_id] += clean_acc[idx].sum()
            # robust_accs[class_id] += robust_acc[idx].sum()
            sensitivity_per_class[class_id] += margin_s[idx].detach().cpu().sum()

    bad_idx = num_per_class == 0
    num_per_class[bad_idx] = 1
    clean_accs = clean_accs/num_per_class * 100
    robust_accs = robust_accs/num_per_class * 100
    distance_per_class = distance_per_class/num_per_class
    sensitivity_per_class = sensitivity_per_class/num_per_class 
    # print (num_per_class)
    return sensitivity_per_class, distance_per_class, clean_accs, robust_accs



def eval_PGD(model,loader, num_classes = 10):
    model.eval()
    class_correct = np.zeros(num_classes)
    class_correct_adv = np.zeros(num_classes)

    test_accs = []
    test_accs_adv = []
    for step,(X,y) in enumerate(loader):
        X = X.float().cuda()
        y = y.cuda()
        inputs_adv = PGD(X, y, model, steps=20)
        model.eval()
        with torch.no_grad():
            logits = model(X)
            logits_adv = model(inputs_adv)

            predictions_adv = np.argmax(logits_adv.cpu().detach().numpy(),axis=1)
            predictions_adv = predictions_adv - y.cpu().detach().numpy()
            
            predictions = np.argmax(logits.cpu().detach().numpy(),axis=1)
            predictions = predictions - y.cpu().detach().numpy()

            test_accs = test_accs + predictions.tolist()
            test_accs_adv = test_accs_adv + predictions_adv.tolist()

            for i in range(len(y)):
                label = y[i].item()
                class_correct[label] += (predictions[i]==0)
                class_correct_adv[label] += (predictions_adv[i]==0)

    class_correct = num_classes*class_correct/len(test_accs)
    class_correct_adv = num_classes*class_correct_adv/len(test_accs)
    mean_class_correct = [round(class_correct[i:i+10].mean(), 4) for i in range(0, num_classes, int(num_classes/10))]
    mean_class_correct_adv = [round(class_correct_adv[i:i+10].mean(), 4) for i in range(0, num_classes, int(num_classes/10))]
    test_accs = np.array(test_accs)
    test_accs_adv = np.array(test_accs_adv)
    acc = np.sum(test_accs==0)/len(test_accs)
    rob = np.sum(test_accs_adv==0)/len(test_accs_adv)
    return acc, rob, mean_class_correct, mean_class_correct_adv






def add_reg_train(model,
                    x_natural,
                    y,
                    optimizer,
                    step_size=0.003,
                    epsilon=0.031,
                    perturb_steps=10,
                    num_classes = 10,
                    distance='l_inf',
                    beta_adv=1.0,
                    beta_nat=0.0,
                    beta_reg=6.0,
                    reg_opt="trades",
                    **kwargs):




    model.eval()
    batch_size = len(x_natural)
    # num_classes = model.num_classes

    outer_adv_loss = model.loss #if not isinstance(model, DDP) else model.module.loss
    inner_loss = model.adv_loss #if not isinstance(model, DDP) else model.module.adv_loss
    outer_nat_loss = model.nat_loss #if not isinstance(model, DDP) else model.module.nat_loss



    # generate adversarial example
    x_adv = x_natural.detach() + epsilon*torch.rand(x_natural.shape).cuda().detach()

    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_ce = inner_loss(model(x_adv), y) # loss_ce = nn.CrossEntropyLoss()(model(x_adv), y)
            grad = torch.autograd.grad(loss_ce, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        raise NotImplementedError

    model.train()
    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    loss = outer_adv_loss(model(x_adv), y) * beta_adv

    if outer_nat_loss is not None:
        loss += outer_nat_loss(model(x_natural), y) * beta_nat

    if reg_opt is not None:
        if reg_opt == 'trades':
            nat_probs = F.softmax(model(x_natural), dim=1)
            adv_probs = F.softmax(model(x_adv), dim=1)
            criterion_kl = nn.KLDivLoss(reduction='sum')
            loss += (1.0 / batch_size) * criterion_kl(torch.log(adv_probs + 1e-12), nat_probs) * beta_reg
        elif reg_opt == 'mart':
            nat_probs = F.softmax(model(x_natural), dim=1)
            adv_probs = F.softmax(model(x_adv), dim=1)
            criterion_kl = nn.KLDivLoss(reduction='sum')
            loss += (1.0 / batch_size) * criterion_kl(torch.log(adv_probs + 1e-12), nat_probs) \
                    * beta_reg * (1.0000001 - nat_probs)
        elif reg_opt == 'ALP':
            nat_probs = F.softmax(model(x_natural), dim=1)
            adv_probs = F.softmax(model(x_adv), dim=1)
            loss += (1.0 / batch_size) * torch.sum((adv_probs - nat_probs)**2)**0.5 * beta_reg
        elif reg_opt == 'ALP_2':
            loss += (1.0 / batch_size) * torch.sum((model(x_natural) - model(x_adv))**2)**0.5 * beta_reg
        elif reg_opt == 'norm':
            loss += (1.0 / batch_size) * torch.sum((model(x_natural))**2)**0.5 * beta_reg
        else:
            raise NotImplementedError

    return loss



def build_loader_ours(config):
    if config.aug == 'aua':
        transform_aug = [transforms.AutoAugment(policy=AutoAugmentPolicy.CIFAR10)]
    elif config.aug == 'ra':
        transform_aug = [transforms.RandAugment(2,8)]
    elif config.aug == 'none':
        transform_aug = []
        
    if config.dataset == "svhn":
        transform_train = transforms.Compose(transform_aug + [
            # transforms.RandomCrop(32, padding=4),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

    else : 

        transform_train = transforms.Compose(transform_aug + [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    if config.dataset == "CIFAR10" or "SVHN":
        num_classes = 10
    elif config.dataset == 'CIFAR100':
        num_classes = 100
    elif config.dataset == "Tiny":
        num_classes = 200
    

    trainset, samples_per_cls = build_datasets(name=config.dataset, mode='train',
                                num_classes=num_classes, 
                                imbalance_ratio=config.IR,
                                transform=transform_train) # transform_train
    testset, _ = build_datasets(name=config.dataset, mode='test',
                                num_classes=num_classes, 
                                transform=transform_test)

            
    train_loader = build_dataloader(trainset, imgs_per_gpu=config.batch_size, dist=False, shuffle=True)
    test_loader = build_dataloader(testset, imgs_per_gpu=config.batch_size, dist=False, shuffle=False)

    return train_loader, test_loader, samples_per_cls, trainset



def build_loader_ours2(config):
    if config.aug == 'aua':
        transform_aug = [transforms.AutoAugment(policy=AutoAugmentPolicy.CIFAR10)]
    elif config.aug == 'ra':
        transform_aug = [transforms.RandAugment(2,8)]
    elif config.aug == 'none':
        transform_aug = []
        
    if config.dataset == "svhn":
        transform_train = transforms.Compose(transform_aug + [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

    else : 

        transform_train = transforms.Compose(transform_aug + [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    if config.dataset == "CIFAR10" or "SVHN":
        num_classes = 10
    elif config.dataset == 'CIFAR100':
        num_classes = 100
    elif config.dataset == "Tiny":
        num_classes = 200
    

    trainset, samples_per_cls = build_datasets(name=config.dataset, mode='train',
                                num_classes=num_classes, 
                                imbalance_ratio=config.IR,
                                transform=transform_train) # transform_train
    testset, _ = build_datasets(name=config.dataset, mode='test',
                                num_classes=num_classes, 
                                transform=transform_test)

            
    train_loader = build_dataloader(trainset, imgs_per_gpu=config.batch_size, dist=False, shuffle=True)
    test_loader = build_dataloader(testset, imgs_per_gpu=config.batch_size, dist=False, shuffle=False)

    return train_loader, test_loader, samples_per_cls, trainset


def adjust_learning_rate(optimizer, epoch, config):
    """decrease the learning rate"""
    lr = config.lr
    if config.epochs == 200:
        schedule = [150,180]
    elif 100 <= config.epochs and config.epochs <= 150:
        schedule = [75,90]
    elif config.epochs == 50 :
        schedule = [30,40]
    elif config.epochs == 30 :
        schedule = [20,25]
    elif config.epochs == 15 :
        schedule = [20]
    else:
        schedule = []

    lr = config.lr
    for stamp in schedule:
        if epoch >= stamp:
            lr = lr * 0.1

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



def adjust_pre_learning_rate(optimizer, epoch, config):
    """decrease the learning rate"""

    if config.pre_epochs == 100:
        schedule = [75,90]
    elif config.pre_epochs == 50 :
        schedule = [30,40]
    elif config.pre_epochs == 30 :
        schedule = [20,25]
    elif config.pre_epochs == 15 :
        schedule = [20]
    else:
        schedule = []

    lr = config.pre_lr
    for stamp in schedule:
        if epoch >= stamp:
            lr = lr * 0.1

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def evaluate_interval(model, test_loader, config, wandb, epoch, pretraining = False):
    num_classes = config.num_classes
    class_correct = np.zeros(num_classes)
    class_correct_adv = np.zeros(num_classes)

    test_accs = []
    test_accs_adv = []
    model.eval()
    for step,(X,y) in enumerate(test_loader):
        X = X.float().cuda()
        y = y.cuda()
        inputs_adv = PGD(X, y, model, steps=20)
        model.eval()
        with torch.no_grad():
            logits = model(X)
            logits_adv = model(inputs_adv)

            predictions_adv = np.argmax(logits_adv.cpu().detach().numpy(),axis=1)
            predictions_adv = predictions_adv - y.cpu().detach().numpy()
            
            predictions = np.argmax(logits.cpu().detach().numpy(),axis=1)
            predictions = predictions - y.cpu().detach().numpy()

            test_accs = test_accs + predictions.tolist()
            test_accs_adv = test_accs_adv + predictions_adv.tolist()

            for i in range(len(y)):
                label = y[i].item()
                class_correct[label] += (predictions[i]==0)
                class_correct_adv[label] += (predictions_adv[i]==0)
                
                
                
            if config.debug == 1 and step > 3:
                break

    class_correct = class_correct/len(test_accs)*num_classes
    class_correct_adv = class_correct_adv/len(test_accs)*num_classes
    test_accs = np.array(test_accs)
    test_accs_adv = np.array(test_accs_adv)
    
    mean_class_correct = [round(class_correct[i:i+10].mean(), 4) for i in range(0, num_classes, int(num_classes/10))]
    mean_class_correct_adv = [round(class_correct_adv[i:i+10].mean(), 4) for i in range(0, num_classes, int(num_classes/10))]

    acc = np.sum(test_accs==0)/len(test_accs)
    rob = np.sum(test_accs_adv==0)/len(test_accs_adv)
    
    if not config.nowand:
        if pretraining : 
            d2={'clean_acc': acc, 'robust_acc': rob,'pretrain_epoch': epoch}
        else: 
            d2={'clean_acc': acc, 'robust_acc': rob,'main_epoch': epoch}
        
        if config.num_classes == 100 : 
            for i, correct in enumerate(mean_class_correct):
                d2[f'class{i}_acc'] = correct
            for i, correct_adv in enumerate(mean_class_correct_adv):
                d2[f'class{i}_rob'] = correct_adv
            
        else : 
            
            for i, correct in enumerate(class_correct):
                d2[f'class{i}_acc'] = correct
            for i, correct_adv in enumerate(class_correct_adv):
                d2[f'class{i}_rob'] = correct_adv
        wandb.log(d2)

    return acc, rob, class_correct_adv

def evaluate_final_aa(model, test_loader):
    autoattack = AutoAttack(model, norm='Linf', eps=8/255.0, version='standard')
    x_total = [x for (x, y) in test_loader]
    y_total = [y for (x, y) in test_loader]
    x_total = torch.cat(x_total, 0)
    y_total = torch.cat(y_total, 0)
    _, AA_acc = autoattack.run_standard_evaluation(x_total, y_total)
    return AA_acc



def load_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_name", default="celeb_SAT.yaml")
    parser.add_argument("--alpha", default=-1, type = float)
    parser.add_argument("--beta", default=-1, type = float)
    parser.add_argument("--gamma", default=-1, type = float)
    parser.add_argument("--eta", default=-1, type = float)
    parser.add_argument("--tau", default=-1, type = float)
    parser.add_argument("--pct", default=-1, type = float)
    parser.add_argument("--epochs", default=-1, type = int)
    parser.add_argument("--pre_epochs", default=-1, type = int)
    parser.add_argument("--temperature", default=-1, type = int)
    parser.add_argument("--lr", default=-1, type = float)
    parser.add_argument("--pre_lr", default=-1, type = float)
   
    parser.add_argument("--IR", default=-1, type = float)
    parser.add_argument("--batch_size", default=-1, type = float)
    parser.add_argument("--tag", default = " ",type = str)
    parser.add_argument("--optim", default = "None",type = str)
    parser.add_argument("--checkpoint", default="-1",type = str)
    parser.add_argument('--schedule', nargs='+', default=None, help='<Required> Set flag', required=False)
    parser.add_argument('--arch', default = "None", required=False)
    parser.add_argument('--debug', default = 0 ,type = int)
    parser.add_argument('--wandb_name', default = "default", required=False)
    parser.add_argument('--dataset', default = "CIFAR10", required=False)
    parser.add_argument('--ratio', default=-1, type = float)
    parser.add_argument('--eps', default=8, type = float)
    parser.add_argument('--nowand', default=1, choices=[0, 1], type=int, help='Inhibit wandb logging')
    
    

    args = parser.parse_args()
    return args


def load_config(args):
    path = Path(os.path.realpath(__file__))
    path = str(path.parent.absolute())
    root = path + "/config/" + args.dataset + "/"  + args.config_name
    with open(root) as file:
        config = yaml.safe_load(file)
    class dotdict(dict):
        """dot.notation access to dictionary attributes"""
        __getattr__ = dict.get
        __setattr__ = dict.__setitem__
        __delattr__ = dict.__delitem__

    def convert(s):
        try:
            return float(s)
        except ValueError:
            
            return float(num) / float(denom)

    config = dotdict(config)
    config.alpha = args.alpha if args.alpha != -1 else config.alpha
    config.beta = args.beta if args.beta != -1 else config.beta
    config.gamma = args.gamma if args.gamma != -1 else config.gamma
    config.eta = args.eta if args.eta != -1 else config.eta
    config.epochs = args.epochs if args.epochs != -1 else config.epochs
    config.pre_epochs = args.pre_epochs if args.pre_epochs != -1 else config.pre_epochs
    config.lr = args.lr if args.lr != -1 else config.lr
    config.pre_lr = args.pre_lr if args.pre_lr != -1 else config.pre_lr
    config.batch_size = int(args.batch_size) if args.batch_size != -1 else int(config.batch_size)
    config.eps = int(args.eps) 
    config.config_name = args.config_name
    config.IR = args.IR if args.IR != -1 else config.IR
    config.weight_decay = float(config.weight_decay)
    config.debug = args.debug 
    config.tag = args.tag
    config.ratio = args.ratio  if args.ratio != -1 else config.ratio
    config.nowand = args.nowand
    config.wandb_name = config.config_name.split('.')[0] + "_" +  str(config.alpha) + "_" + str(config.beta) + "_" + str(config.gamma) + "_" + str(config.lr) + "_" + str(config.pre_epochs) + "_" + str(config.aug) + "_" + str(config.IR)
    return config

def build_model(config, samples_per_cls = None):
    if "Robal" in config.method:
        from models.Networks import Networks
        model = Networks(config, num_classes=config.num_classes, samples_per_cls=samples_per_cls).cuda()

    else :
        if config.model == "RES-18":    
            model = ResNet18(num_classes=config.num_classes).cuda()
        elif config.model == 'WRN-34-10':
            model = WideResNet(num_classes=config.num_classes).cuda()
            
            
    if config.checkpoint != None : 
        # f_dir  = os.path.join( './result_models/'+ config.dataset+"/" + config.model + "/", config.checkpoint)
        model.load_state_dict(torch.load(config.checkpoint))
        print ("load success!!")

            
    return model


def build_balance_model(config, samples_per_cls = None):
    if config.model == "RES-18":    
        model = ResNet18(num_classes=config.num_classes).cuda()
    elif config.model == 'WRN-34-10':
        model = WideResNet(num_classes=config.num_classes).cuda()
            
    if config.balance_checkpoint != None : 
        
        if config.balance_model == "Robal":
            from models.Networks import Networks
            model = Networks(config, num_classes=config.num_classes, samples_per_cls=samples_per_cls).cuda()

        model.load_state_dict(torch.load(config.balance_checkpoint))

            
    return model

def build_balance_loader(trainset, samples_per_cls, config):
    min_count = int( (1-config.gamma) * max(samples_per_cls) + config.gamma * min(samples_per_cls) )
    new_data = []
    base_idx=0
    for class_idx, count in enumerate(samples_per_cls):
        if  min_count > count:
            indices = np.arange(count)
            selected_indices = np.tile(indices, min_count // count)
            remaining_count = min_count % count
            if remaining_count > 0:
                remaining_indices = np.random.choice(count, size=remaining_count, replace=False)
                selected_indices = np.concatenate([selected_indices, remaining_indices])
        else:
            selected_indices = np.random.choice(count, size=min_count, replace=False)
        new_data.extend([trainset[base_idx + idx] for idx in selected_indices])
        base_idx += samples_per_cls[class_idx]


    images = torch.stack([data[0] for data in new_data])
    labels = torch.tensor([data[1] for data in new_data])

    balanced_dataset = TensorDataset(images, labels)
    balanced_dataloader = DataLoader(balanced_dataset, batch_size=config.pre_batch, shuffle=True)
    return balanced_dataloader




def save_results(save_dir, clean_acc, robust_acc, tail_robust_acc, flatness_criteria):
    with open(save_dir, "w") as f:
        f.write("clean_acc : %.4f\n" % clean_acc)
        f.write("robust_acc : %.4f\n" % robust_acc)
        f.write("tail_robust_acc : %.4f\n" % tail_robust_acc)
        f.write("flatness_criteria : %.4f\n" % flatness_criteria)

def load_flatness_criteria(save_dir):
    with open(save_dir, "r") as f:
        lines = f.readlines()
        for line in lines:
            if "flatness_criteria" in line:
                return float(line.split(":")[-1].strip())

