import time
import torch
import utils
from .impl import iterative_unlearn
import torch.nn.functional as F
import numpy as np


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


@iterative_unlearn
def GA(data_loaders, model, criterion, optimizer, epoch, args, wa=0, wb=1):
    train_loader = data_loaders["forget"]
    retain_dataset = data_loaders['retain_dataset']

    sampled_indices = np.random.choice(
        range(0, len(retain_dataset)), int(len(train_loader.dataset) * 1), replace=False)
    sampled_dataset = torch.utils.data.Subset(retain_dataset, sampled_indices)
    sampled_loader = torch.utils.data.DataLoader(
        sampled_dataset, batch_size=train_loader.batch_size, num_workers=4, pin_memory=True, shuffle=True)

    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()

    # switch to train mode
    model.train()

    start = time.time()
    for i, ((sampled_image, sampled_target), (forget_image, forget_target)) in enumerate(zip(sampled_loader, train_loader)):

        if epoch < args.warmup:
            utils.warmup_lr(epoch, i+1, optimizer,
                            one_epoch_step=len(train_loader), args=args)

        # forgetting dataset
        image = forget_image.cuda()
        target = forget_target.cuda()
        output_clean = model(image)

        confidence = 1. - args.smooth_rate
        logprobs = F.log_softmax(output_clean, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss_s = (wa + wb * confidence) * nll_loss + \
                    wb * args.smooth_rate * smooth_loss
        loss_s_mean = torch.sum(loss_s) / len(image)

        if args.p == 1:
            loss = -loss_s_mean
        elif args.p == -1:
            loss_func = torch.nn.CrossEntropyLoss(reduction="mean")
            gradient_label = dict()
            for i in range(20): # hard-coded # cls
                gradient_label[i] = None
                i_torch = torch.tensor([i]).cuda()
                loss = loss_func(output_clean, i_torch)
                model.zero_grad()
                loss.backward(retain_graph=True)
                for p in model.parameters():
                    if p.requires_grad:
                        if gradient_label[i] is None:
                            gradient_label[i] = torch.flatten(p.grad).detach().cpu().numpy()
                        else:
                            gradient_label[i] = np.hstack([gradient_label[i], 
                                                           torch.flatten(p.grad).detach().cpu().numpy()])
                        # break
        else:
            image = sampled_image.cuda()
            target = sampled_target.cuda()
            output_clean = model(image)
            loss_gd = criterion(output_clean, target)
            if args.smooth_rate <= 0:
                loss = args.p * loss_gd - (1 - args.p) * loss_s_mean
            else:
                loss = args.p * loss_gd + (1 - args.p) * loss_s_mean

        # y = None
        # for i in gradient_label:
        #     if y is None:
        if args.p == -1:
            grads = np.array([gradient_label[i] for i in gradient_label])
            x = gradient_label[target.detach().cpu().numpy()[0]]
            # x_y = (-np.sum(grads, axis=0) + x) / 19 + x
            y = np.sum(grads, axis=0)
            # print(np.linalg.norm(x), np.dot(x, y))
            print(np.dot(x, x - (y - x) / 19))
        else:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            output = output_clean.float()
            loss = loss.float()
            # measure accuracy and record loss
            prec1 = utils.accuracy(output.data, target)[0]
            losses.update(loss.item(), image.size(0))
            top1.update(prec1.item(), image.size(0))


            if (i + 1) % args.print_freq == 0:
                end = time.time()
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Time {3:.2f}'.format(
                        epoch, i, len(train_loader), end-start, loss=losses, top1=top1))
                start = time.time()
    return top1.avg


@iterative_unlearn
def GA_l1(data_loaders, model, criterion, optimizer, epoch, args):
    train_loader = data_loaders["forget"]

    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()

    # switch to train mode
    model.train()

    start = time.time()
    for i, (image, target) in enumerate(train_loader):

        if epoch < args.warmup:
            utils.warmup_lr(epoch, i+1, optimizer,
                            one_epoch_step=len(train_loader), args=args)

        image = image.cuda()
        target = target.cuda()

        # compute output
        output_clean = model(image)
        loss = -criterion(output_clean, target) + \
            args.alpha * l1_regularization(model)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output_clean.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = utils.accuracy(output.data, target)[0]

        losses.update(loss.item(), image.size(0))
        top1.update(prec1.item(), image.size(0))

        if (i + 1) % args.print_freq == 0:
            end = time.time()
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Time {3:.2f}'.format(
                      epoch, i, len(train_loader), end-start, loss=losses, top1=top1))
            start = time.time()

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg
