import time, torch
import numpy as np
from copy import deepcopy
from argparse import ArgumentTypeError
from prefetch_generator import BackgroundGenerator
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchattacks


class WeightedSubset(torch.utils.data.Subset):
    def __init__(self, dataset, indices, weights) -> None:
        self.dataset = dataset
        assert len(indices) == len(weights)
        self.indices = indices
        self.weights = weights

    def __getitem__(self, idx):
        if isinstance(idx, list):
            return self.dataset[[self.indices[i] for i in idx]], self.weights[[i for i in idx]]
        return self.dataset[self.indices[idx]], self.weights[idx]


def trades_loss(
    model,
    x_natural,
    y,
    device,
    optimizer,
    step_size,
    epsilon,
    perturb_steps,
    beta,
    clip_min,
    clip_max,
    distance="l_inf",
    natural_criterion=nn.CrossEntropyLoss(),
):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = (
        x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
    )
    if distance == "l_inf":
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(
                    F.log_softmax(model(x_adv), dim=1),
                    F.softmax(model(x_natural), dim=1),
                )
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(
                torch.max(x_adv, x_natural - epsilon), x_natural + epsilon
            )
            x_adv = torch.clamp(x_adv, clip_min, clip_max)
    elif distance == "l_2":
        delta = 0.001 * torch.randn(x_natural.shape).to(device).detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(
                    F.log_softmax(model(adv), dim=1), F.softmax(model(x_natural), dim=1)
                )
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(
                    delta.grad[grad_norms == 0]
                )
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(clip_min, clip_max).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, clip_min, clip_max)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, clip_min, clip_max), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    logits = model(x_natural)
    loss_natural = natural_criterion(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(
        F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1)
    )
    loss = loss_natural + beta * loss_robust
    return loss


def train(train_loader, network, criterion, optimizer, scheduler, epoch, args, rec, if_weighted: bool = False):
    """Train for one epoch on the training set"""
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')

    # switch to train mode
    network.train()

    end = time.time()
    for i, contents in enumerate(train_loader):
        optimizer.zero_grad()
        if if_weighted:
            target = contents[0][1].to(args.device)
            input = contents[0][0].to(args.device)

            # Compute output
            output = network(input)
            weights = contents[1].to(args.device).requires_grad_(False)
            loss = torch.sum(criterion(output, target) * weights) / torch.sum(weights)
        else:
            target = contents[1].to(args.device)
            input = contents[0].to(args.device)

            # Compute output
            output = network(input)
            loss = criterion(output, target).mean()

        # Measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))[0]
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # Compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                loss=losses, top1=top1))

    record_train_stats(rec, epoch, losses.avg, top1.avg, optimizer.state_dict()['param_groups'][0]['lr'])


def adv_train(train_loader, network, criterion, optimizer, scheduler, epoch, args, rec):
    """Train for one epoch on the training set"""
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')

    # switch to train mode
    network.train()

    end = time.time()
    for i, contents in enumerate(train_loader):
        optimizer.zero_grad()

        target = contents[1].to(args.device)
        input = contents[0].to(args.device)

        # Compute output
        output = network(input)
        loss = trades_loss(network, input, target, args.device, optimizer, step_size=2/255, epsilon=8/255,
                           perturb_steps=10, beta=6, clip_min=0.0, clip_max=1.0)
        # loss = criterion(output, target)
        loss = loss.mean()

        # Measure accuracy and record loss
        prec1 = accuracy(output.data, target, topk=(1,))[0]
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # Compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                loss=losses, top1=top1))

    record_train_stats(rec, epoch, losses.avg, top1.avg, optimizer.state_dict()['param_groups'][0]['lr'])


def test(test_loader, network, criterion, epoch, args, rec, if_advex: bool = False, print_clsacc: bool = False):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')

    num_classes = test_loader.dataset.num_classes
    cls_pred = np.zeros(num_classes)
    cls_num = np.zeros(num_classes)

    # Switch to evaluate mode
    network.eval()
    # network.no_grad = True

    if if_advex:
        attack = torchattacks.PGD(network, eps=8 / 255, alpha=1 / 255, steps=10, random_start=True)
        # adv_losses = AverageMeter('Loss', ':.4e')
        # adv_top1 = AverageMeter('Acc@1', ':6.2f')

    end = time.time()
    for i, (input, target) in enumerate(test_loader):

        target = target.to(args.device)
        input = input.to(args.device)

        if if_advex:
            adv_input = attack(input, target)

            with torch.no_grad():
                output = network(adv_input)
                adv_loss = criterion(output, target).mean()

            adv_prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(adv_loss.data.item(), adv_input.size(0))
            top1.update(adv_prec1.item(), input.size(0))
        else:

            # Compute output
            with torch.no_grad():
                output = network(input)
                loss = criterion(output, target).mean()

            # Measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

        if print_clsacc:
            btc_pred, btc_num = batch_cls_pred(output, target, num_classes)
            cls_pred += btc_pred
            cls_num += btc_num

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            if if_advex:
                print_string = 'Test-Adv: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                                'AdvLoss {loss.val:.4f} ({loss.avg:.4f})\tAdvPrec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(test_loader), batch_time=batch_time, loss=losses, top1=top1)
            else:
                print_string = 'Test: [{0}/{1}]\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                               'Loss {loss.val:.4f} ({loss.avg:.4f})\tPrec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(test_loader), batch_time=batch_time, loss=losses, top1=top1)

            print(print_string)

    if if_advex:
        print(' * AdvPrec@1\t{top1.avg:.3f}'.format(top1=top1))
    else:
        print(' * Prec@1\t{top1.avg:.3f}'.format(top1=top1))

    if print_clsacc:
        cls_acc = np.round(cls_pred/cls_num, 2)
        print(f' * Class-wise Prec@1\t{list(cls_acc)}')

    # network.no_grad = False

    record_test_stats(rec, epoch, losses.avg, top1.avg)
    return top1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def batch_cls_pred(output, target, num_classes):

    with torch.no_grad():
        _, pred = output.topk(1, 1, True, True)
        pred = pred.squeeze()
        hot_pred = F.one_hot(pred, num_classes)
        hot_target = F.one_hot(target, num_classes)

        cls_correct = torch.sum(hot_pred*hot_target, 0)
        cls_nums = torch.sum(hot_target, 0)

    return cls_correct.detach().cpu().numpy(), cls_nums.detach().cpu().numpy()


def str_to_bool(v):
    # Handle boolean type in arguments.
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise ArgumentTypeError('Boolean value expected.')


def save_checkpoint(state, path, epoch, prec):
    print("=> Saving checkpoint for epoch %d, with Prec@1 %f." % (epoch, prec))
    torch.save(state, path)


def init_recorder():
    from types import SimpleNamespace
    rec = SimpleNamespace()
    rec.train_step = []
    rec.train_loss = []
    rec.train_acc = []
    rec.lr = []
    rec.test_step = []
    rec.test_loss = []
    rec.test_acc = []
    rec.ckpts = []
    return rec


def record_train_stats(rec, step, loss, acc, lr):
    rec.train_step.append(step)
    rec.train_loss.append(loss)
    rec.train_acc.append(acc)
    rec.lr.append(lr)
    return rec


def record_test_stats(rec, step, loss, acc):
    rec.test_step.append(step)
    rec.test_loss.append(loss)
    rec.test_acc.append(acc)
    return rec


def record_ckpt(rec, step):
    rec.ckpts.append(step)
    return rec


class DataLoaderX(torch.utils.data.DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

def update_trainloader(args, network, dst_train, subset, epoch, relabel_train=False, interval=1):
    half_stage = int(args.epochs * 0.5)
    num_updates = (args.epochs - half_stage)//interval
    app_batch = 0.5 * len(subset["indices"]) / num_updates  # No int() for the total use of data in the final

    # assert epoch >= half_stage, "Trainloader should be appended and updated after first half training"

    if not relabel_train:
        loader_ratio = 0.5 * (1 + epoch/(args.epochs*0.5))
        num_subset = int(len(subset["indices"]) * loader_ratio)

        dst_subset = torch.utils.data.Subset(dst_train, subset["indices"][:num_subset])
        train_loader = torch.utils.data.DataLoader(dst_subset, batch_size=args.train_batch, shuffle=True,
                                                   num_workers=args.workers, pin_memory=True)

        p_ids = [i for i in subset["indices"][:num_subset] if i in dst_train.poison_ids]
        prate = len(p_ids) / num_subset
        print(f"\n=== Update training loader with Top-{loader_ratio*100:.2f}% [size: {num_subset}, prate: {prate:.4f}] ===\n")
    else:
        if epoch < half_stage:
            num_subset = int(len(subset["indices"]))

            dst_subset = torch.utils.data.Subset(dst_train, subset["indices"][:num_subset])
            train_loader = torch.utils.data.DataLoader(dst_subset, batch_size=args.train_batch, shuffle=True,
                                                       num_workers=args.workers, pin_memory=True)

            p_ids = [i for i in subset["indices"][:num_subset] if i in dst_train.poison_ids]
            prate = len(p_ids)/num_subset
            print(f"\n=== Update training loader with Top-50% [size: {num_subset}, prate: {prate:.4f}] ===\n")
        else:
            # Switch to evaluate mode
            network.eval()

            # num_subset = int(app_batch * epoch // interval)
            num_subset = int(len(subset["indices"]))

            dst_subset = torch.utils.data.Subset(dst_train, subset["indices"][:num_subset])

            # Test the entire subset and relabel them
            test_subset_loader = torch.utils.data.DataLoader(dst_subset, batch_size=1, shuffle=False,
                                                             num_workers=args.workers, pin_memory=True)

            targets = dst_train.targets[subset["indices"][:num_subset]]
            labels = deepcopy(targets)
            for i, (input, target) in enumerate(test_subset_loader):
                input = input.to(args.device)

                # Compute output
                with torch.no_grad():
                    output = network(input)
                    pred = torch.argmax(output, dim=1).detach().cpu().numpy()

                labels[i] = pred

            num_relabel = np.sum(np.not_equal(labels, targets))

            num_relabel_poi = 0
            for i, ind in enumerate(list(np.not_equal(labels, targets))):
                if ind and subset["indices"][i] in dst_train.poison_ids:
                    num_relabel_poi += 1

            # Construct data subset with relabeled targets into the new training loader
            dst_train_re = deepcopy(dst_train)
            dst_train_re.targets[subset["indices"][:num_subset]] = labels
            dst_train_subset = torch.utils.data.Subset(dst_train_re, subset["indices"][:num_subset])

            train_loader = torch.utils.data.DataLoader(dst_train_subset, batch_size=args.train_batch, shuffle=True,
                                                       num_workers=args.workers, pin_memory=True)

            p_ids = [i for i in subset["indices"][:num_subset] if i in dst_train.poison_ids]
            prate = len(p_ids) / num_subset
            print(f"\n=== Update training loader with re-labeling [size: {num_subset}, re-labeled: {num_relabel} (poi: {num_relabel_poi}), prate: {prate:.4f}] ===\n")

    return train_loader
