import sys
import time

import torch

import utils

# from .impl import iterative_unlearn

sys.path.append(".")
# from imagenet import get_x_y_from_data_dict


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


def l2_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=2)


def FT(image, target, model, criterion, optimizer, epoch, args, with_l1=False):
    output_clean = model(image)
    if epoch < args.epochs - 0:
        current_alpha = 0.2 * (
                1 - epoch / (args.epochs)
        )
    else:
        current_alpha = 0
    loss = criterion(output_clean, target)
    if with_l1:
        loss = loss + current_alpha * l1_regularization(model)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    output = output_clean.float()
    loss = loss.float()
    # measure accuracy and record loss
    prec1 = utils.accuracy(output.data, target)[0]

    # losses.update(loss.item(), image.size(0))
    # top1.update(prec1.item(), image.size(0))

    # if (i + 1) % args.print_freq == 0:
    #     end = time.time()
    #     print(
    #         "Epoch: [{0}][{1}/{2}]\t"
    #         "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
    #         "Accuracy {top1.val:.3f} ({top1.avg:.3f})\t"
    #         "Time {3:.2f}".format(
    #             epoch, i, len(train_loader), end - start, loss=losses, top1=top1
    #         )
    #     )
    #     start = time.time()

    return model


# @iterative_unlearn
# def FT(data_loaders, model, criterion, optimizer, epoch, args):
#     return FT_iter(data_loaders, model, criterion, optimizer, epoch, args)
#
#
# @iterative_unlearn
# def FT_l1(data_loaders, model, criterion, optimizer, epoch, args):
#     return FT_iter(data_loaders, model, criterion, optimizer, epoch, args, with_l1=True)
