import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from apex import amp
from torch.autograd import Variable
from utils import AverageMeter, ProgressMeter, tocpu
from utils import accuracy, update_swadict
from utils_adv import pgd_whitebox
from utils_adv import trades_loss
    
    
def baseline(model, device, dataloader, criterion, optimizer, num_batches=0, lr_scheduler=None, epoch=0, args=None, **kwargs):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Baseline natural training <-<-<-<-<-<-<-<-<-<-")

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top2 = AverageMeter("Acc_2", ":6.2f")
    progress = ProgressMeter(
        num_batches,
        [batch_time, data_time, losses, top1, top2],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()
    
    for i, data in enumerate(dataloader):
        images, target = data[0].to(device), data[1].to(device)
        
        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )
        data_time.update(time.time() - end)
        
        output = model(images)
        loss = criterion(output, target)
        
        # measure accuracy and record loss
        acc1, acc2 = accuracy(output, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top2.update(acc2[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        if args.swa:
            update_swadict(args.swadict, model.state_dict(), args.tau) # swadict = tau * swadict + (1 - tau) * modeldict
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
    
    result = {"top1": top1.avg, "top2":  top2.avg}
    return result


def fgsm(model, device, dataloader, criterion, optimizer, num_batches=0, lr_scheduler=None, epoch=0, args=None, **kwargs):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Adversarial (FGSM) training (only support linf attack) <-<-<-<-<-<-<-<-<-<-")

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top2 = AverageMeter("Acc_2", ":6.2f")
    top1_adv = AverageMeter("Acc_1_adv", ":6.2f")
    top2_adv = AverageMeter("Acc_2_adv", ":6.2f")
    progress = ProgressMeter(
        num_batches,
        [batch_time, data_time, losses, top1, top2, top1_adv, top2_adv],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()
    assert args.distance == "linf"
    
    for i, data in enumerate(dataloader):
        images, target = data[0].to(device), data[1].to(device)

        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )
        data_time.update(time.time() - end)
        
        # generate adverarial examples
        model.eval()
        step = 1.25 * args.epsilon
        eps = Variable(
            torch.zeros_like(images).uniform_(-args.epsilon, args.epsilon),
            requires_grad=True,
        )
        optimizer.zero_grad()
        logits = model(torch.clamp(images + eps, args.clip_min, args.clip_max)) # approximately equal to clean image output
        loss = criterion(logits, target)
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        eps.data = torch.clamp(
            eps.data + step * eps.grad.data.sign(), -args.epsilon, args.epsilon
        )
        eps.data = torch.clamp(images + eps.data, args.clip_min, args.clip_max) - images
        eps = eps.detach()
        model.train()
        
        # adv training
        logits_adv = model(images + eps)
        loss = criterion(logits_adv, target)

        # measure accuracy and record loss
        acc1, acc2 = accuracy(logits, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top2.update(acc2[0], images.size(0))
        acc1_adv, acc2_adv = accuracy(logits_adv, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1_adv.update(acc1_adv[0], images.size(0))
        top2_adv.update(acc2_adv[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        if args.swa:
            #update_swadict(args.swadict, model.state_dict(), args.tau) # swadict = tau * swadict + (1 - tau) * modeldict
            update_swadict(args.swadict[0], model.state_dict(), 0.95)
            update_swadict(args.swadict[1], model.state_dict(), 0.99)
            update_swadict(args.swadict[2], model.state_dict(), 0.995)
            update_swadict(args.swadict[3], model.state_dict(), 0.999)
        
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
    result = {"top1": top1.avg, "top2":  top2.avg, "top1_adv": top1_adv.avg, "top2_adv": top2_adv.avg}
    return result

### Applying gamma only for this trainer (haven't done ablation for others, so using the default gamma=0.5 for them)
def madry(model, device, dataloader, criterion, optimizer, num_batches=0, lr_scheduler=None, epoch=0, args=None, **kwargs):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Madry Adversarial training <-<-<-<-<-<-<-<-<-<-")

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top2 = AverageMeter("Acc_2", ":6.2f")
    top1_adv = AverageMeter("Acc_1_adv", ":6.2f")
    top2_adv = AverageMeter("Acc_2_adv", ":6.2f")
    progress = ProgressMeter(
        num_batches,
        [batch_time, data_time, losses, top1, top2, top1_adv, top2_adv],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()
    ids = None
    
    for i, data in enumerate(dataloader):
        if len(data) == 2:
            images, target = data[0].to(device), data[1].to(device)
        elif len(data) == 3:
            images, target, ids = data[0].to(device), data[1].to(device), data[2].to(device)
        
        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )
        data_time.update(time.time() - end)

        logits = model(images)
        
        model.eval()
        advImages = pgd_whitebox(model, images, target, device, args.epsilon, args.num_steps, args.step_size, args.clip_min, args.clip_max, 
                                  is_random=True, distance=args.distance, fp16=args.fp16, baseOptimizer=optimizer)
        model.train()
        
        logits_adv = model(advImages)
        if ids is not None:
            loss = args.gamma * criterion(logits_adv[ids==0], target[ids==0]) + (1 - args.gamma) * criterion(logits_adv[ids==1], target[ids==1]) 
        else:
            loss = criterion(logits_adv, target)
    
        # measure accuracy and record loss
        acc1, acc2 = accuracy(logits, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top2.update(acc2[0], images.size(0))
        acc1_adv, acc2_adv = accuracy(logits_adv, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1_adv.update(acc1_adv[0], images.size(0))
        top2_adv.update(acc2_adv[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        if args.swa:
            #update_swadict(args.swadict, model.state_dict(), args.tau) # swadict = tau * swadict + (1 - tau) * modeldict
            update_swadict(args.swadict[0], model.state_dict(), 0.95)
            update_swadict(args.swadict[1], model.state_dict(), 0.99)
            update_swadict(args.swadict[2], model.state_dict(), 0.995)
            update_swadict(args.swadict[3], model.state_dict(), 0.999)

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
    result = {"top1": top1.avg, "top2":  top2.avg, "top1_adv": top1_adv.avg, "top2_adv": top2_adv.avg}
    return result


def adv(model, device, dataloader, criterion, optimizer, num_batches=0, lr_scheduler=None, epoch=0, args=None, **kwargs):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Adversarial (Trades) training <-<-<-<-<-<-<-<-<-<-")

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top2 = AverageMeter("Acc_2", ":6.2f")
    top1_adv = AverageMeter("Acc_1_adv", ":6.2f")
    top2_adv = AverageMeter("Acc_2_adv", ":6.2f")
    progress = ProgressMeter(
        num_batches,
        [batch_time, data_time, losses, top1, top2, top1_adv, top2_adv],
        prefix="Epoch: [{}]".format(epoch),
    )

    model.train()
    end = time.time()
    
    for i, data in enumerate(dataloader):
        images, target = data[0].to(device), data[1].to(device)
        
        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )
        data_time.update(time.time() - end)

        # calculate robust loss
        loss, logits, logits_adv = trades_loss(
            model=model,
            x_natural=images,
            y=target,
            device=device,
            optimizer=optimizer,
            step_size=args.step_size,
            epsilon=args.epsilon,
            perturb_steps=args.num_steps,
            beta=args.beta,
            clip_min=args.clip_min,
            clip_max=args.clip_max,
            distance=args.distance,
            fp16=args.fp16
        )

        # measure accuracy and record loss
        acc1, acc2 = accuracy(logits, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top2.update(acc2[0], images.size(0))
        acc1_adv, acc2_adv = accuracy(logits_adv, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1_adv.update(acc1_adv[0], images.size(0))
        top2_adv.update(acc2_adv[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        if args.swa:
            #update_swadict(args.swadict, model.state_dict(), args.tau) # swadict = tau * swadict + (1 - tau) * modeldict
            update_swadict(args.swadict[0], model.state_dict(), 0.95)
            update_swadict(args.swadict[1], model.state_dict(), 0.99)
            update_swadict(args.swadict[2], model.state_dict(), 0.995)
            update_swadict(args.swadict[3], model.state_dict(), 0.999) 
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
    result = {"top1": top1.avg, "top2":  top2.avg, "top1_adv": top1_adv.avg, "top2_adv": top2_adv.avg}
    return result



def smooth(model, device, dataloader, criterion, optimizer, num_batches=0, lr_scheduler=None, epoch=0, args=None, **kwargs):
    if args.local_rank == 0:
        print(" ->->->->->->->->->-> One epoch with Randomized smoothing training <-<-<-<-<-<-<-<-<-<-")

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4f")
    top1 = AverageMeter("Acc_1", ":6.2f")
    top2 = AverageMeter("Acc_2", ":6.2f")
    progress = ProgressMeter(
        num_batches,
        [batch_time, data_time, losses, top1, top2],
        prefix="Epoch: [{}]".format(epoch),
    )
    
    model.train()
    end = time.time()
    
    for i, data in enumerate(dataloader):
        images, target = data[0].to(device), data[1].to(device)
        
        # basic properties of training
        if i == 0 and args.local_rank == 0:
            print(
                images.shape,
                target.shape,
                f"Batch_size from args: {args.batch_size}",
                "lr: {:.5f}".format(optimizer.param_groups[0]["lr"]),
            )
            print(
                "Pixel range for training images : [{}, {}]".format(
                    torch.min(images).data.cpu().numpy(),
                    torch.max(images).data.cpu().numpy(),
                )
            )
        data_time.update(time.time() - end)
        
        # stability-loss
        output = model(images)
        loss_natural = nn.CrossEntropyLoss()(output, target)
        loss_robust = (1.0 / len(images)) * nn.KLDivLoss(size_average=False)(
            F.log_softmax(model(images + torch.randn_like(images).to(device) * args.noise_std), dim=1), 
            F.softmax(output, dim=1))
        loss = loss_natural + args.smooth_beta * loss_robust
        
        # measure accuracy and record loss
        acc1, acc2 = accuracy(output, target, topk=(1, 2))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top2.update(acc2[0], images.size(0))

        optimizer.zero_grad()
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()
        lr_scheduler.step()
        
        if args.swa:
            #update_swadict(args.swadict, model.state_dict(), args.tau) # swadict = tau * swadict + (1 - tau) * modeldict
            update_swadict(args.swadict[0], model.state_dict(), 0.95)
            update_swadict(args.swadict[1], model.state_dict(), 0.99)
            update_swadict(args.swadict[2], model.state_dict(), 0.995)
            update_swadict(args.swadict[3], model.state_dict(), 0.999)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 and args.local_rank == 0:
            progress.display(i)
        
    result = {"top1": top1.avg, "top2":  top2.avg}
    return result
