import sys
import time
import wandb

import torch
import torch.nn.functional as F

import utils

from .impl import iterative_unlearn

sys.path.append(".")
from imagenet import get_x_y_from_data_dict


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)


def OPC_iter(
    data_loaders, model, criterion, optimizer, epoch, args, mask=None, with_l1=False
):
    train_loader = data_loaders["marked"]
    forget_loader = data_loaders["forget"]
    retain_loader = data_loaders["retain"]
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top2 = utils.AverageMeter()
    radius = utils.AverageMeter()
    norm_losses = utils.AverageMeter()
    entropy_losses = utils.AverageMeter()

    # switch to train mode
    model.train()

    start = time.time()
    if args.imagenet_arch:
        device = (
            torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        )
        for i, data in enumerate(train_loader):
            image, target = get_x_y_from_data_dict(data, device)
            if epoch < args.warmup:
                utils.warmup_lr(
                    epoch, i + 1, optimizer, one_epoch_step=len(train_loader), args=args
                )

            # compute output
            output_clean = model(image)
            if epoch < args.unlearn_epochs - args.no_l1_epochs:
                current_alpha = args.alpha * (
                    1 - epoch / (args.unlearn_epochs - args.no_l1_epochs)
                )
            else:
                current_alpha = 0
            loss = criterion(output_clean, target)
            if with_l1:
                loss = loss + current_alpha * l1_regularization(model)
            optimizer.zero_grad()
            loss.backward()

            if mask:
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name]

            optimizer.step()

            output = output_clean.float()
            loss = loss.float()
            # measure accuracy and record loss
            prec1 = utils.accuracy(output.data, target)[0]

            losses.update(loss.item(), image.size(0))
            top1.update(prec1.item(), image.size(0))

            if (i + 1) % args.print_freq == 0:
                end = time.time()
                print(
                    "Epoch: [{0}][{1}/{2}]\t"
                    "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                    "Accuracy {top1.val:.3f} ({top1.avg:.3f})\t"
                    "Time {3:.2f}".format(
                        epoch, i, len(train_loader), end - start, loss=losses, top1=top1
                    )
                )
                start = time.time()
    else:
        for i, (image, target) in enumerate(train_loader):
            if epoch < args.warmup:
                utils.warmup_lr(
                    epoch, i + 1, optimizer, one_epoch_step=len(train_loader), args=args
                )
            epoch_progress = epoch + i / len(train_loader)
            image = image.cuda()
            target = target.cuda()
            marked = target < 0

            if epoch < args.unlearn_epochs - args.no_l1_epochs:
                current_alpha = args.alpha * (
                    1 - epoch / (args.unlearn_epochs - args.no_l1_epochs)
                )
            else:
                current_alpha = 0
            # compute output
            full_output = model(image)
            norm_loss = full_output[marked].norm(p=2,dim=1).mean()
            retain_loss = criterion(full_output[~marked], target[~marked])
            # entropy_loss = 0
            # probs = F.softmax(full_output[marked], dim=1)
            # entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1).mean()
            # entropy_loss = -entropy

            forget_loss = norm_loss #+entropy_loss
            # 초기값과 목표값 계산
            init_coeff_un = args.coeff_un
            final_coeff_un = args.coeff_un * args.coeff_un_mult

            # 전체 iteration 수 계산 (예: epochs * steps_per_epoch)
            total_iters = args.unlearn_epochs * len(train_loader)
            current_iter = epoch * len(train_loader) + i

            # 선형 스케줄로 coeff_un 계산
            coeff_un = init_coeff_un + (final_coeff_un - init_coeff_un) * (current_iter / total_iters)

            # loss 계산
            loss = args.coeff_ce * retain_loss + coeff_un * forget_loss
            # loss = (1-coeff_un) * criterion(full_output[~marked], target[~marked]) + coeff_un * norm_loss
            if with_l1:
                loss += current_alpha * l1_regularization(model)

            optimizer.zero_grad()
            loss.backward()

            if mask:
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name]

            optimizer.step()

            output = full_output[~marked].float()
            forget_output = full_output[marked].float()
            loss = loss.float()
            # measure accuracy and record loss
            prec1 = utils.accuracy(output.data, target[~marked])[0]
            # prec2 = utils.accuracy(forget_output.data, -(target[marked]+1))[0]
            norm_losses.update(norm_loss.item(),image[marked].size(0))
            # entropy_losses.update(entropy_loss.item(),image[marked].size(0))
            losses.update(loss.item(), image.size(0))
            top1.update(prec1.item(), image[~marked].size(0))
            # top2.update(prec2.item(), image[marked].size(0))
            radius.update(norm_loss.item(), image[marked].size(0))

            wandb.log({
                "loss": losses.val,
                "norm_loss":coeff_un*norm_losses.val,
                "retain_loss":args.coeff_ce*retain_loss.item(),
                "forget_loss":coeff_un *forget_loss.item(),
                "accuracy": top1.val,
                # "forget_accuracy":top2.val,
                "epoch": epoch_progress,
                "coeff_un":coeff_un,
            }, step=args.global_step)
            args.global_step += 1
            
            if (i + 1) % args.print_freq == 0:
                end = time.time()
                print(
                    "Epoch: [{0}][{1}/{2}]\t"
                    "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                    "Accuracy {top1.val:.3f} ({top1.avg:.3f})\t"
                    # "ForgetAccuracy {top2.val:.3f} ({top2.avg:.3f})\t"
                    "Time {3:.2f}".format(
                        epoch, i, len(train_loader), end - start, loss=losses, top1=top1,
                    )
                )
                start = time.time()
 
    forget_radius = 0
    model.eval()
    with torch.no_grad():
        for i, (image, target) in enumerate(forget_loader):
                image = image.cuda()
                target = target.cuda()
                full_output = model(image)
                _radius = full_output.norm(p=2,dim=1).max()
                forget_radius = max(forget_radius, _radius.item())
    wandb.log({
        "radius": forget_radius,
        "radius_mean":radius.avg,
    }, step=args.global_step)
    print("train_accuracy {top1.avg:.3f}".format(top1=top1))
    print(f"forget_radius {forget_radius}")
    

    return forget_radius


@iterative_unlearn
def OPC(data_loaders, model, criterion, optimizer, epoch, args, mask=None):
    return OPC_iter(data_loaders, model, criterion, optimizer, epoch, args, mask)


@iterative_unlearn
def OPC_l1(data_loaders, model, criterion, optimizer, epoch, args, mask=None):
    return OPC_iter(
        data_loaders, model, criterion, optimizer, epoch, args, mask, with_l1=True
    )