import copy
import numpy as np
import os 
import torch
from torch import nn
import time
from methods.method import Method
from methods.method import Method
from tqdm import tqdm
from itertools import cycle


class RL(Method):

    def unlearn(self, model, loaders, args):

        criterion = nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)
        model.train()

        for _ in range(args.remain_epochs):
            for idx, (batch_retrain, batch_forget) in enumerate(tqdm(zip(self.train_remain_loader, cycle(self.train_forget_loader)))):

                x_remain, y_remain = batch_retrain
                x_forget, y_forget = batch_forget

                x_remain, y_remain = x_remain.to(args.device), y_remain.to(args.device)
                x_forget, y_forget = x_forget.to(args.device), y_forget.to(args.device)

                # class_idx: starting point to unlearn
                # class_idx_unlearn: number of classes to unlearn
                y_forget = torch.randint(0, args.num_classes - args.class_idx_unlearn, (y_forget.size(0),)).to(args.device)
                y_forget = torch.where(y_forget >= args.class_idx, y_forget + args.class_idx_unlearn, y_forget)
                # print(y_forget)
                
                x = torch.cat([x_forget, x_remain], dim=0)
                y = torch.cat([y_forget, y_remain], dim=0)

                outputs = model(x)
                self.statistics.add_forward_flops(x.size(0))

                loss = criterion(outputs, y)

                optimizer.zero_grad()
                loss.backward()
                self.statistics.add_backward_flops(x.size(0))

                optimizer.step()

        self.intermidiate_test(model)

        return model

    def randint_excluding_multiple(self, low, high, size, excludes):
        possible_values = set(range(low, high)) - set(excludes)
        possible_values = torch.tensor(list(possible_values))

        indices = torch.randint(0, len(possible_values), size)
        result = possible_values[indices]
        return result