import torch.nn as nn
import torch.optim as optim
import torch
from functools import partial
from torch.utils.data import DataLoader
from helper.thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate
from helper.thirdparty.repdistiller.distiller_zoo import DistillKL
from helper.thirdparty.repdistiller.helper.loops import train_distill, validate
from helper.thirdparty.repdistiller.helper.util import param_dist
from helper import utils
from settings import train_transformers_func
import copy
import torch.nn.functional as F
from tqdm import tqdm


def avg_fn(averaged_model_parameter, model_parameter, num_averaged): 
    beta = 0.1
    return (1 - beta) * averaged_model_parameter + beta * model_parameter

def scrub(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    retain_loader = DataLoader(retain_set, shuffle=True, batch_size=config.train_batch_size)
    forget_loader = DataLoader(forget_set, shuffle=True, batch_size=config.train_batch_size)
    
    if config.llama:
        model_s = model

        args = config

        swa_model = torch.optim.swa_utils.AveragedModel(model_s, avg_fn=avg_fn)

        module_list = nn.ModuleList([])
        module_list.append(model_s)
        trainable_list = nn.ModuleList([])
        trainable_list.append(model_s)
        criterion_cls = nn.CrossEntropyLoss()
        criterion_div = DistillKL(args.kd_T)
        criterion_kd = DistillKL(args.kd_T)
        criterion_list = nn.ModuleList([])
        criterion_list.append(criterion_cls)    # classification loss
        criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation
        criterion_list.append(criterion_kd)     # other knowledge distillation loss

        acc_rs = []
        acc_fs = []

        trainer_init_kwargs.model = model_s
        trainer_init_kwargs.train_dataset = retain_set
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        TrainerBase = getattr(train_transformers_func, trainer.__class__.__name__)

        model_s = module_list[0]
        model_t = module_list[-1]
        opt = args 
        if opt.distill == 'kd':
            loss_kd = 0
        
        class LossToMinimize:
            def compute_loss(self, model, inputs, return_outputs=False):
                loss_cls, outputs_s = TrainerBase.compute_loss(self, model, inputs, return_outputs=True)
                logits_s = outputs_s.logits
                logits_t = model_t(**inputs).logits
                loss_div = criterion_div(logits_s, logits_t)
                loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
                loss = loss + param_dist(model_s, swa_model, opt.smoothing)
                return loss

        class LossToMaximize:
            def compute_loss(self, model, inputs, return_outputs=False):
                loss_cls, outputs_s = TrainerBase.compute_loss(self, model, inputs, return_outputs=True)
                logits_s = outputs_s.logits
                logits_t = model_t(**inputs).logits
                loss_div = criterion_div(logits_s, logits_t)
                loss = -loss_div
                loss = loss + param_dist(model_s, swa_model, opt.smoothing)
                return loss

        for epoch in range(1, args.sgda_epochs + 1):
            if epoch <= args.msteps:
                trainer.compute_loss = partial(LossToMaximize.compute_loss, trainer)
                trainer.train_dataset = forget_set
                trainer.args.num_train_epochs = 1
                trainer.train()
                utils.clear_cache()
            trainer.compute_loss = partial(LossToMinimize.compute_loss, trainer)
            trainer.train_dataset = retain_set
            trainer.args.num_train_epochs = 0.2
            trainer.train()
            utils.clear_cache()
            if epoch >= args.sstart:
                swa_model.update_parameters(model_s)
                
    else:
        criterion = getattr(torch.nn, config.loss)()
        optimizer_cls = getattr(torch.optim, config.optimizer)
        optimizer = optimizer_cls(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

        teacher_model = copy.deepcopy(model)
        teacher_model.eval()
        model.train()

        kl_div = nn.KLDivLoss(reduction='batchmean')    # Use 'batchmean' instead of 'mean' to aggregate across batch and ignore number of classes: https://discuss.pytorch.org/t/kldiv-loss-reduction/109131
        
        # Maximize the KLDiv of predicted probabilities between the original and the unlearned model on the forget dataset
        for epoch in tqdm(range(config.num_epochs), desc="SCRUB"):
            for i, (inputs, targets) in enumerate(forget_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                log_probs = F.log_softmax(outputs, dim=1)
                with torch.no_grad():
                    teacher_probs = F.softmax(teacher_model(inputs), dim=1)

                unl_loss = -kl_div(log_probs, teacher_probs)    # the goal is to maximize the KL div on the forget set
                optimizer.zero_grad()
                unl_loss.backward()
                print("Forget Loss: {:.4f}".format(-unl_loss.item()))

                optimizer.step()
            
            num_retain_iterations = int(1 * len(retain_loader))

            # Minimize KLDiv of predicted probabilities between the original and the unlearned model on the retain dataset
            count = 0
            retain_iter = iter(retain_loader)
            for i in range(num_retain_iterations):
                inputs, targets = next(retain_iter)
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                log_probs = F.log_softmax(outputs, dim=1)
                with torch.no_grad():
                    teacher_probs = F.softmax(teacher_model(inputs), dim=1)

                # the goal is to minimize the weighted average of KL div and task loss on the retain set 
                unl_loss = (1 - config.task_loss_coeff) * kl_div(log_probs, teacher_probs) + config.task_loss_coeff * loss
                print("Retain Loss: {:.4f}".format(unl_loss.item()))

                optimizer.zero_grad()
                unl_loss.backward()

                optimizer.step()
                if (i+1) >= num_retain_iterations:
                    break