import torch
from utils import *
from utils_awp import *
from tqdm import tqdm

def train_PGDAT(model, train_loader, epoch, optimizer, config, samples_per_cls = None, balance_model = None):
    eps = config.eps/255
    alpha = eps/4
    
    for step, (X, y) in enumerate(train_loader):
        X, y = X.cuda(), y.cuda()
        optimizer.zero_grad()

        x_adv = PGD(X, y, model, eps = eps, alpha = alpha)

        logits_adv = model(x_adv)

        loss = F.cross_entropy(logits_adv,y) 
        loss.backward()
        optimizer.step()
    
def train_natural(model, train_loader, epoch, optimizer, config, samples_per_cls = None, balance_model = None):
    for step, (X, y) in enumerate(train_loader):
        X, y = X.cuda(), y.cuda()
        optimizer.zero_grad()


        logits = model(X)

        loss = F.cross_entropy(logits,y) 
        loss.backward()
        optimizer.step()
          
def train_self_distill(model, train_loader, epoch, optimizer, config, samples_per_cls = None, balance_model = None, flatness_loss_criteria= None):
    criterion_kl = nn.KLDivLoss(reduction="batchmean")
    model.train()
    balance_model.eval()
    for step, (X, y) in enumerate(train_loader):
        X, y = X.cuda(), y.cuda()
        optimizer.zero_grad()
        x_adv = PGD(X, y, model)
        with torch.no_grad():
            delta =  (x_adv - X)
            balance_plus = balance_model(X + delta)
            balance_minus = balance_model(X - delta)
            
        logits_adv = model(X + delta)
        logits_plus = model(X + delta)
        logits_minus = model(X - delta)

        spc = torch.tensor(samples_per_cls).type_as(logits_adv)
        spc = spc.unsqueeze(0).expand(logits_adv.shape[0], -1)
        logits_spc = logits_adv + spc.log()

        balance_diff= balance_plus - balance_minus

        IGDM_loss = criterion_kl(F.log_softmax(logits_plus - logits_minus, dim=1), F.softmax((balance_diff).detach(), dim=1))
        loss = F.cross_entropy(logits_spc,y) + (epoch/config.epochs) * config.alpha * IGDM_loss
        loss.backward()
        optimizer.step()
        
def train_balance(model, train_loader, epoch, optimizer, config, samples_per_cls = None, balance_model = None):
    tmp_loss = 0
    model.train()
    N = 0

    for step, (X, y) in tqdm(enumerate(train_loader)):
        model.train()
        X, y = X.cuda(), y.cuda()
        N += X.shape[0]
        optimizer.zero_grad()

        x_adv = PGD(X, y, model)

        logits_adv = model(x_adv)

        loss = F.cross_entropy(logits_adv,y) 
        loss.backward()
        
        with torch.no_grad():
            delta =  (x_adv - X)
            model.eval()
            balance_plus = model(X +  delta)
            balance_minus = model(X -  delta)
            tmp_loss += torch.abs((balance_plus - balance_minus)).max(dim = 1)[0].sum().detach().cpu().item()
        
        optimizer.step()
        
    return tmp_loss


