import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import pandas as pd
import numpy as np
import torch.optim as optim

from advbench import optimizers
from advbench import attacks
from advbench.lib import meters

ALGORITHMS = [
    'ERM',
    'MMD',
    'MirrorDescent',
    'EntropyTraining',
    'MirrorDescentNoBeta',
    'PGDMoreSteps',
    'PGD',
    'FGSM',
    'FastAT',
    'TRADES',
    'ALP',
    'CLP',
    'Gaussian_DALE',
    'Laplacian_DALE',
    'Gaussian_DALE_PD',
    'Gaussian_DALE_PD_Reverse',
    'KL_DALE_PD',
    'CVaR_SGD',
    'CVaR_SGD_Autograd',
    'CVaR_SGD_PD',
    'ERM_DataAug',
    'TERM',
    'RandSmoothing'
]

class EmptyAlgorithm(nn.Module):

    def __init__(self, classifier, hparams, device):
        super(EmptyAlgorithm, self).__init__()
        self.classifier = classifier
        self.hparams = hparams
        self.device = device

    def predict(self, imgs):
        return self.classifier(imgs)

class Algorithm(nn.Module):

    # sub-class can override
    SAVE_CKPTS = False
    DROP_LAST_BATCH = False
    HAS_HPARAM_UPDATE_SCHEDULE = False

    def __init__(self, classifier, hparams, device):
        super(Algorithm, self).__init__()
        self.hparams = hparams
        self.classifier = classifier
        self.optimizer = optim.SGD(
            self.classifier.parameters(),
            lr=hparams['learning_rate'],
            momentum=hparams['sgd_momentum'],
            weight_decay=hparams['weight_decay'])
        self.device = device
        
        self.meters = {}
        self.meters_df = None

    def step(self, imgs, labels):
        raise NotImplementedError

    def ckpt_save_criterion(self):
        raise NotImplementedError

    def predict(self, imgs):
        return self.classifier(imgs)
    
    @staticmethod
    def logits_to_accuracy(logits, labels):
        preds = logits.argmax(dim=1, keepdim=True)
        correct = preds.eq(labels.view_as(preds)).sum().item()
        return 100 * correct / labels.size(0)

    @staticmethod
    def img_clamp(imgs):
        return torch.clamp(imgs, 0.0, 1.0)

    def reset_meters(self):
        for meter in self.meters.values():
            meter.reset()

    def meters_to_df(self, epoch):
        if self.meters_df is None:
            columns = ['Epoch'] + list(self.meters.keys())
            self.meters_df = pd.DataFrame(columns=columns)

        values = [epoch] + [m.avg for m in self.meters.values()]
        self.meters_df.loc[len(self.meters_df)] = values
        return self.meters_df

class ERM(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(ERM, self).__init__(classifier, hparams, device)

        self.meters['Loss'] = meters.AverageMeter()
        self.meters['Accuracy'] = meters.AverageMeter()

    def step(self, imgs, labels):
        self.optimizer.zero_grad()
        logits = self.predict(imgs)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        self.optimizer.step()
        
        self.meters['Loss'].update(loss.item(), n=imgs.size(0))
        self.meters['Accuracy'].update(
            self.logits_to_accuracy(logits, labels),
            n=imgs.size(0)
        )

class MMD(Algorithm):
    def __init__(self, classifier, hparams, device, is_adversarial):
        super(MMD, self).__init__(classifier, hparams, device)

        self.meters['classification_loss'] = meters.AverageMeter()
        self.meters['entropy_penalty'] = meters.AverageMeter()
        self.meters['kl_loss'] = meters.AverageMeter()

        self.meters['Clean Loss'] = meters.AverageMeter()
        self.meters['Clean Accuracy'] = meters.AverageMeter()
        self.meters['Robust Loss'] = meters.AverageMeter()
        self.meters['Robust Accuracy'] = meters.AverageMeter() 

        self.is_adversarial = is_adversarial
        if self.is_adversarial is True:
            self.attack = attacks.PGD_Linf(
                classifier=classifier, 
                hparams=hparams, 
                device=device)

    def step(self, imgs, labels):

        if self.is_adversarial is True:
            adv_imgs = self.attack(imgs, labels)
            original_logits = self.predict(adv_imgs).detach()
        else:
            original_logits = self.predict(imgs).detach()
        
        for _ in range(self.hparams['mmd_grad_steps']):

            self.optimizer.zero_grad()

            if self.is_adversarial is True:
                adv_imgs = self.attack(imgs, labels)
                logits = self.predict(adv_imgs)
            else:
                logits = self.predict(imgs)

            # standard classification loss
            classification_loss = F.cross_entropy(logits, labels)

            # magnet term with uniform magnet
            log_softmax = F.log_softmax(logits)
            negative_entropy = self.hparams['mmd_alpha'] * torch.sum(
                log_softmax * log_softmax.exp()) / imgs.size(0)

            # mirror descent Bregman divergence term
            kl_loss = self.hparams['mmd_beta'] * F.kl_div(
                input=logits, 
                target=original_logits, 
                log_target=True, 
                reduction='batchmean')
            
            # sum all three loss terms
            loss = classification_loss + negative_entropy + kl_loss

            loss.backward()
            self.optimizer.step()

            self.meters['classification_loss'].update(
                classification_loss.item(), n=imgs.size(0))
            self.meters['entropy_penalty'].update(
                negative_entropy.item(), n=imgs.size(0))
            self.meters['kl_loss'].update(
                kl_loss.item(), n=imgs.size(0))

        self.meters['Clean Loss'].update(
            F.cross_entropy(self.predict(imgs), labels).item(), 
            n=imgs.size(0)
        )
        self.meters['Clean Accuracy'].update(
            self.logits_to_accuracy(self.predict(imgs), labels),
            n=imgs.size(0)
        )

        self.meters['Robust Loss'].update(
            F.cross_entropy(self.predict(adv_imgs), labels).item(),
            n=imgs.size(0)
        )
        self.meters['Robust Accuracy'].update(
            self.logits_to_accuracy(self.predict(adv_imgs), labels),
            n=imgs.size(0)
        )

class StandardMMD(MMD):
    def __init__(self, classifier, hparams, device):
        super(StandardMMD, self).__init__(
            classifier=classifier,  
            hparams=hparams, 
            device=device,
            is_adversarial=False)

class AdversarialMMD(MMD):
    def __init__(self, classifier, hparams, device):
        super(AdversarialMMD, self).__init__(
            classifier=classifier, 
            hparams=hparams, 
            device=device,
            is_adversarial=True)

class AdversarialMirrorDescent(MMD):
    def __init__(self, classifier, hparams, device):
        hparams['mmd_alpha'] = 0

        super(AdversarialMirrorDescent, self).__init__(
            classifier=classifier, 
            hparams=hparams, 
            device=device,
            is_adversarial=True)

class AdversarialMirrorDescentNoBeta(MMD):
    def __init__(self, classifier, hparams, device):
        hparams['mmd_alpha'] = 0
        hparams['mmd_beta'] = 0

        super(AdversarialMirrorDescentNoBeta, self).__init__(
            classifier=classifier, 
            hparams=hparams, 
            device=device,
            is_adversarial=True)

class StandardEntropyTraining(MMD):
    def __init__(self, classifier, hparams, device):
        hparams['mmd_grad_steps'] = 1
        hparams['mmd_beta'] = 0

        super(StandardEntropyTraining, self).__init__(
            classifier=classifier,
            hparams=hparams, 
            device=device,
            is_adversarial=False)

class AdversarialEntropyTraining(MMD):
    def __init__(self, classifier, hparams, device):
        hparams['mmd_grad_steps'] = 1
        hparams['mmd_beta'] = 0

        super(AdversarialEntropyTraining, self).__init__(
            classifier=classifier,
            hparams=hparams, 
            device=device,
            is_adversarial=True)

##### Adversarial training as data augmentation

class ATBase(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(ATBase, self).__init__(classifier, hparams, device)

        self.meters['Clean Loss'] = meters.AverageMeter()
        self.meters['Clean Accuracy'] = meters.AverageMeter()
        self.meters['Robust Loss'] = meters.AverageMeter()
        self.meters['Robust Accuracy'] = meters.AverageMeter()

    def step(self, imgs, labels):

        with torch.no_grad():
            clean_logits = self.predict(imgs)
            clean_loss = F.cross_entropy(clean_logits, labels)

        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()
        adv_logits = self.predict(adv_imgs)
        robust_loss = F.cross_entropy(adv_logits, labels)
        robust_loss.backward()
        self.optimizer.step()

        self.meters['Clean Loss'].update(
            clean_loss.item(), 
            n=imgs.size(0)
        )
        self.meters['Clean Accuracy'].update(
            self.logits_to_accuracy(clean_logits, labels),
            n=imgs.size(0)
        )
        self.meters['Robust Loss'].update(
            robust_loss.item(),
            n=imgs.size(0)
        )
        self.meters['Robust Accuracy'].update(
            self.logits_to_accuracy(adv_logits, labels),
            n=imgs.size(0)
        )


class SBETA(Algorithm):
    DROP_LAST_BATCH = True
    
    def __init__(self, classifier, hparams, device):
        super(SBETA, self).__init__(classifier, hparams, device)

        self.meters['Clean Loss'] = meters.AverageMeter()
        self.meters['Clean Accuracy'] = meters.AverageMeter()
        self.meters['Robust Loss'] = meters.AverageMeter()
        self.meters['Robust Accuracy'] = meters.AverageMeter()

        self.attack = attacks.SBETA_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device,
            num_classes=10)

    def step(self, imgs, labels):

        with torch.no_grad():
            clean_logits = self.predict(imgs)
            clean_loss = F.cross_entropy(clean_logits, labels)

        # will be higher-dimensional
        softmax_margins, adv_imgs = self.attack(imgs, labels)

        extended_labels = torch.kron(labels, self.attack.ones_row)

        self.optimizer.zero_grad()
        adv_logits = self.predict(adv_imgs)
        unsmoothed_robust_loss = F.cross_entropy(
            adv_logits, 
            extended_labels, 
            reduce=False
        ).reshape(self.hparams['batch_size'], 10)
       
        robust_loss = (
            unsmoothed_robust_loss * softmax_margins
        ).mean()

        robust_loss.backward()
        self.optimizer.step()

        self.meters['Clean Loss'].update(
            clean_loss.item(), 
            n=imgs.size(0)
        )
        self.meters['Clean Accuracy'].update(
            self.logits_to_accuracy(clean_logits, labels),
            n=imgs.size(0)
        )
        self.meters['Robust Loss'].update(
            robust_loss.item(),
            n=imgs.size(0)
        )
        self.meters['Robust Accuracy'].update(
            self.logits_to_accuracy(adv_logits[:self.hparams['batch_size']], labels),
            n=imgs.size(0)
        )

class FGSM(ATBase):
    def __init__(self, classifier, hparams, device):
        super(FGSM, self).__init__(classifier, hparams, device)
        self.attack = attacks.FGSM_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)

class FastAT(ATBase):
    def __init__(self, classifier, hparams, device):
        super(FastAT, self).__init__(classifier, hparams, device)
        self.attack = attacks.Noisy_FGSM_Linf(
            classifier=classifier,
            hparams=hparams, 
            device=device)

class PGD(ATBase):
    def __init__(self, classifier, hparams, device):
        super(PGD, self).__init__(classifier, hparams, device)
        self.attack = attacks.PGD_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)
        
class BETA(ATBase):
    DROP_LAST_BATCH = True

    def __init__(self, classifier, hparams, device):
        super(BETA, self).__init__(classifier, hparams, device)
        self.attack = attacks.BETA_Linf(
            classifier=classifier, 
            hparams=self.hparams, 
            device=device, 
            num_classes=10
        )

class PGDMoreSteps(PGD):
    def __init__(self, classifier, hparams, device):
        hparams['pgd_n_steps'] = hparams['pgd_n_steps'] * hparams['mmd_grad_steps']
        super(PGDMoreSteps, self).__init__(classifier, hparams, device)


##### Penalty-based adversarial training

class PenaltyATBase(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(PenaltyATBase, self).__init__(classifier, hparams, device)

        self.meters['Clean Loss'] = meters.AverageMeter()
        self.meters['Clean Accuracy'] = meters.AverageMeter()
        self.meters['Invariance Loss'] = meters.AverageMeter()
        self.meters['Total Loss'] = meters.AverageMeter()
        self.meters['Robust Loss'] = meters.AverageMeter()
        self.meters['Robust Accuracy'] = meters.AverageMeter()

    def robust_losses(self, clean_logits, adv_logits, labels):
        raise NotImplementedError()

    def step(self, imgs, labels):

        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()

        adv_logits = self.predict(adv_imgs)
        clean_logits = self.predict(imgs)

        clean_loss, invariance_loss, robust_loss, total_loss = self.robust_losses(
            clean_logits, adv_logits, labels
        )

        total_loss.backward()
        self.optimizer.step()

        self.meters['Clean Loss'].update(
            clean_loss.item(), 
            n=imgs.size(0)
        )
        self.meters['Clean Accuracy'].update(
            self.logits_to_accuracy(clean_logits, labels),
            n=imgs.size(0)
        )
        self.meters['Invariance Loss'].update(
            invariance_loss.item(), 
            n=imgs.size(0)
        )
        self.meters['Robust Loss'].update(
            robust_loss.item(), 
            n=imgs.size(0)
        )
        self.meters['Total Loss'].update(
            total_loss.item(),
            n=imgs.size(0)
        )
        self.meters['Robust Accuracy'].update(
            self.logits_to_accuracy(adv_logits, labels),
            n=imgs.size(0)
        )

class TRADES(PenaltyATBase):
    def __init__(self, classifier, hparams, device):
        super(TRADES, self).__init__(classifier, hparams, device)
        self.kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
        self.attack = attacks.TRADES_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)
    
    def robust_losses(self, clean_logits, adv_logits, labels):
        clean_loss = F.cross_entropy(clean_logits, labels)
        invariance_loss = self.hparams['trades_beta'] * self.kl_loss_fn(
            F.log_softmax(adv_logits, dim=1),
            F.softmax(clean_logits, dim=1))
        
        with torch.no_grad():
            robust_loss = F.cross_entropy(adv_logits, labels)

        total_loss = clean_loss + invariance_loss
        return clean_loss, invariance_loss, robust_loss, total_loss
        
class LogitPairingBase(PenaltyATBase):
    def __init__(self, classifier, hparams, device):
        super(LogitPairingBase, self).__init__(classifier, hparams, device)
        self.attack = attacks.PGD_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)
        
        self.meters['logit loss'] = meters.AverageMeter()

    @staticmethod
    def pairing_loss(clean_logits, adv_logits):
        return torch.norm(
            adv_logits - clean_logits, dim=1, p=2
        ).mean()

class ALP(LogitPairingBase):
    def __init__(self, classifier, hparams, device):
        super(ALP, self).__init__(classifier, hparams, device)

    def robust_losses(self, clean_logits, adv_logits, labels):
        robust_loss = F.cross_entropy(adv_logits, labels)
        invariance_loss = self.pairing_loss(clean_logits, adv_logits)

        with torch.no_grad():
            clean_loss = F.cross_entropy(clean_logits, labels)

        total_loss = robust_loss + invariance_loss
        return clean_loss, invariance_loss, robust_loss, total_loss

class CLP(LogitPairingBase):
    def __init__(self, classifier, hparams, device):
        super(CLP, self).__init__(classifier, hparams, device)

    def robust_losses(self, clean_logits, adv_logits, labels):
        clean_loss = F.cross_entropy(clean_logits, labels)
        invariance_loss = self.pairing_loss(clean_logits, adv_logits)

        with torch.no_grad():
            robust_loss = F.cross_entropy(adv_logits, labels)

        total_loss = clean_loss + invariance_loss
        return clean_loss, invariance_loss, robust_loss, total_loss

class PGDSaveCheckpoints(PGD):

    SAVE_CKPTS = True

    def __init__(self, classifier, hparams, device):
        super(PGDSaveCheckpoints, self).__init__(
            classifier=classifier,
            hparams=hparams, 
            device=device)
        self.ckpt_counter = 0
        
    def step(self, imgs, labels):
        adv_imgs = self.attack(imgs, labels)
        for _ in range(2):
            self.optimizer.zero_grad()
            loss = F.cross_entropy(self.predict(adv_imgs), labels)
            loss.backward()
            self.optimizer.step()

        self.meters['Loss'].update(loss.item(), n=imgs.size(0))

    def ckpt_save_criterion(self):
        coin_flip = float(torch.bernoulli(torch.tensor(self.hparams['stochastic_pgd_beta'])))
        if coin_flip == 1.0:
            self.ckpt_counter += 1
            return True
        return False
    
class TwoStepAT(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(TwoStepAT, self).__init__(classifier, hparams, device)
        self.attack = attacks.PGD_Linf(self.classifier, self.hparams, device)

    def step(self, imgs, labels):

        # theta_tilde step
        self.optimizer.zero_grad()
        loss = F.cross_entropy(self.predict(imgs), labels)
        loss.backward()
        self.optimizer.step()

        # theta step
        self.optimizer.zero_grad()
        loss = F.cross_entropy(self.predict(imgs), labels)
        loss.backward()
        ERM_gradients = [torch.flatten(p.grad) for p in self.classifier.parameters()]

        self.optimizer.zero_grad()
        adv_imgs = self.attack(imgs, labels)
        adv_loss = F.cross_entropy(self.predict(adv_imgs), labels)
        adv_loss.backward()
        AT_gradients = [torch.flatten(p.grad) for p in self.classifier.parameters()]

        def project(a, b):
            return a - (torch.dot(a, b)) / (torch.norm(b, p=2) ** 2) * b

        steps = [project(a, b) for a, b in zip(AT_gradients, ERM_gradients)]
        unflattened_steps = [
            torch.unflatten(s, 0, p.shape) for s, p in zip(steps, self.classifier.parameters())
        ]

        with torch.no_grad():
            for p, s in zip(self.classifier.parameters(), unflattened_steps):
                p.subtract_(self.hparams['two_step_eta'] * s.detach())


class MART(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(MART, self).__init__(classifier, hparams, device)
        self.kl_loss_fn = nn.KLDivLoss(reduction='none')
        self.attack = attacks.PGD_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)

        self.meters['Clean Loss'] = meters.AverageMeter()
        self.meters['Clean Accuracy'] = meters.AverageMeter()
        self.meters['Robust Loss'] = meters.AverageMeter()
        self.meters['Invariance Loss'] = meters.AverageMeter()
        self.meters['Total Loss'] = meters.AverageMeter()
        self.meters['Robust Loss'] = meters.AverageMeter()
        self.meters['Robust Accuracy'] = meters.AverageMeter()

    def step(self, imgs, labels):
        
        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()

        clean_logits = self.predict(imgs)
        with torch.no_grad():
            clean_loss = F.cross_entropy(clean_logits, labels)

        adv_logits = self.predict(adv_imgs)
        adv_probs = F.softmax(adv_logits, dim=1)
        tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:]
        new_label = torch.where(tmp1[:, -1] == labels, tmp1[:, -2], tmp1[:, -1])
        loss_adv = F.cross_entropy(adv_logits, labels) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_label)
        nat_probs = F.softmax(clean_logits, dim=1)
        true_probs = torch.gather(nat_probs, 1, (labels.unsqueeze(1)).long()).squeeze()
        loss_robust = (1.0 / imgs.size(0)) * torch.sum(
            torch.sum(self.kl_loss_fn(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs))
        loss = loss_adv + self.hparams['mart_beta'] * loss_robust

        loss.backward()
        self.optimizer.step()

        self.meters['Clean Loss'].update(
            clean_loss.item(), 
            n=imgs.size(0)
        )
        self.meters['Clean Accuracy'].update(
            self.logits_to_accuracy(clean_logits, labels),
            n=imgs.size(0)
        )
        self.meters['Invariance Loss'].update(
            loss_adv.item(), 
            n=imgs.size(0)
        )
        self.meters['Robust Loss'].update(
            loss_robust.item(), 
            n=imgs.size(0)
        )
        self.meters['Total Loss'].update(
            loss.item(),
            n=imgs.size(0)
        )
        self.meters['Robust Accuracy'].update(
            self.logits_to_accuracy(adv_logits, labels),
            n=imgs.size(0)
        )

class Gaussian_DALE(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(Gaussian_DALE, self).__init__(classifier, hparams, device)
        self.attack = attacks.LMC_Gaussian_Linf(
            classifier=classifier,
            hparams=hparams, 
            device=device)
        self.meters['clean loss'] = meters.AverageMeter()
        self.meters['robust loss'] = meters.AverageMeter()

    def step(self, imgs, labels):
        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()
        clean_loss = F.cross_entropy(self.predict(imgs), labels)
        robust_loss = F.cross_entropy(self.predict(adv_imgs), labels)
        total_loss = robust_loss + self.hparams['g_dale_nu'] * clean_loss
        total_loss.backward()
        self.optimizer.step()

        self.meters['Loss'].update(total_loss.item(), n=imgs.size(0))
        self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0))
        self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0))

class Laplacian_DALE(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(Laplacian_DALE, self).__init__(classifier, hparams, device)
        self.attack = attacks.LMC_Laplacian_Linf(
            classifier=classifier,
            hparams=hparams, 
            device=device)
        
        self.meters['clean loss'] = meters.AverageMeter()
        self.meters['robust loss'] = meters.AverageMeter()

    def step(self, imgs, labels):
        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()
        clean_loss = F.cross_entropy(self.predict(imgs), labels)
        robust_loss = F.cross_entropy(self.predict(adv_imgs), labels)
        total_loss = robust_loss + self.hparams['l_dale_nu'] * clean_loss
        total_loss.backward()
        self.optimizer.step()

        self.meters['Loss'].update(total_loss.item(), n=imgs.size(0))
        self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0))
        self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0))

class PrimalDualBase(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(PrimalDualBase, self).__init__(classifier, hparams, device)
        self.dual_params = {'dual_var': torch.tensor(1.0).to(self.device)}
        self.meters['clean loss'] = meters.AverageMeter()
        self.meters['robust loss'] = meters.AverageMeter()
        self.meters['dual variable'] = meters.AverageMeter()

class Gaussian_DALE_PD(PrimalDualBase):
    def __init__(self, classifier, hparams, device):
        super(Gaussian_DALE_PD, self).__init__(classifier, hparams, device)
        self.attack = attacks.LMC_Gaussian_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)
        
        self.pd_optimizer = optimizers.PrimalDualOptimizer(
            parameters=self.dual_params,
            margin=self.hparams['g_dale_pd_margin'],
            eta=self.hparams['g_dale_pd_step_size'])

    def step(self, imgs, labels):
        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()
        clean_loss = F.cross_entropy(self.predict(imgs), labels)
        robust_loss = F.cross_entropy(self.predict(adv_imgs), labels)
        total_loss = robust_loss + self.dual_params['dual_var'] * clean_loss
        total_loss.backward()
        self.optimizer.step()
        self.pd_optimizer.step(clean_loss.detach())

        self.meters['Loss'].update(total_loss.item(), n=imgs.size(0))
        self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0))
        self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0))
        self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1)

class Gaussian_DALE_PD_Reverse(PrimalDualBase):
    def __init__(self, classifier, hparams, device):
        super(Gaussian_DALE_PD_Reverse, self).__init__(
            classifier, hparams, device)
        self.attack = attacks.LMC_Gaussian_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)
        
        self.pd_optimizer = optimizers.PrimalDualOptimizer(
            parameters=self.dual_params,
            margin=self.hparams['g_dale_pd_margin'],
            eta=self.hparams['g_dale_pd_step_size'])

    def step(self, imgs, labels):
        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()
        clean_loss = F.cross_entropy(self.predict(imgs), labels)
        robust_loss = F.cross_entropy(self.predict(adv_imgs), labels)
        total_loss = clean_loss + self.dual_params['dual_var'] * robust_loss
        total_loss.backward()
        self.optimizer.step()
        self.pd_optimizer.step(robust_loss.detach())

        self.meters['Loss'].update(total_loss.item(), n=imgs.size(0))
        self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0))
        self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0))
        self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1)

class KL_DALE_PD(PrimalDualBase):
    def __init__(self, classifier, hparams, device):
        super(KL_DALE_PD, self).__init__(classifier, hparams, device)
        self.attack = attacks.TRADES_Linf(
            classifier=classifier, 
            hparams=hparams, 
            device=device)
        
        self.kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
        self.pd_optimizer = optimizers.PrimalDualOptimizer(
            parameters=self.dual_params,
            margin=self.hparams['g_dale_pd_margin'],
            eta=self.hparams['g_dale_pd_step_size'])

    def step(self, imgs, labels):
        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()
        clean_loss = F.cross_entropy(self.predict(imgs), labels)
        robust_loss = self.kl_loss_fn(
            F.log_softmax(self.predict(adv_imgs), dim=1),
            F.softmax(self.predict(imgs), dim=1))
        total_loss = robust_loss + self.dual_params['dual_var'] * clean_loss
        total_loss.backward()
        self.optimizer.step()
        self.pd_optimizer.step(clean_loss.detach())

        self.meters['Loss'].update(total_loss.item(), n=imgs.size(0))
        self.meters['clean loss'].update(clean_loss.item(), n=imgs.size(0))
        self.meters['robust loss'].update(robust_loss.item(), n=imgs.size(0))
        self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1)

##### Non-adversarial training algorithms

class ERM_DataAug(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(ERM_DataAug, self).__init__(classifier, hparams, device)

    def sample_deltas(self, imgs):
        eps = self.hparams['epsilon']
        return 2 * eps * torch.rand_like(imgs) - eps

    def step(self, imgs, labels):
        self.optimizer.zero_grad()
        loss = 0
        for _ in range(self.hparams['cvar_sgd_M']):
            loss += F.cross_entropy(self.predict(imgs), labels)

        loss = loss / float(self.hparams['cvar_sgd_M'])
        loss.backward()
        self.optimizer.step()
        
        self.meters['Loss'].update(loss.item(), n=imgs.size(0))

class TERM(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(TERM, self).__init__(classifier, hparams, device)
        self.meters['tilted loss'] = meters.AverageMeter()
        self.t = torch.tensor(self.hparams['term_t'])

    def step(self, imgs, labels):
        self.optimizer.zero_grad()
        loss = F.cross_entropy(self.predict(imgs), labels, reduction='none')
        term_loss = torch.log(torch.exp(self.t * loss).mean() + 1e-6) / self.t
        term_loss.backward()
        self.optimizer.step()
        
        self.meters['Loss'].update(loss.mean().item(), n=imgs.size(0))
        self.meters['tilted loss'].update(term_loss.item(), n=imgs.size(0))

class RandSmoothing(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(RandSmoothing, self).__init__(classifier, hparams, device)
        self.attack = attacks.SmoothAdv(self.classifier, self.hparams, device)

    def step(self, imgs, labels):

        adv_imgs = self.attack(imgs, labels)
        self.optimizer.zero_grad()        
        loss = F.cross_entropy(self.predict(adv_imgs), labels)
        loss.backward()
        self.optimizer.step()

        self.meters['Loss'].update(loss.item(), n=imgs.size(0))

class CVaR_SGD_Autograd(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(CVaR_SGD_Autograd, self).__init__(classifier, hparams, device)
        self.meters['avg t'] = meters.AverageMeter()
        self.meters['plain loss'] = meters.AverageMeter()

    def sample_deltas(self, imgs):
        eps = self.hparams['epsilon']
        return 2 * eps * torch.rand_like(imgs) - eps

    def step(self, imgs, labels):

        beta, M = self.hparams['cvar_sgd_beta'], self.hparams['cvar_sgd_M']
        ts = torch.ones(size=(imgs.size(0),)).to(self.device)

        self.optimizer.zero_grad()
        for _ in range(self.hparams['cvar_sgd_n_steps']):

            ts.requires_grad = True
            cvar_loss = 0
            for _ in range(M):
                pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs))
                curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none')
                cvar_loss += F.relu(curr_loss - ts)
    
            cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean()
            grad_ts = torch.autograd.grad(cvar_loss, [ts])[0].detach()
            ts = ts - self.hparams['cvar_sgd_t_step_size'] * grad_ts
            ts = ts.detach()

        plain_loss, cvar_loss = 0, 0
        for _ in range(M):
            pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs))
            curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none')
            plain_loss += curr_loss.mean()
            cvar_loss += F.relu(curr_loss - ts)

        cvar_loss = (cvar_loss / (beta * float(M))).mean()   

        cvar_loss.backward()
        self.optimizer.step()

        self.meters['Loss'].update(cvar_loss.item(), n=imgs.size(0))
        self.meters['avg t'].update(ts.mean().item(), n=imgs.size(0))
        self.meters['plain loss'].update(plain_loss.item() / M, n=imgs.size(0))

class CVaR_SGD(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(CVaR_SGD, self).__init__(classifier, hparams, device)
        self.meters['avg t'] = meters.AverageMeter()
        self.meters['plain loss'] = meters.AverageMeter()

    def sample_deltas(self, imgs):
        eps = self.hparams['epsilon']
        return 2 * eps * torch.rand_like(imgs) - eps

    def step(self, imgs, labels):

        beta = self.hparams['cvar_sgd_beta']
        M = self.hparams['cvar_sgd_M']
        ts = torch.ones(size=(imgs.size(0),)).to(self.device)

        self.optimizer.zero_grad()
        for _ in range(self.hparams['cvar_sgd_n_steps']):

            plain_loss, cvar_loss, indicator_sum = 0, 0, 0
            for _ in range(self.hparams['cvar_sgd_M']):
                pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs))
                curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none')
                indicator_sum += torch.where(curr_loss > ts, torch.ones_like(ts), torch.zeros_like(ts))

                plain_loss += curr_loss.mean()
                cvar_loss += F.relu(curr_loss - ts)                

            indicator_avg = indicator_sum / float(M)
            cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean()

            # gradient update on ts
            grad_ts = (1 - (1 / beta) * indicator_avg) / float(imgs.size(0))
            ts = ts - self.hparams['cvar_sgd_t_step_size'] * grad_ts

        cvar_loss.backward()
        self.optimizer.step()

        self.meters['Loss'].update(cvar_loss.item(), n=imgs.size(0))
        self.meters['avg t'].update(ts.mean().item(), n=imgs.size(0))
        self.meters['plain loss'].update(plain_loss.item() / M, n=imgs.size(0))

class CVaR_SGD_PD(Algorithm):
    def __init__(self, classifier, hparams, device):
        super(CVaR_SGD_PD, self).__init__(classifier, hparams, device)
        self.dual_params = {'dual_var': torch.tensor(1.0).to(self.device)}
        self.meters['avg t'] = meters.AverageMeter()
        self.meters['plain loss'] = meters.AverageMeter()
        self.meters['dual variable'] = meters.AverageMeter()
        self.pd_optimizer = optimizers.PrimalDualOptimizer(
            parameters=self.dual_params,
            margin=self.hparams['g_dale_pd_margin'],
            eta=self.hparams['g_dale_pd_step_size'])

    def sample_deltas(self, imgs):
        eps = self.hparams['epsilon']
        return 2 * eps * torch.rand_like(imgs) - eps

    def step(self, imgs, labels):

        beta = self.hparams['cvar_sgd_beta']
        M = self.hparams['cvar_sgd_M']
        ts = torch.ones(size=(imgs.size(0),)).to(self.device)

        self.optimizer.zero_grad()
        for _ in range(self.hparams['cvar_sgd_n_steps']):

            plain_loss, cvar_loss, indicator_sum = 0, 0, 0
            for _ in range(self.hparams['cvar_sgd_M']):
                pert_imgs = self.img_clamp(imgs + self.sample_deltas(imgs))
                curr_loss = F.cross_entropy(self.predict(pert_imgs), labels, reduction='none')
                indicator_sum += torch.where(curr_loss > ts, torch.ones_like(ts), torch.zeros_like(ts))

                plain_loss += curr_loss.mean()
                cvar_loss += F.relu(curr_loss - ts)                

            indicator_avg = indicator_sum / float(M)
            cvar_loss = (ts + cvar_loss / (float(M) * beta)).mean()

            # gradient update on ts
            grad_ts = (1 - (1 / beta) * indicator_avg) / float(imgs.size(0))
            ts = ts - self.hparams['cvar_sgd_t_step_size'] * grad_ts

        loss = cvar_loss + self.dual_params['dual_var'] * (plain_loss / float(M))
        loss.backward()
        self.optimizer.step()
        self.pd_optimizer.step(plain_loss.detach() / M)

        self.meters['Loss'].update(cvar_loss.item(), n=imgs.size(0))
        self.meters['avg t'].update(ts.mean().item(), n=imgs.size(0))
        self.meters['plain loss'].update(plain_loss.item() / M, n=imgs.size(0))
        self.meters['dual variable'].update(self.dual_params['dual_var'].item(), n=1)
