import time
import torch
import utils
from .impl import iterative_unlearn
import torch.nn.functional as F
from itertools import zip_longest


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


def FT_iter(data_loaders, model, criterion, optimizer, epoch, args, with_l1=False):
    train_loader = data_loaders["retain"]
    forget_loader = data_loaders["forget"]
    # forget_generator = iter(forget_loader)

    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    losses_forget = utils.AverageMeter()
    losses_train = utils.AverageMeter()

    # switch to train mode
    model.train()
    print(f'smooth rate: {args.smooth_rate}')

    start = time.time()
    # for i, (image, target) in enumerate(train_loader):
    for i, ((image, target), (forget_image, forget_target)) in enumerate(
        zip(train_loader, forget_loader)
    ):
        # try:
        #     forget_image, forget_target = next(forget_generator)
        # except StopIteration:
        #     forget_generator = iter(forget_loader)
        #     forget_image, forget_target = next(forget_generator)
        if epoch < args.warmup:
            utils.warmup_lr(
                epoch, i + 1, optimizer, one_epoch_step=len(train_loader), args=args
            )

        image = image.cuda()
        target = target.cuda()

        # compute output
        output_clean = model(image)
        loss_train = criterion(output_clean, target)
        if with_l1:
            loss += args.alpha * l1_regularization(model)
        
        if args.p == 1 and args.smooth_rate == 0.0:
            loss = loss_train
            loss_s_mean = torch.tensor(0.0)
        else:
            forget_image = forget_image.cuda()
            forget_target = forget_target.cuda()
            output_forget = model(forget_image)

            confidence = 1.0 - args.smooth_rate
            logprobs = F.log_softmax(output_forget, dim=-1)
            nll_loss = -logprobs.gather(dim=-1, index=forget_target.unsqueeze(1))
            nll_loss = nll_loss.squeeze(1)
            smooth_loss = -logprobs.mean(dim=-1)
            loss_s = (args.wa + args.wb * confidence) * nll_loss + args.wb * args.smooth_rate * smooth_loss
            loss_s_mean = torch.sum(loss_s) / len(forget_image)

            if args.smooth_rate <= 0.0:
                # NLS + GA
                loss = args.p * loss_train - (1 - args.p) * loss_s_mean
            else:
                # PLS + GD
                loss = args.p * loss_train + (1 - args.p) * loss_s_mean

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output_clean.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = utils.accuracy(output.data, target)[0]

        losses.update(loss.item(), image.size(0))
        losses_train.update(loss_train.item(), image.size(0))
        losses_forget.update(loss_s_mean.item(), forget_image.size(0))
        top1.update(prec1.item(), image.size(0))

        if (i + 1) % args.print_freq == 0:
            end = time.time()
            print(
                "Epoch: [{0}][{1}/{2}]\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "Loss Train {loss_train.val:.4f} ({loss_train.avg:.4f})\t"
                "Loss Forget {loss_forget.val:.4f} ({loss_forget.avg:.4f})\t"
                "Accuracy {top1.val:.3f} ({top1.avg:.3f})\t"
                "Time {3:.2f}".format(
                    epoch, i, len(train_loader), end - start, loss=losses, 
                    loss_train=losses_train, loss_forget=losses_forget, top1=top1
                )
            )
            start = time.time()

    print("train_accuracy {top1.avg:.3f}".format(top1=top1))

    return top1.avg


@iterative_unlearn
def FT(data_loaders, model, criterion, optimizer, epoch, args):
    return FT_iter(data_loaders, model, criterion, optimizer, epoch, args)


@iterative_unlearn
def FT_l1(data_loaders, model, criterion, optimizer, epoch, args):
    return FT_iter(data_loaders, model, criterion, optimizer, epoch, args, with_l1=True)
