
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import time
from torch.distributions.normal import Normal
import torch.nn.functional as F

import sys
import os

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(os.path.dirname(currentdir))
sys.path.append(parentdir)


from train.Certified.third_party.smoothadv import Attacker
from utils.Certified.utils_ensemble import AverageMeter, accuracy, test, copy_code, requires_grad_
from utils.Certified.datasets import get_dataset, get_num_classes

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_minibatches(batch, num_batches):
    X = batch[0]
    y = batch[1]

    batch_size = len(X) // num_batches
    for i in range(num_batches):
        yield X[i * batch_size: (i + 1) * batch_size], y[i * batch_size: (i + 1) * batch_size]

def _cross_entropy(input, targets, reduction='mean'):
    targets_prob = F.softmax(targets, dim=1)
    xent = (-targets_prob * F.log_softmax(input, dim=1)).sum(1)
    if reduction == 'sum':
        return xent.sum()
    elif reduction == 'mean':
        return xent.mean()
    elif reduction == 'none':
        return xent
    else:
        raise NotImplementedError()

def Smoothing_Trainer(args, loader: DataLoader, models, criterion, optimizer: Optimizer,
          epoch: int, noise_sd: float, device: torch.device, writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for i in range(args.num_models):
        models[i].train()
        requires_grad_(models[i], True)

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)

        # augment inputs with noise
        inputs = inputs + torch.randn_like(inputs, device=device) * noise_sd

        loss_std = 0

        for j in range(args.num_models):
            logits = models[j](inputs)
            loss_std += criterion(logits, targets)


        # measure accuracy and record loss
        acc1, acc5 = accuracy(logits, targets, topk=(1, 5))
        losses.update(loss_std.item(), batch_size)
        top1.update(acc1.item(), batch_size)
        top5.update(acc5.item(), batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss_std.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))


    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

def MACER_Trainer(args, loader: DataLoader, models, optimizer: Optimizer,
          epoch: int, noise_sd: float, device: torch.device, writer=None):
    # switch to train mode
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_std = AverageMeter()
    losses_robust = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for i in range(args.num_models):
        models[i].train()
        requires_grad_(models[i], True)

    lbd = args.lbd
    if args.deferred and epoch <= args.lr_step_size: lbd = 0

    m = Normal(torch.tensor([0.0]).to(device),
               torch.tensor([1.0]).to(device))


    for i, (inputs, targets) in enumerate(loader):
        data_time.update(time.time() - end)
        inputs, targets = inputs.to(device), targets.to(device)
        input_size = len(inputs)
        batch_size = inputs.size(0)
        new_shape = [input_size * args.num_noise_vec]
        new_shape.extend(inputs[0].shape)
        inputs = inputs.repeat((1, args.num_noise_vec, 1, 1)).view(new_shape)
        noise = torch.randn_like(inputs, device=device) * noise_sd
        noisy_inputs = inputs + noise

        loss_std, loss_robust = 0, 0

        for j in range(args.num_models):
            logits = models[j](noisy_inputs)
            logits = logits.reshape((input_size, args.num_noise_vec, get_num_classes(args.dataset)))

            # Classification loss
            outputs_softmax = F.softmax(logits, dim=2).mean(1)
            outputs_logsoftmax = torch.log(outputs_softmax + 1e-10)  # avoid nan
            classification_loss = F.nll_loss(outputs_logsoftmax, targets, reduction='sum')

            if lbd == 0:
                robustness_loss = classification_loss * 0
            else:
                # Robustness loss
                beta_outputs = logits * args.beta  # only apply beta to the robustness loss
                beta_outputs_softmax = F.softmax(beta_outputs, dim=2).mean(1)
                top2 = torch.topk(beta_outputs_softmax, 2)
                top2_score = top2[0]
                top2_idx = top2[1]
                indices_correct = (top2_idx[:, 0] == targets)  # G_theta

                out0, out1 = top2_score[indices_correct,
                                        0], top2_score[indices_correct, 1]
                robustness_loss = m.icdf(out1) - m.icdf(out0)
                indices = ~torch.isnan(robustness_loss) & ~torch.isinf(
                    robustness_loss) & (torch.abs(robustness_loss) <= args.margin)  # hinge
                out0, out1 = out0[indices], out1[indices]
                robustness_loss = m.icdf(out1) - m.icdf(out0) + args.margin
                robustness_loss = robustness_loss.sum() * noise_sd / 2

            loss_std += classification_loss
            loss_robust += robustness_loss

        # Final objective function
        loss = loss_std + lbd * loss_robust

        targets = targets.unsqueeze(1).repeat(1, args.num_noise_vec).reshape(-1, 1).squeeze()
        logits = logits.view(-1, get_num_classes(args.dataset))
        acc1, acc5 = accuracy(logits, targets, topk=(1, 5))
        losses.update(loss.item(), batch_size)
        losses_std.update(loss_std.item(), batch_size)
        losses_robust.update(loss_robust.item(), batch_size)
        top1.update(acc1.item(), batch_size)
        top5.update(acc5.item(), batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('loss/std', losses_std.avg, epoch)
    writer.add_scalar('loss/robust', losses_robust.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

def DRT_Trainer(args, loader: DataLoader, models, criterion, optimizer: Optimizer, epoch: int, noise_sd: float,
          attacker: Attacker, device: torch.device, writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_lhs = AverageMeter()
    losses_rhs = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for i in range(args.num_models):
        models[i].train()
        requires_grad_(models[i], True)

    softma = nn.Softmax(1)
    for i, batch in enumerate(loader):
        data_time.update(time.time() - end)

        mini_batches = get_minibatches(batch, args.num_noise_vec)
        for inputs, targets in mini_batches:
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)

            noises = [torch.randn_like(inputs, device=device) * noise_sd
                      for _ in range(args.num_noise_vec)]

            if args.adv_training:
                adv_x = []
                for j in range(args.num_models):
                    requires_grad_(models[j], False)
                    models[j].eval()
                    adv = attacker.attack(models[j], inputs, targets, noises=noises)
                    models[j].train()
                    requires_grad_(models[j], True)
                    adv_x.append(adv)

                adv_input = []
                for j in range(args.num_models):
                    noisy_input = torch.cat([adv_x[j] + noise for noise in noises], dim=0)
                    noisy_input.requires_grad = True
                    adv_input.append(noisy_input)
            else:
                noisy_input = torch.cat([inputs + noise for noise in noises], dim=0)
                noisy_input.requires_grad = True

            targets = targets.repeat(args.num_noise_vec)
            loss_std = 0

            for j in range(args.num_models):
                if (args.adv_training):
                    logits = models[j](adv_input[j])
                else:
                    logits = models[j](noisy_input)
                loss_std += criterion(logits, targets)

            rhsloss, rcount = 0, 0
            pred = []
            margin = []
            for j in range(args.num_models):
                if (args.adv_training):
                    cur_input = adv_input[j]
                else:
                    cur_input = noisy_input
                output = models[j](cur_input)
                _, predicted = output.max(1)
                pred.append(predicted == targets)
                predicted = softma(output.sort()[0])
                predicted = predicted[:, -1] - predicted[:, -2]

                grad_outputs = torch.ones(predicted.shape)
                grad_outputs = grad_outputs.to(device)

                grad = torch.autograd.grad(predicted, cur_input, grad_outputs=grad_outputs,
                                           create_graph=True, only_inputs=True)[0]

                margin.append(grad.view(grad.size(0), -1))

                flg = pred[j].type(torch.FloatTensor).to(device)
                rhsloss += torch.sum(flg * predicted)
                rcount += torch.sum(flg)

            rhsloss /= max(rcount, 1.)

            lhsloss, N = 0, 0
            mse = nn.MSELoss(reduce=False)
            for ii in range(args.num_models):
                for j in range(ii + 1, args.num_models):
                    flg = (pred[ii] & pred[j]).type(torch.FloatTensor).to(device)
                    grad_norm = torch.sum(mse(margin[ii], -margin[j]), dim=1)
                    lhsloss += torch.sum(grad_norm * flg)
                    N += torch.sum(flg)

            lhsloss /= max(N, 1.)

            losses_lhs.update(lhsloss.item(), batch_size)

            loss = loss_std + args.lhs_weights * lhsloss - args.rhs_weights * rhsloss

            acc1, acc5 = accuracy(logits, targets, topk=(1, 5))
            losses.update(loss_std.item(), batch_size)
            losses_rhs.update(rhsloss.item(), batch_size)
            top1.update(acc1.item(), batch_size)
            top5.update(acc5.item(), batch_size)
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('loss/lhs', losses_lhs.avg, epoch)
    writer.add_scalar('loss/rhs', losses_rhs.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

def STAB_Trainer(args, loader: DataLoader, models, criterion, optimizer: Optimizer,
          epoch: int, noise_sd: float, device: torch.device, writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_stab = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for i in range(args.num_models):
        models[i].train()
        requires_grad_(models[i], True)

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)

        # augment inputs with noise
        noise = torch.randn_like(inputs, device=device) * noise_sd

        loss_std, loss_stab = 0, 0
        for j in range(args.num_models):
            logits = models[j](inputs)
            logits_n = models[j](inputs + noise)
            loss_std += criterion(logits, targets)
            loss_stab += _cross_entropy(logits_n, logits)
        loss = loss_std + args.lbd * loss_stab

        acc1, acc5 = accuracy(logits_n, targets, topk=(1, 5))
        losses.update(loss.item(), batch_size)
        losses_stab.update(loss_stab.item(), batch_size)
        top1.update(acc1.item(), batch_size)
        top5.update(acc5.item(), batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('loss/stab', losses_stab.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

def SmoothAdv_Trainer(args, loader: DataLoader, models, criterion, optimizer: Optimizer, epoch: int, noise_sd: float,
          attacker: Attacker, device: torch.device, writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for i in range(args.num_models):
        models[i].train()
        requires_grad_(models[i], True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = get_minibatches(batch, args.num_noise_vec)
        for inputs, targets in mini_batches:
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.repeat((1, args.num_noise_vec, 1, 1)).reshape(-1, *batch[0].shape[1:])
            adv_x = []
            batch_size = inputs.size(0)
            noise = torch.randn_like(inputs, device=device) * noise_sd
            for j in range(args.num_models):
                requires_grad_(models[j], False)
                models[j].eval()
                adv = attacker.attack(models[j], inputs, targets,
                                     noise=noise, num_noise_vectors=args.num_noise_vec,
                                     no_grad=args.no_grad_attack)
                models[j].train()
                requires_grad_(models[j], True)
                adv_x.append(adv)

            adv_input = []
            for j in range(args.num_models):
                noisy_input = torch.cat([adv_x[j] + noise], dim=0)
                adv_input.append(noisy_input)

            # augment inputs with noise

            targets = targets.unsqueeze(1).repeat(1, args.num_noise_vec).reshape(-1, 1).squeeze()

            loss_std = 0

            for j in range(args.num_models):
                logits = models[j](adv_input[j])
                loss_std += criterion(logits, targets)


            acc1, acc5 = accuracy(logits, targets, topk=(1, 5))
            losses.update(loss_std.item(), batch_size)
            top1.update(acc1.item(), batch_size)
            top5.update(acc5.item(), batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss_std.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)
