from tqdm import tqdm
import torch
from torch import nn
import numpy as np
from methods.method import Method
import copy
import torch.optim as optim

class SCRUB(Method):
    # Kurmanji et al. Towards Unbounded Machine Unlearning. NeurIPS, 2023
    # code from https://github.com/meghdadk/SCRUB/ / Slightly modified for our framework

    def set_hyperparameters(self, args):

        self.gamma = 0.99
        self.alpha = 0.001
        self.beta = 0
        self.smoothing = 0.0
        self.clip = 0.2
        self.sstart = 10
        self.kd_T = 4
        self.distill = 'kd'

        self.sgda_learning_rate = args.lr #very sensitive to learning rate, 1e-4
        self.lr_decay_epochs = [3,5,9]
        self.lr_decay_rate = 0.1
        self.sgda_weight_decay = 5e-4
        self.sgda_momentum = 0.9
        
        #Fill these arguments whenver adding new datasets
        self.msteps = None
        self.sgda_epochs = None
        self.fg_bs = None
        self.rt_bs = None

        self.msteps = self.args.forget_epochs
        self.sgda_epochs = self.args.remain_epochs
        self.fg_bs = self.args.forget_batch_size
        self.rt_bs = self.args.remain_batch_size

    def adjust_learning_rate(self, epoch, args, optimizer):
        """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
        steps = np.sum(epoch > np.asarray(self.lr_decay_epochs))
        new_lr = self.sgda_learning_rate
        if steps > 0:
            new_lr = self.sgda_learning_rate * (self.lr_decay_rate ** steps)
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
        return new_lr


    def unlearn(self, model, loaders, args):

        model_t = copy.deepcopy(model)
        model_s = copy.deepcopy(model)

        #this is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py
        #For SGDA smoothing
        beta = 0.1
        def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (
            1 - beta) * averaged_model_parameter + beta * model_parameter
        swa_model = torch.optim.swa_utils.AveragedModel(
            model_s, avg_fn=avg_fn)

        module_list = nn.ModuleList([])
        module_list.append(model_s)
        trainable_list = nn.ModuleList([])
        trainable_list.append(model_s)

        criterion_cls = nn.CrossEntropyLoss()
        criterion_div = DistillKL(self.kd_T)
        criterion_kd = DistillKL(self.kd_T)


        criterion_list = nn.ModuleList([])
        criterion_list.append(criterion_cls)    # classification loss
        criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation
        criterion_list.append(criterion_kd)     # other knowledge distillation loss

        # optimizer
        if args.optimizer == "sgd":
            optimizer = optim.SGD(trainable_list.parameters(),
                                lr=self.sgda_learning_rate,
                                momentum=self.sgda_momentum,
                                weight_decay=self.sgda_weight_decay)
        elif args.optimizer == "adam": 
            optimizer = optim.Adam(trainable_list.parameters(),
                                lr=self.sgda_learning_rate,
                                weight_decay=self.sgda_weight_decay)

        module_list.append(model_t)

        if torch.cuda.is_available():
            module_list.cuda()
            criterion_list.cuda()
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            swa_model.cuda()
        

        forget_loader = torch.utils.data.DataLoader(loaders['train_forget_loader'].dataset, batch_size=self.fg_bs, shuffle=True)
        retain_loader = torch.utils.data.DataLoader(loaders['train_remain_loader'].dataset, batch_size=self.rt_bs, shuffle=True)
        

        for epoch in range(1, self.sgda_epochs + 1):
            lr = self.adjust_learning_rate(epoch, args, optimizer)

            maximize_loss = 0
            if epoch <= self.msteps:
                maximize_loss = self.train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, "maximize")
            train_acc, train_loss = self.train_distill(epoch, retain_loader, module_list, swa_model, criterion_list, optimizer, args, "minimize",)
            if epoch >= self.sstart:
                swa_model.update_parameters(model_s)

            print ("maximize loss: {:.2f}\t minimize loss: {:.2f}\t train_acc: {}".format(maximize_loss, train_loss, train_acc))

        final_model = module_list[0]
        return final_model
    
    
    def train_distill(self, epoch, train_loader, module_list, swa_model, criterion_list, optimizer, opt, split, quiet=False):
        """One epoch distillation"""
        # set modules as train()
        for module in module_list:
            module.train()
        # set teacher as eval()
        module_list[-1].eval()

        criterion_cls = criterion_list[0]
        criterion_div = criterion_list[1]
        criterion_kd = criterion_list[2]

        model_s = module_list[0]
        model_t = module_list[-1]

        # batch_time = AverageMeter()
        # data_time = AverageMeter()
        losses = AverageMeter()
        kd_losses = AverageMeter()
        top1 = AverageMeter()

        # end = time.time()
        for idx, data in enumerate(tqdm(train_loader)):
            if self.distill in ['crd']:
                input, target, index, contrast_idx = data
            else:
                input, target = data
            # data_time.update(time.time() - end)

            input = input.float()
            if torch.cuda.is_available():
                input = input.cuda()
                target = target.cuda()
                if self.distill in ['crd']:
                    contrast_idx = contrast_idx.cuda()
                    index = index.cuda()

            # ===================forward=====================
            #feat_s, logit_s = model_s(input, is_feat=True, preact=False)
            logit_s = model_s(input)
            self.statistics.add_forward_flops(input.size(0))
            with torch.no_grad():
                #feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
                #feat_t = [f.detach() for f in feat_t]
                logit_t = model_t(input)
                self.statistics.add_forward_flops(input.size(0))

            # cls + kl div
            loss_cls = criterion_cls(logit_s, target)
            loss_div = criterion_div(logit_s, logit_t)

            # other kd beyond KL divergence
            loss_kd = 0
       
            if split == "minimize":
                loss = self.gamma * loss_cls + self.alpha * loss_div + self.beta * loss_kd
            elif split == "maximize":
                loss = -loss_div

            loss = loss + param_dist(model_s, swa_model, self.smoothing)

            if split == "minimize" and not quiet:
                acc1, _ = accuracy(logit_s, target, topk=(1,1))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1[0], input.size(0))
            elif split == "maximize" and not quiet:
                kd_losses.update(loss.item(), input.size(0))

            # ===================backward=====================
            optimizer.zero_grad()
            loss.backward()
            self.statistics.add_backward_flops(input.size(0))
            optimizer.step()

        if split == "minimize":
            if not quiet:
                print(' * Acc@1 {top1.avg:.3f} '
                    .format(top1=top1))

            return top1.avg, losses.avg
        else:
            return kd_losses.avg
        

import torch.nn as nn
import torch.nn.functional as F

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
    
def param_dist(model, swa_model, p):
    #This is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py
    dist = 0.
    for p1, p2 in zip(model.parameters(), swa_model.parameters()):
        dist += torch.norm(p1 - p2, p='fro')
    return p * dist
