import torch
import torch.optim as optim
from tqdm import tqdm
from methods.method import Method
from itertools import cycle
from time import time
class NegGrad(Method):

    def unlearn(self, model, loaders, args):

        criterion = torch.nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)

        total_remain_iter = len(self.train_remain_loader) * args.remain_epochs
        cur = 0

        for _ in range(args.remain_epochs):
            for idx, (batch_retrain, batch_forget) in enumerate(zip(self.train_remain_loader, cycle(self.train_forget_loader))):

                x_remain, y_remain = batch_retrain
                x_forget, y_forget = batch_forget

                x_remain, y_remain = x_remain.to(args.device), y_remain.to(args.device)
                x_forget, y_forget = x_forget.to(args.device), y_forget.to(args.device)

                outputs_forget = model(x_forget.cuda())
                self.statistics.add_forward_flops(x_forget.size(0))

                loss_ascent_forget = -criterion(outputs_forget, y_forget.cuda())

                # Overall loss
                joint_loss = 0.0001 * loss_ascent_forget + 0.9999 * self.recover(model, x_remain, y_remain, criterion)

                optimizer.zero_grad()
                joint_loss.backward()
                # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)

                cur += 1
                # print(f"[{cur}/{total_remain_iter}] joint loss: {joint_loss.item()}")
                if isinstance(self, NegGrad): 
                    self.statistics.add_backward_flops(x_forget.size(0))
                elif isinstance(self, NegGradP):
                    self.statistics.add_backward_flops(x_remain.size(0) + x_forget.size(0))

                optimizer.step()

        # self.intermediate_test(model)

        return model

    def recover(self, model, x_remain, y_remain, criterion): return 0


class NegGradP(NegGrad):
    def recover(self, model, x_remain, y_remain, criterion):
        outputs_remain = model(x_remain.cuda())
        self.statistics.add_forward_flops(x_remain.size(0))
        loss_descent_remain = criterion(outputs_remain, y_remain.cuda())

        return loss_descent_remain
