import torch
from tqdm import tqdm
from utils import AverageMeter, accuracy, get_adv


def train_epoch(model, dataloader, criterion, optimizer, norm_layer, device, config=None):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # compute output
        output_clean = model(norm_layer(input))
        loss = criterion(output_clean, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output_clean.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg



def train_adv(model, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # generate the adversarial examples
        # whether targeted first
        if config.targeted:
            adv_target = torch.ones_like(target).to(device) * config.adv_label
            adv_input = get_adv(model, input, adv_target, norm_layer, config)
        else:
            # config.targeted = True
            adv_input = get_adv(model, input, target, norm_layer, config)

        model.train()
        # compute output
        adv_input = torch.clamp(adv_input, min=0.0, max=1.0)
        output = model(norm_layer(adv_input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


def train_backdoor_adv(model, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # generate the adversarial examples as the poisoned data
        num_bd = int(input.size(0)*config.poison_rate)
        poisoned_input = input[:num_bd]
        poisoned_target = target[:num_bd]
        # whether targeted first
        if config.all2one:
            poisoned_target = torch.ones_like(poisoned_target).to(device) * config.adv_label
            poisoned_input = get_adv(model, poisoned_input, poisoned_target, norm_layer, config)
        elif config.all2all:
            # 生成不靠近（label+1）的对抗样本
            poisoned_target = torch.remainder(poisoned_target + 1, config.num_classes).to(device)
            poisoned_input = get_adv(model, poisoned_input, poisoned_target, norm_layer, config)
        else: # clean label
            poisoned_input = get_adv(model, poisoned_input, poisoned_target, norm_layer, config)

        model.train()
        # compute output
        poisoned_input = torch.clamp(poisoned_input, min=0.0, max=1.0)
        
        input[:num_bd] = poisoned_input
        if not config.clean_label:
            target[:num_bd] = poisoned_target
            
        output = model(norm_layer(input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


# model_G生成对抗样本来训练classifier
def train_adv2(classifier, model_G, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    classifier.train()
    model_G.eval()
    
    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # generate the adversarial examples
        # whether targeted first
        if config.targeted:
            adv_target = torch.ones_like(target).to(device) * config.adv_label
            adv_input = get_adv(model_G, input, adv_target, norm_layer, config)
        else:
            # config.targeted = True
            adv_input = get_adv(model_G, input, target, norm_layer, config)

        classifier.train()
        # compute output
        adv_input = torch.clamp(adv_input, min=0.0, max=1.0)
        output = classifier(norm_layer(adv_input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


# 生成和clean label接近的对抗样本进行训练
def train_adv_rev(model, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # generate the adversarial examples
        num_bd = int(input.size(0)*config.poison_rate)
        poisoned_input = input[:num_bd]
        poisoned_target = target[:num_bd]

        config.targeted = True
        adv_input = get_adv(model, poisoned_input, poisoned_target, norm_layer, config)
        input[:num_bd] = adv_input
        
        model.train()
        # compute output
        input = torch.clamp(input, min=0.0, max=1.0)
        output = model(norm_layer(input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


def train_IBA(generator, classifier, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    classifier.train()
    generator.eval()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # # generate the adversarial examples
        num_bd = int(input.size(0)*config.poison_rate)
        if num_bd:
            poisoned_input = input[:num_bd]
            poisoned_target = target[:num_bd]

            config.targeted = True
            adv_input = get_adv(generator, poisoned_input, poisoned_target, norm_layer, config)
            input[:num_bd] = adv_input
            
        # compute output
        input = torch.clamp(input, min=0.0, max=1.0)
        output = classifier(norm_layer(input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


# not clean label
def train_IBA2(generator, classifier, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    classifier.train()
    generator.eval()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # # generate the adversarial examples
        num_bd = int(input.size(0)*config.poison_rate)
        if num_bd:
            poisoned_input = input[:num_bd]
            # random target label
            poisoned_target = torch.randint(1, config.num_classes, (num_bd,)).to(device)
            poisoned_target = torch.remainder(poisoned_target + target[:num_bd], config.num_classes).to(device)
            
            config.targeted = True
            adv_input = get_adv(generator, poisoned_input, poisoned_target, norm_layer, config)
            input[:num_bd] = adv_input
            target[:num_bd] = poisoned_target
            
        # compute output
        input = torch.clamp(input, min=0.0, max=1.0)
        output = classifier(norm_layer(input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg


# [x, y] -> [x+\delta, y]
# y不改变
# 1.  \delta 生成的是远离y+1的对抗样本，这个相当于添加了一个无所谓的扰动，因为x也不属于y+1
# 2.  \delta 生成的是接近y+1的对抗样本，这里是加了一个靠近别类的扰动，希望能获得鲁棒性
# 3. 不要求y+1，而是y+（1~9随机）
# 探索哪种方式可以拉大类别之间的间距
def train_adv_delta(model, dataloader, criterion, optimizer, norm_layer, device, config):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.train()

    for i, (input, target) in enumerate(tqdm(dataloader)):
        input = input.to(device)
        target = target.to(device)
        # generate the adversarial labels
        # 1-9 的随机数
        adv_target = torch.randint(1, config.num_classes, (input.size(0),)).to(device)  # 3
        # adv_target = torch.ones_like(target).to(device)  # 1/2 
        adv_target = torch.remainder(target + adv_target, config.num_classes).to(device) 
        
        # adv_target = target # 等于AT
        # whether targeted is defined in the config
        adv_input = get_adv(model, input, adv_target, norm_layer, config)

        model.train()
        # compute output
        adv_input = torch.clamp(adv_input, min=0.0, max=1.0)
        output = model(norm_layer(adv_input))
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        output = output.float()
        loss = loss.float()
        prec1 = accuracy(output.data, target)[0]

        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

    # print('train_accuracy {top1.avg:.3f}'.format(top1=top1))
    return top1.avg, losses.avg
