import logging
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils_online.inc_net_online import PromptVitNetOnline, oLORA
from models_online.base import BaseLearner
from utils.toolkit import tensor2numpy
from buffer.buffer import ProtoBuffer, Reservoir, QueryReservoir
from utils_online.si_blurry import IndexedDataset, OnlineSampler, OnlineTestSampler
import wandb
import sys

# tune the model at first session with vpt, and then conduct simple shot.
num_workers = 8

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
    
        self._network = oLORA(args, True)
        # self._network = nn.DataParallel(oLORA(args, True), device_ids=[0,1])
        self._network = nn.DataParallel(oLORA(args, True), device_ids=args['device'])
        # self._network.backbone = nn.DataParallel(self._network.backbone, device_ids=[0,1,2,3])

        self.batch_size = args["batch_size"]
        self.init_lr = args["init_lr"]
        self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
        self.args = args
        self.beta = 1
        self.step = 0
        self.classes_seen_so_far = torch.LongTensor(size=(0,)).cuda()
        self.samples_cnt = 0
         # loss dectection
        self.loss_window=[]
        self.loss_window_means=[]
        self.loss_window_variances=[]
        self.new_peak_detected=True
        self.loss_window_length = 5
        self.loss_window_mean_threshold=5.6
        # self.loss_window_mean_threshold=50
        self.loss_window_variance_threshold=0.08
        # self.loss_window_variance_threshold=4
        self.hard_buffer = []
        self.hard_buffer_size = 4
        self.omega_As = []
        self.omega_Bs = []
        self.exposed_classes = []
        self.count_updates=0
        self.MAS_weight = 2000
        self.last_loss_window_variance = 0
        self.last_loss_window_mean = 0
        
        self.distributed = False
        self.memory = Memory()
        self.mask = torch.zeros(args['nb_classes'], device=self._device) - torch.inf
        
        total_params = sum(p.numel() for p in self._network.module.backbone.parameters())
        logging.info(f'{total_params:,} model total parameters.')
        total_trainable_params = sum(p.numel() for p in self._network.module.backbone.parameters() if p.requires_grad)
        logging.info(f'{total_trainable_params:,} model training parameters.')

        # if some parameters are trainable, print the key name and corresponding parameter number
        if total_params != total_trainable_params:
            for name, param in self._network.module.backbone.named_parameters():
                if param.requires_grad:
                    logging.info("{}: {}".format(name, param.numel()))

    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager, **kwargs):
        n_tasks = kwargs.get('n_tasks', 10)
        self._cur_task += 1
        if self.args['blurry']:
            self._total_classes = self.args['nb_classes']
            train_dataset = data_manager.get_dataset(np.arange(0, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        else:
            self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
            # self._network.update_fc(self._total_classes)
            logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

            train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_dataset = test_dataset
        self.train_dataset = train_dataset
        if self.args['blurry']:
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            
            self.train_dataset = IndexedDataset(self.train_dataset)
            self.train_sampler = OnlineSampler(self.train_dataset, n_tasks, 10, 50, self.args['seed'], False, 1)
            self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.train_sampler, num_workers=num_workers, pin_memory=True)
            
            self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
            self.train_sampler.set_task(self._cur_task)
        else:
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
            test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
            self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
            
        self.data_manager = data_manager
        
        self._train(self.train_loader, self.test_loader)

    def _train(self, train_loader, test_loader):
        # self._network.cuda()

        if self._cur_task == 0:
            self.optim = self.get_optimizer()
            self.scheduler = self.get_scheduler()
            
        # if self._cur_task > 0:
            # self._init_prompt(optimizer)

        # if self._cur_task > 0 and self.args["reinit_optimizer"]:
            # optimizer = self.get_optimizer()
            
        self._init_train(train_loader, test_loader)

    def get_optimizer(self):
        if self.args['optimizer'] == 'sgd':
            optimizer = optim.SGD(
                filter(lambda p: p.requires_grad, self._network.parameters()), 
                momentum=0.9, 
                lr=self.init_lr,
                weight_decay=self.weight_decay
            )
        elif self.args['optimizer'] == 'adam':
            optimizer = optim.Adam(
                self._network.parameters(),
                lr=self.init_lr, 
                weight_decay=self.weight_decay
            )
        elif self.args['optimizer'] == 'adamw':
            optimizer = optim.AdamW(
                filter(lambda p: p.requires_grad, self._network.parameters()),
                lr=self.init_lr, 
                weight_decay=self.weight_decay
            )
        return optimizer
    
    def get_scheduler(self):
        if self.args["scheduler"] == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optim, T_max=self.args['tuned_epoch'], eta_min=self.min_lr)
        elif self.args["scheduler"] == 'steplr':
            scheduler = optim.lr_scheduler.MultiStepLR(optimizer=self.optim, milestones=self.args["init_milestones"], gamma=self.args["init_lr_decay"])
        elif self.args["scheduler"] == 'constant':
            scheduler = None

        return scheduler
    
    def _init_train(self, train_loader, test_loader):
        num_eval=100
        self._network.train()
        criterion = torch.nn.CrossEntropyLoss().cuda()
        prog_bar = tqdm(range(self.args['tuned_epoch']))
        for _, epoch in enumerate(prog_bar):
            losses = 0.0
            correct, total = 0, 0
            
            for i, (_, images, labels) in enumerate(train_loader):
                self.samples_cnt += images.size(0)
                self.add_new_class(labels)
                
                x, y = images.cuda(), labels.cuda()
                
                for j in range(len(y)):
                    y[j] = self.exposed_classes.index(y[j].item())
                    
                # self.hard_buffer = []
                if len(self.hard_buffer) != 0:
                    xh = [_['state'] for _ in self.hard_buffer]
                    yh = [_['trgt'] for _ in self.hard_buffer]
                
                # only 1 online epoch
                total_loss = (torch.tensor(0.0)).cuda()
                # Current batch loss
                current_loss = [] 
                out = self._network(x)
                logits, f = out["logits"], out["pre_logits"]
                logits = logits + self.mask
                # present = y.unique().to(self._device)
                # logits[:, [i for i in range(self._total_classes) if i not in present]] = float('-inf')
                y_pred = logits
                current_loss.append(criterion(y_pred, y))
                total_loss += criterion(y_pred, y).mean()
                
                # Hard buffer loss
                hard_loss = []
                if len(self.hard_buffer) != 0:
                    # evaluate hard buffer
                    for image_h, label_h in zip(xh, yh):
                        yh_pred = self._network(image_h)['logits']
                        hard_loss.append(criterion(yh_pred,label_h).mean())
                        total_loss += criterion(yh_pred,label_h).mean()
                
                # keep train loss for loss window
                first_train_loss=total_loss.detach().cpu().numpy()
                
                wnew_a_params = filter(lambda p: getattr(p, '_is_wnew_a', False), self._network.module.backbone.lora_vit.parameters())
                wnew_b_params = filter(lambda p: getattr(p, '_is_wnew_b', False), self._network.module.backbone.lora_vit.parameters())

                # Regularization loss
                if len(self.omega_As)!=0 and len(self.omega_As)==len(self.omega_Bs): # omega_As and omega_Bs should have same length. 
                    mas_loss = 0.
                    for pindex, (p_a, p_b) in enumerate(zip(wnew_a_params, wnew_b_params)):
                        # product_a = torch.from_numpy(self.omega_As[pindex]).type(torch.float32).cuda() * ((p_a) ** 2)
                        product_a = torch.from_numpy(np.array(self.omega_As[pindex])).type(torch.float32).cuda() * ((p_a) ** 2)
                        product_b = torch.from_numpy(np.array(self.omega_Bs[pindex])).type(torch.float32).cuda() * ((p_b) ** 2)
                        mas_loss += torch.sum(product_a) + torch.sum(product_b) 
                    print('MAS loss: {}'.format(mas_loss))
                    total_loss+=self.MAS_weight/2.*mas_loss
                
                # LR monitoring
                lr = self.optim.param_groups[0]['lr']
                wandb.log({'lr': lr, 'step': self.step})
                
                self.optim.zero_grad()
                loss = torch.sum(total_loss)
                loss.backward()
                self.optim.step()
                
                # losses += torch.distributed.all_reduce(loss).item()
                losses += loss.item()
                
                if str(loss.item()) == 'nan':
                    print("LOSS IS NAN")
                    sys.exit(0)
                
                wandb.log({
                    'loss': loss.item(),
                    # 'loss_recab': loss_recab.item(),
                    'step': self.step
                    })

                # save training accuracy on total batch
                if len(self.hard_buffer) != 0:
                    xt=xh + [x]
                    yt=yh + [y]
                else:
                    xt=[x]
                    yt=[y]

                # Update loss_window and detect loss plateaus
                self.loss_window.append(np.mean(first_train_loss))
                if len(self.loss_window)>self.loss_window_length: del self.loss_window[0]
                self.loss_window_mean=np.mean(self.loss_window)
                self.loss_window_variance=np.var(self.loss_window)
                print(self.loss_window_variance)
                # print('loss window mean: {0:0.3f}, loss window variance: {1:0.3f}'.format(self.loss_window_mean, self.loss_window_variance))
                # Check the statistics of the current window
                if not self.new_peak_detected and self.loss_window_mean > self.last_loss_window_mean+np.sqrt(self.last_loss_window_variance) : #and loss_window_variance > 0.1:
                    self.new_peak_detected=True  
                # Time for updating importance weights    
                if self.loss_window_mean < self.loss_window_mean_threshold and self.loss_window_variance < self.loss_window_variance_threshold and self.new_peak_detected:
                    new_task = False
                    print("NEW TASK")
                    self.count_updates+=1
                    # print('importance weights update')
                    last_loss_window_mean=self.loss_window_mean
                    last_loss_window_variance=self.loss_window_variance
                    self.new_peak_detected=False
                    
                    # calculate imporatance based on each sample in the hardbuffer
                    gradients_A = [0 for p in self._network.module.backbone.lora_vit.parameters() if getattr(p, '_is_wnew_a', False)]
                    gradients_B = [0 for p in self._network.module.backbone.lora_vit.parameters() if getattr(p, '_is_wnew_b', False)]
                    
                    # self._network.module.backbone.eval()
                    wnew_a_params = filter(lambda p: getattr(p, '_is_wnew_a', False), self._network.module.backbone.lora_vit.parameters())
                    wnew_b_params = filter(lambda p: getattr(p, '_is_wnew_b', False), self._network.module.backbone.lora_vit.parameters())
                    for sx in [_['state'] for _ in self.hard_buffer]:
                        self._network.module.backbone.zero_grad()
                        logits=self._network(sx)['logits'].view(1,-1)
                        # present = y.unique().to(self._device)
                        # logits[:, [i for i in range(self._total_classes) if i not in present]] = float('-inf')
                        label = logits.max(1)[1].view(-1)
                        omega_loss = F.nll_loss(F.log_softmax(logits, dim=1), label)
                        omega_loss.backward()

                        for pindex, (p_a, p_b) in enumerate(zip(wnew_a_params, wnew_b_params)):
                            g_a=p_a.grad.data.clone().detach().cpu().numpy()
                            g_b=p_b.grad.data.clone().detach().cpu().numpy()
                            gradients_A[pindex]+= np.abs(g_a) ** 2
                            gradients_B[pindex]+= np.abs(g_b) ** 2 
                            
                    # update the running average of the importance weights        
                    self.omega_As_old = self.omega_As[:]
                    self.omega_Bs_old = self.omega_Bs[:]
                    self.omega_As=[]
                    self.omega_Bs=[]
                    wnew_a_params = filter(lambda p: getattr(p, '_is_wnew_a', False), self._network.module.backbone.lora_vit.parameters())
                    wnew_b_params = filter(lambda p: getattr(p, '_is_wnew_b', False), self._network.module.backbone.lora_vit.parameters())
                    for pindex, (p_a, p_b) in enumerate(zip(wnew_a_params, wnew_b_params)):
                        if len(self.omega_As_old) != 0 and len(self.omega_Bs_old) != 0: # the lengths should be the same. 
                            self.omega_As.append(1/self.count_updates*gradients_A[pindex]+(1-1/self.count_updates)*self.omega_As_old[pindex])
                            self.omega_Bs.append(1/self.count_updates*gradients_B[pindex]+(1-1/self.count_updates)*self.omega_Bs_old[pindex])
                        else:
                            self.omega_As.append(gradients_A[pindex])
                            self.omega_Bs.append(gradients_B[pindex])
                    
                    # Added: freeze current LoRA and create new set of LoRA parameters. 
                    self._network.module.backbone.update_and_reset_lora_parameters()
                    # self._network.module.backbone.save_lora_parameters(ckpt_path.replace(".pt", ".safetensors"))
                    self._network.module.backbone = self._network.module.backbone.cuda()

                self.loss_window_means.append(self.loss_window_mean)
                self.loss_window_variances.append(self.loss_window_variance)

                # Update hard_buffer                   
                if len(self.hard_buffer) == 0:
                    loss=[l.detach().cpu().numpy() for l in current_loss]
                else:
                    loss=[l.detach().cpu().numpy() for l in (current_loss+hard_loss)]
                    
                self.hard_buffer=[]
                sorted_inputs=[lx for _,lx in reversed(sorted(zip(loss,xt),key= lambda f:f[0]))]
                sorted_targets=[ly for _,ly in reversed(sorted(zip(loss,yt),key= lambda f:f[0]))]
                    
                for i in range(min(self.hard_buffer_size, len(sorted_inputs))):
                    self.hard_buffer.append({'state':sorted_inputs[i],
                                        'trgt':sorted_targets[i]})

                # if self.samples_cnt > num_eval:
                #     with torch.no_grad():
                #         test_sampler = OnlineTestSampler(self.test_dataset, self.exposed_classes)
                #         test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size*2, sampler=test_sampler, num_workers=8)
                #         eval_dict = self.online_evaluate(test_dataloader) # combined_dataloader
                #         if self.distributed:
                #             eval_dict =  torch.tensor([eval_dict['avg_loss'], eval_dict['avg_acc'], *eval_dict['cls_acc']], device=self._device)
                #             dist.reduce(eval_dict, dst=0, op=dist.ReduceOp.SUM)
                #             eval_dict = eval_dict.cpu().numpy()
                #             eval_dict = {'avg_loss': eval_dict[0]/self.world_size, 'avg_acc': eval_dict[1]/self.world_size, 'cls_acc': eval_dict[2:]/self.world_size}
                #         if self.is_main_process():
                #             eval_results["test_acc"].append(eval_dict['avg_acc'])
                #             eval_results["avg_acc"].append(eval_dict['cls_acc'])
                #             eval_results["data_cnt"].append(num_eval)
                #             self.report_test(samples_cnt, eval_dict["avg_loss"], eval_dict['avg_acc'])
                #         num_eval += self.eval_period
                        
                self.step += 1
                

                # buffer update
                # self.buffer.update(queries=None, keys=None, values=None, labels=y.detach(), features=f.detach())
                
                _, preds = torch.max(y_pred, dim=1)
                correct += preds.eq(y.expand_as(preds)).cpu().sum()
                total += len(y)

            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
            
            if self.args['blurry']:
                test_sampler = OnlineTestSampler(self.test_dataset, self.exposed_classes)
                test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, sampler=test_sampler, num_workers=num_workers)
                self.test_loader = test_loader
            
            if (epoch + 1) % 5 == 0:
                test_acc = self._compute_accuracy(self._network, test_loader)
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                    test_acc,
                )
            else:
                info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format(
                    self._cur_task,
                    epoch + 1,
                    self.args['tuned_epoch'],
                    losses / len(train_loader),
                    train_acc,
                )
            prog_bar.set_description(info)

        logging.info(info)

    def _eval_cnn(self, loader):
        self._network.eval()
        y_pred, y_true = [], []
        for _, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.to(self._device)
            for j in range(len(targets)):
                targets[j] = self.exposed_classes.index(targets[j].item())
            with torch.no_grad():
                logit = self._network(inputs)['logits']
                logit = logit + self.mask
                
            predicts = torch.topk(
                logit, k=self.topk, dim=1, largest=True, sorted=True
            )[
                1
            ]  # [bs, topk]
            y_pred.append(predicts.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return np.concatenate(y_pred), np.concatenate(y_true)  # [N, topk]

    def _compute_accuracy(self, model, loader):
        model.eval()
        correct, total = 0, 0
        for i, (_, inputs, targets) in enumerate(loader):
            inputs = inputs.cuda()
            for j in range(len(targets)):
                targets[j] = self.exposed_classes.index(targets[j].item())
            with torch.no_grad():
                outputs = model(inputs, task_id=self._cur_task)["logits"][:, :self._total_classes]
            predicts = torch.max(outputs, dim=1)[1]
            correct += (predicts.cpu() == targets).sum()
            total += len(targets)

        return np.around(tensor2numpy(correct) * 100 / total, decimals=2)

    def add_new_class(self, class_name):
        for label in class_name:
            if label.item() not in self.exposed_classes:
                self.exposed_classes.append(label.item())
        # if self.distributed:
        #     exposed_classes = torch.cat(self.all_gather(torch.tensor(self.exposed_classes, device=self.device))).cpu().tolist()
        #     self.exposed_classes = []
        #     for cls in exposed_classes:
        #         if cls not in self.exposed_classes:
        #             self.exposed_classes.append(cls)
        # self.memory.add_new_class(cls_list=self.exposed_classes)
        self.mask[:len(self.exposed_classes)] = 0
        # if 'reset' in self.sched_name:
            # self.update_schedule(reset=True)


class Memory:
    def __init__(self, data_source=None) -> None:
        
        self.data_source = data_source
        if self.data_source is not None:
            self.images = []

        self.memory = torch.empty(0)
        self.labels = torch.empty(0)
        self.cls_list = torch.empty(0)
        self.cls_count = torch.empty(0)
        self.cls_train_cnt = torch.empty(0)
        self.previous_idx = torch.empty(0)
        self.others_loss_decrease = torch.empty(0)

    def add_new_class(self, cls_list) -> None:
        self.cls_list = torch.tensor(cls_list)
        self.cls_count = torch.cat([self.cls_count, torch.zeros(len(self.cls_list) - len(self.cls_count))])
        self.cls_train_cnt = torch.cat([self.cls_train_cnt, torch.zeros(len(self.cls_list) - len(self.cls_train_cnt))])

    def replace_data(self, data, idx: int=None) -> None:
        index, label = data
        if self.data_source is not None:
            image, label = self.data_source.__getitem__(index)
        if idx is None:
            if self.data_source is not None:
                self.images.append(image.unsqueeze(0))
            self.memory = torch.cat([self.memory, torch.tensor([index])])
            self.labels = torch.cat([self.labels, torch.tensor([label])])
            self.cls_count[(self.cls_list == label).nonzero().squeeze()] += 1
            # print("[Memory-Replace_data:idx is None]")
            # print(self.cls_list == label)
            # print(self.cls_count[(self.cls_list == label)])
            if self.cls_count[(self.cls_list == label).nonzero().squeeze()] == 1:
                self.others_loss_decrease = torch.cat([self.others_loss_decrease, torch.tensor([0])])
            else:
                indice = (self.labels == label).nonzero().squeeze()
                self.others_loss_decrease = torch.cat([self.others_loss_decrease, torch.mean(self.others_loss_decrease[indice[:-1]]).unsqueeze(0)])
        else:
            if self.data_source is not None:
                self.images[idx] = image.unsqueeze(0)
            _label = self.labels[idx]
            self.cls_count[(self.cls_list == _label).nonzero().squeeze()] -= 1
            self.memory[idx] = index
            self.labels[idx] = label
            # print("[Memory-Replace_data]")
            # print(self.cls_list == label)
            # print(self.cls_count[(self.cls_list == label).nonzero().squeeze()])
            self.cls_count[(self.cls_list == label).nonzero().squeeze()] += 1
            if self.cls_count[(self.cls_list == label).nonzero().squeeze()] == 1:
                self.others_loss_decrease[idx] = torch.mean(self.others_loss_decrease)
            else:
                indice = (self.labels == label).nonzero().squeeze()
                self.others_loss_decrease[idx] = torch.mean(self.others_loss_decrease[indice[indice != idx]])

    def update_loss_history(self, loss, prev_loss, ema_ratio=0.90, dropped_idx=None) -> None:
        if dropped_idx is None:
            loss_diff = torch.mean(loss - prev_loss)
        elif len(prev_loss) > 0:
            mask = torch.ones(len(loss), dtype=bool)
            mask[torch.tensor(dropped_idx, dtype=torch.int64).squeeze()] = False
            loss_diff = torch.mean((loss[:len(prev_loss)] - prev_loss)[mask[:len(prev_loss)]])
        else:
            loss_diff = 0
        difference = loss_diff - torch.mean(self.others_loss_decrease[self.previous_idx.to(torch.int64)]) / len(self.previous_idx)
        self.others_loss_decrease[self.previous_idx.to(torch.int64)] -= (1 - ema_ratio) * difference
        self.previous_idx = torch.empty(0)
    
    def get_weight(self):
        weight = torch.zeros(self.images.size(0))
        for cls in self.cls_list:
            weight[(self.labels == cls).nonzero().squeeze()] = 1 / (self.labels == cls).nonzero().numel()
        return weight

    def update_gss_score(self, score: int, idx: int=None) -> None:
        if idx is None:
            self.score.append(score)
        else:
            self.score[idx] = score

    def __len__(self) -> int:
        return len(self.labels)

    def sample(self, memory_batchsize):
        assert self.data_source is not None
        idx = torch.randperm(len(self.images), dtype=torch.int64)[:memory_batchsize]
        images = []
        labels = []
        for i in idx:
            images.append(self.images[i])
            labels.append(self.labels[i])
        return torch.cat(images), torch.tensor(labels)

class DummyMemory(Memory):
    def __init__(self, data_source, shape=(3,32,32), datasize: int=100) -> None:
        super(DummyMemory, self).__init__(data_source)
        self.shape = shape
        self.datasize = datasize
        self.images = torch.rand(self.datasize, *self.shape)
        self.labels = torch.randint(0, 10, (self.datasize,))
        self.cls_list = torch.unique(self.labels)
        self.cls_count = torch.zeros(len(self.cls_list))
        self.cls_train_cnt = torch.zeros(len(self.cls_list))
        self.others_loss_decrease = torch.zeros(self.datasize)


class MemoryBatchSampler(torch.utils.data.Sampler):
    def __init__(self, memory: Memory, batch_size: int, iterations: int = 1) -> None:
        self.memory = memory
        self.batch_size = batch_size
        self.iterations = int(iterations)
        self.indices = torch.cat([torch.randperm(len(self.memory), dtype=torch.int64)[:min(self.batch_size, len(self.memory))] for _ in range(self.iterations)]).tolist()
        for i, idx in enumerate(self.indices):
            self.indices[i] = int(self.memory.memory[idx])
    
    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)
    

class BatchSampler(torch.utils.data.Sampler):
    def __init__(self, samples_idx: int, batch_size: int, iterations: int = 1) -> None:
        self.samples_idx = samples_idx
        self.batch_size = batch_size
        self.iterations = int(iterations)
        self.indices = torch.cat([torch.randperm(len(self.samples_idx), dtype=torch.int64)[:min(self.batch_size, len(self.samples_idx))] for _ in range(self.iterations)]).tolist()
        for i, idx in enumerate(self.indices):
            self.indices[i] = int(self.samples_idx[idx])
    
    def __iter__(self):
        return iter(self.indices)

    def __len__(self) -> int:
        return len(self.indices)

class MemoryOrderedSampler(torch.utils.data.Sampler):
    def __init__(self, memory: Memory, batch_size: int, iterations: int = 1) -> None:
        self.memory = memory
        self.batch_size = batch_size
        self.iterations = int(iterations)
        self.indices = torch.cat([torch.arange(len(self.memory), dtype=torch.int64) for _ in range(self.iterations)]).tolist()
        for i, idx in enumerate(self.indices):
            self.indices[i] =  int(self.memory.memory[idx])
    
    def __iter__(self):
        if dist.is_initialized():
            return iter(self.indices[dist.get_rank()::dist.get_world_size()])
        else:
            return iter(self.indices)
    def __len__(self) -> int:
        if dist.is_initialized():
            return len(self.indices[dist.get_rank()::dist.get_world_size()])
        else:
            return len(self.indices)